diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..0bd655236 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -29,7 +29,16 @@ enum ProblemType { MAX_ITER_REACHED }; -int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); +int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode=0, int return_checkpoint=0, + double* flow_state=nullptr, double* pi_state=nullptr, + signed char* state_state=nullptr, int* parent_state=nullptr, + int64_t* pred_state=nullptr, int* thread_state=nullptr, + int* rev_thread_state=nullptr, int* succ_num_state=nullptr, + int* last_succ_state=nullptr, signed char* forward_state=nullptr, + int64_t* search_arc_num_out=nullptr, int64_t* all_arc_num_out=nullptr); + int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..e54dd0776 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -20,7 +20,14 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - double* alpha, double* beta, double *cost, uint64_t maxIter) { + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode, int return_checkpoint, + double* flow_state, double* pi_state, signed char* state_state, + int* parent_state, int64_t* pred_state, + int* thread_state, int* rev_thread_state, + int* succ_num_state, int* last_succ_state, + signed char* forward_state, + int64_t* search_arc_num_out, int64_t* all_arc_num_out) { // beware M and C are stored in row major C style!!! using namespace lemon; @@ -93,8 +100,29 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, // Solve the problem with the network simplex algorithm - - int ret=net.run(); + // If resume_mode=1 and checkpoint data provided, resume from checkpoint + // Otherwise do normal run + + int64_t search_arc_num_in = 0, all_arc_num_in = 0; + if (resume_mode == 1 && search_arc_num_out != nullptr && all_arc_num_out != nullptr) { + search_arc_num_in = *search_arc_num_out; + all_arc_num_in = *all_arc_num_out; + } + + int ret; + if (resume_mode == 1 && flow_state != nullptr) { + // Resume from checkpoint + ret = net.runFromCheckpoint( + flow_state, pi_state, state_state, + parent_state, pred_state, + thread_state, rev_thread_state, + succ_num_state, last_succ_state, forward_state, + search_arc_num_in, all_arc_num_in); + } else { + // Normal run + ret = net.run(); + } + uint64_t i, j; if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) { *cost = 0; @@ -111,6 +139,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, } + // Save checkpoint if requested and arrays provided + if (return_checkpoint == 1 && flow_state != nullptr) { + net.saveCheckpoint( + flow_state, pi_state, state_state, + parent_state, pred_state, + thread_state, rev_thread_state, + succ_num_state, last_succ_state, forward_state, + search_arc_num_out, all_arc_num_out); + } return ret; } @@ -118,9 +155,6 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, - - - int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) { // beware M and C are stored in row major C style!!! diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..9ee21bbcc 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -172,6 +172,7 @@ def emd( center_dual=True, numThreads=1, check_marginals=True, + warm_start=False, ): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -232,6 +233,10 @@ def emd( check_marginals: bool, optional (default=True) If True, checks that the marginals mass are equal. If False, skips the check. + warm_start: bool or dict, optional (default=False) + If True, returns warm start data in the log for resuming computation. + If dict (from previous call with warm_start=True), resumes optimization + from the provided state. Requires log=True when saving state. Returns @@ -241,7 +246,9 @@ def emd( parameters log: dict, optional If input log is true, a dictionary containing the - cost and dual variables and exit status + cost and dual variables and exit status. If warm_start=True, + also contains a "checkpoint" key with the internal solver state + for resuming computation. Examples @@ -258,6 +265,14 @@ def emd( array([[0.5, 0. ], [0. , 0.5]]) + Warm start example for resuming optimization: + + >>> # First call - save warm start data + >>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True) + >>> # log["checkpoint"] contains the solver state + >>> # Resume from warm start + >>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log) + .. _references-emd: References @@ -321,7 +336,67 @@ def emd( numThreads = check_number_threads(numThreads) - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # Handle warm_start parameter + checkpoint_data = None + return_checkpoint = False + + if isinstance(warm_start, dict): + # Resume from previous warm_start dict + # Check if checkpoint is nested under "checkpoint" key or at top level + if "checkpoint" in warm_start: + chkpt = warm_start["checkpoint"] + else: + chkpt = warm_start + + checkpoint_data = { + "flow": nx.to_numpy(chkpt.get("flow", chkpt.get("_flow"))) + if ("flow" in chkpt or "_flow" in chkpt) + else None, + "pi": nx.to_numpy(chkpt.get("pi", chkpt.get("_pi"))) + if ("pi" in chkpt or "_pi" in chkpt) + else None, + "state": nx.to_numpy(chkpt.get("state", chkpt.get("_state"))) + if ("state" in chkpt or "_state" in chkpt) + else None, + "parent": nx.to_numpy(chkpt.get("parent", chkpt.get("_parent"))) + if ("parent" in chkpt or "_parent" in chkpt) + else None, + "pred": nx.to_numpy(chkpt.get("pred", chkpt.get("_pred"))) + if ("pred" in chkpt or "_pred" in chkpt) + else None, + "thread": nx.to_numpy(chkpt.get("thread", chkpt.get("_thread"))) + if ("thread" in chkpt or "_thread" in chkpt) + else None, + "rev_thread": nx.to_numpy(chkpt.get("rev_thread", chkpt.get("_rev_thread"))) + if ("rev_thread" in chkpt or "_rev_thread" in chkpt) + else None, + "succ_num": nx.to_numpy(chkpt.get("succ_num", chkpt.get("_succ_num"))) + if ("succ_num" in chkpt or "_succ_num" in chkpt) + else None, + "last_succ": nx.to_numpy(chkpt.get("last_succ", chkpt.get("_last_succ"))) + if ("last_succ" in chkpt or "_last_succ" in chkpt) + else None, + "forward": nx.to_numpy(chkpt.get("forward", chkpt.get("_forward"))) + if ("forward" in chkpt or "_forward" in chkpt) + else None, + "search_arc_num": int( + chkpt.get("search_arc_num", chkpt.get("_search_arc_num", 0)) + ), + "all_arc_num": int(chkpt.get("all_arc_num", chkpt.get("_all_arc_num", 0))), + } + # Filter out None values + checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None} + elif warm_start is True: + # Save warm_start data - requires log=True + if not log: + raise ValueError( + "warm_start=True requires log=True to return the warm start data" + ) + return_checkpoint = True + + G, cost, u, v, result_code, checkpoint_out = emd_c( + a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint) + ) if center_dual: u, v = center_ot_dual(u, v, a, b) @@ -345,6 +420,24 @@ def emd( log["v"] = nx.from_numpy(v, type_as=type_as) log["warning"] = result_code_string log["result_code"] = result_code + + # Add checkpoint data if requested (preserve original dtypes, don't cast) + if return_checkpoint and checkpoint_out is not None: + log["checkpoint"] = { + "flow": checkpoint_out["flow"], + "pi": checkpoint_out["pi"], + "state": checkpoint_out["state"], + "parent": checkpoint_out["parent"], + "pred": checkpoint_out["pred"], + "thread": checkpoint_out["thread"], + "rev_thread": checkpoint_out["rev_thread"], + "succ_num": checkpoint_out["succ_num"], + "last_succ": checkpoint_out["last_succ"], + "forward": checkpoint_out["forward"], + "search_arc_num": int(checkpoint_out["search_arc_num"]), + "all_arc_num": int(checkpoint_out["all_arc_num"]), + } + return nx.from_numpy(G, type_as=type_as), log return nx.from_numpy(G, type_as=type_as) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..c99cdc011 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,13 +14,21 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, int64_t import warnings cdef extern from "EMD.h": - int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil + int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, + double* alpha, double* beta, double *cost, uint64_t maxIter, + int resume_mode, int return_checkpoint, + double* flow_state, double* pi_state, signed char* state_state, + int* parent_state, int64_t* pred_state, + int* thread_state, int* rev_thread_state, + int* succ_num_state, int* last_succ_state, + signed char* forward_state, + int64_t* search_arc_num_out, int64_t* all_arc_num_out) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -40,9 +48,16 @@ def check_result(result_code): @cython.boundscheck(False) @cython.wraparound(False) -def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads): +def emd_c(np.ndarray[double, ndim=1, mode="c"] a, + np.ndarray[double, ndim=1, mode="c"] b, + np.ndarray[double, ndim=2, mode="c"] M, + uint64_t max_iter, + int numThreads, + checkpoint_in=None, + int return_checkpoint=0): """ Solves the Earth Movers distance problem and returns the optimal transport matrix + with optional checkpoint support for pause/resume. gamm=emd(a,b,M) @@ -79,43 +94,147 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod max_iter : uint64_t The maximum number of iterations before stopping the optimization algorithm if it has not converged. + numThreads : int + Number of threads for parallel computation (1 = no OpenMP) + checkpoint_in : dict or None + Checkpoint data to resume from. Should contain flow, pi, state, parent, + pred, thread, rev_thread, succ_num, last_succ, forward arrays. + return_checkpoint : int + If 1, returns checkpoint data; if 0, returns None for checkpoint. Returns ------- gamma: (ns x nt) numpy.ndarray Optimal transportation matrix for the given parameters + cost : float + Optimal transport cost + alpha : (ns,) numpy.ndarray + Source dual potentials + beta : (nt,) numpy.ndarray + Target dual potentials + result_code : int + Result code (OPTIMAL, INFEASIBLE, UNBOUNDED, MAX_ITER_REACHED) + checkpoint_out : dict or None + Checkpoint data if return_checkpoint=1, None otherwise """ - cdef int n1= M.shape[0] - cdef int n2= M.shape[1] - cdef int nmax=n1+n2-1 + cdef int n1 = M.shape[0] + cdef int n2 = M.shape[1] + cdef int all_nodes = n1 + n2 + 1 + cdef int64_t max_arcs = n1 * n2 + 2 * (n1 + n2) cdef int result_code = 0 - cdef int nG=0 - - cdef double cost=0 - cdef np.ndarray[double, ndim=1, mode="c"] alpha=np.zeros(n1) - cdef np.ndarray[double, ndim=1, mode="c"] beta=np.zeros(n2) - - cdef np.ndarray[double, ndim=2, mode="c"] G=np.zeros([0, 0]) - - cdef np.ndarray[double, ndim=1, mode="c"] Gv=np.zeros(0) + cdef double cost = 0 + cdef int64_t search_arc_num = 0 + cdef int64_t all_arc_num = 0 + + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros([n1, n2]) + + # Checkpoint arrays (for both input and output) + cdef np.ndarray[double, ndim=1, mode="c"] flow_state + cdef np.ndarray[double, ndim=1, mode="c"] pi_state + cdef np.ndarray[signed char, ndim=1, mode="c"] state_state + cdef np.ndarray[int, ndim=1, mode="c"] parent_state + cdef np.ndarray[int64_t, ndim=1, mode="c"] pred_state + cdef np.ndarray[int, ndim=1, mode="c"] thread_state + cdef np.ndarray[int, ndim=1, mode="c"] rev_thread_state + cdef np.ndarray[int, ndim=1, mode="c"] succ_num_state + cdef np.ndarray[int, ndim=1, mode="c"] last_succ_state + cdef np.ndarray[signed char, ndim=1, mode="c"] forward_state + + cdef int resume_mode = 0 if not len(a): - a=np.ones((n1,))/n1 + a = np.ones((n1,)) / n1 if not len(b): - b=np.ones((n2,))/n2 - - # init OT matrix - G=np.zeros([n1, n2]) - - # calling the function + b = np.ones((n2,)) / n2 + + # Prepare checkpoint arrays + if checkpoint_in is not None: + resume_mode = 1 + flow_state = np.asarray(checkpoint_in['flow'], dtype=np.float64, order='C') + pi_state = np.asarray(checkpoint_in['pi'], dtype=np.float64, order='C') + state_state = np.asarray(checkpoint_in['state'], dtype=np.int8, order='C') + parent_state = np.asarray(checkpoint_in['parent'], dtype=np.int32, order='C') + pred_state = np.asarray(checkpoint_in['pred'], dtype=np.int64, order='C') + thread_state = np.asarray(checkpoint_in['thread'], dtype=np.int32, order='C') + rev_thread_state = np.asarray(checkpoint_in['rev_thread'], dtype=np.int32, order='C') + + # Sanity check: array sizes must match expected sizes + if flow_state.shape[0] != max_arcs or pi_state.shape[0] != all_nodes: + raise ValueError( + f"Checkpoint size mismatch: expected flow={max_arcs}, pi={all_nodes}, " + f"got flow={flow_state.shape[0]}, pi={pi_state.shape[0]}" + ) + succ_num_state = np.asarray(checkpoint_in['succ_num'], dtype=np.int32, order='C') + last_succ_state = np.asarray(checkpoint_in['last_succ'], dtype=np.int32, order='C') + forward_state = np.asarray(checkpoint_in['forward'], dtype=np.int8, order='C') + + # Extract the arc counts + search_arc_num = checkpoint_in['search_arc_num'] + all_arc_num = checkpoint_in['all_arc_num'] + else: + # Allocate empty arrays (will be filled if return_checkpoint=1) + flow_state = np.zeros(max_arcs, dtype=np.float64) + pi_state = np.zeros(all_nodes, dtype=np.float64) + state_state = np.zeros(max_arcs, dtype=np.int8) + parent_state = np.zeros(all_nodes, dtype=np.int32) + pred_state = np.zeros(all_nodes, dtype=np.int64) + thread_state = np.zeros(all_nodes, dtype=np.int32) + rev_thread_state = np.zeros(all_nodes, dtype=np.int32) + succ_num_state = np.zeros(all_nodes, dtype=np.int32) + last_succ_state = np.zeros(all_nodes, dtype=np.int32) + forward_state = np.zeros(all_nodes, dtype=np.int8) + + # Call C++ function with checkpoint support with nogil: if numThreads == 1: - result_code = EMD_wrap(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter) + result_code = EMD_wrap( + n1, n2, + a.data, b.data, M.data, + G.data, alpha.data, beta.data, + &cost, max_iter, + resume_mode, return_checkpoint, + flow_state.data, + pi_state.data, + state_state.data, + parent_state.data, + pred_state.data, + thread_state.data, + rev_thread_state.data, + succ_num_state.data, + last_succ_state.data, + forward_state.data, + &search_arc_num, + &all_arc_num + ) else: - result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, G.data, alpha.data, beta.data, &cost, max_iter, numThreads) - return G, cost, alpha, beta, result_code + # For now, OpenMP version falls back to regular (not implemented yet) + result_code = EMD_wrap_omp(n1, n2, a.data, b.data, M.data, + G.data, alpha.data, beta.data, + &cost, max_iter, numThreads) + + # Build checkpoint output dict if requested + checkpoint_out = None + if return_checkpoint: + checkpoint_out = { + 'flow': flow_state, + 'pi': pi_state, + 'state': state_state, + 'parent': parent_state, + 'pred': pred_state, + 'thread': thread_state, + 'rev_thread': rev_thread_state, + 'succ_num': succ_num_state, + 'last_succ': last_succ_state, + 'forward': forward_state, + 'search_arc_num': search_arc_num, + 'all_arc_num': all_arc_num, + } + + return G, cost, alpha, beta, result_code, checkpoint_out @cython.boundscheck(False) diff --git a/ot/lp/network_simplex_simple.h b/ot/lp/network_simplex_simple.h index 9612a8a24..388851f88 100644 --- a/ot/lp/network_simplex_simple.h +++ b/ot/lp/network_simplex_simple.h @@ -941,6 +941,116 @@ namespace lemon { } } + + /// This function saves the complete internal state of the solver, + /// including flow values, dual potentials, arc states, and the + /// spanning tree structure. This allows pausing and resuming + /// the optimization later. + + void saveCheckpoint( + double* flow_out, + double* pi_out, + signed char* state_out, + int* parent_out, + ArcsType* pred_out, + int* thread_out, + int* rev_thread_out, + int* succ_num_out, + int* last_succ_out, + signed char* forward_out, + ArcsType* search_arc_num_out, + ArcsType* all_arc_num_out) + { + // Copy internal state to output arrays + std::copy(_flow.begin(), _flow.end(), flow_out); + std::copy(_pi.begin(), _pi.end(), pi_out); + std::copy(_state.begin(), _state.end(), state_out); + std::copy(_parent.begin(), _parent.end(), parent_out); + std::copy(_pred.begin(), _pred.end(), pred_out); + std::copy(_thread.begin(), _thread.end(), thread_out); + std::copy(_rev_thread.begin(), _rev_thread.end(), rev_thread_out); + std::copy(_succ_num.begin(), _succ_num.end(), succ_num_out); + std::copy(_last_succ.begin(), _last_succ.end(), last_succ_out); + + // Convert bool vector to signed char + for (size_t i = 0; i < _forward.size(); i++) { + forward_out[i] = _forward[i] ? 1 : 0; + } + + // Save arc counts needed for start() + *search_arc_num_out = _search_arc_num; + *all_arc_num_out = _all_arc_num; + } + + + /// This function restores the complete internal state of the solver + /// from a previously saved checkpoint. + + void restoreCheckpoint( + double* flow_in, + double* pi_in, + signed char* state_in, + int* parent_in, + ArcsType* pred_in, + int* thread_in, + int* rev_thread_in, + int* succ_num_in, + int* last_succ_in, + signed char* forward_in, + ArcsType search_arc_num_in, + ArcsType all_arc_num_in) + { + // Copy from input arrays to internal state + std::copy(flow_in, flow_in + _flow.size(), _flow.begin()); + std::copy(pi_in, pi_in + _pi.size(), _pi.begin()); + std::copy(state_in, state_in + _state.size(), _state.begin()); + std::copy(parent_in, parent_in + _parent.size(), _parent.begin()); + std::copy(pred_in, pred_in + _pred.size(), _pred.begin()); + std::copy(thread_in, thread_in + _thread.size(), _thread.begin()); + std::copy(rev_thread_in, rev_thread_in + _rev_thread.size(), _rev_thread.begin()); + std::copy(succ_num_in, succ_num_in + _succ_num.size(), _succ_num.begin()); + std::copy(last_succ_in, last_succ_in + _last_succ.size(), _last_succ.begin()); + + // Convert signed char to bool vector + for (size_t i = 0; i < _forward.size(); i++) { + _forward[i] = (forward_in[i] != 0); + } + + // Restore root (it's always _node_num) + _root = _node_num; + + // Restore arc counts needed by start() + _search_arc_num = search_arc_num_in; + _all_arc_num = all_arc_num_in; + } + + + /// This function restores the solver state from a checkpoint and + /// continues the optimization from that point. It skips the normal + /// initialization phase and goes directly to the simplex iterations. + + ProblemType runFromCheckpoint( + double* flow_in, + double* pi_in, + signed char* state_in, + int* parent_in, + ArcsType* pred_in, + int* thread_in, + int* rev_thread_in, + int* succ_num_in, + int* last_succ_in, + signed char* forward_in, + ArcsType search_arc_num_in, + ArcsType all_arc_num_in) + { + // Restore state from checkpoint + restoreCheckpoint(flow_in, pi_in, state_in, parent_in, pred_in, + thread_in, rev_thread_in, succ_num_in, last_succ_in, forward_in, + search_arc_num_in, all_arc_num_in); + + return start(); + } + /// @} private: diff --git a/test/test_ot.py b/test/test_ot.py index e8217d54d..f5c7976d1 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -914,6 +914,107 @@ def test_dual_variables(): assert constraint_violation.max() < 1e-8 +def test_emd_checkpoint(): + # test checkpoint save and resume + n = 50 + rng = np.random.RandomState(42) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = rng.rand(n, n) + + G_ref, log_ref = ot.emd(a, b, M, numItermax=10000, log=True) + + G1, log1 = ot.emd(a, b, M, numItermax=500, log=True, warm_start=True) + + if log1["result_code"] == 3: # MAX_ITER_REACHED ? + G2, log2 = ot.emd(a, b, M, numItermax=10000, log=True, warm_start=log1) + + np.testing.assert_allclose(log2["cost"], log_ref["cost"], rtol=1e-6) + np.testing.assert_allclose(G2, G_ref, rtol=1e-6) + + +def test_emd_checkpoint_multiple(): + # test multiple checkpoint cycles + n = 100 + rng = np.random.RandomState(123) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = rng.rand(n, n) + + G_ref, log_ref = ot.emd(a, b, M, numItermax=50000, log=True) + + # multiple checkpoint phases with increasing iteration budgets + max_iters = [100, 300, 600, 1000] + warm_start_data = None + costs = [] + + for max_iter in max_iters: + if warm_start_data is None: + G, log = ot.emd( + a, + b, + M, + numItermax=max_iter, + log=True, + warm_start=True, + ) + else: + G, log = ot.emd( + a, + b, + M, + numItermax=max_iter, + log=True, + warm_start=warm_start_data, + ) + costs.append(log["cost"]) + + if log["result_code"] != 3: # converged + break + # Only use warm_start if checkpoint is present + warm_start_data = log if "checkpoint" in log else None + + # check cost decreases monotonically + for i in range(len(costs) - 1): + assert costs[i + 1] <= costs[i] + + # check final result matches reference + np.testing.assert_allclose(log["cost"], log_ref["cost"], rtol=1e-5) + + +def test_emd_checkpoint_structure(): + # test that checkpoint contains all required fields + n = 10 + a = ot.utils.unif(n) + b = ot.utils.unif(n) + M = np.random.rand(n, n) + + G, log = ot.emd(a, b, M, numItermax=10, log=True, warm_start=True) + + # Check that checkpoint key exists + assert "checkpoint" in log, "Missing checkpoint key in log" + + checkpoint = log["checkpoint"] + + required_fields = [ + "flow", + "pi", + "state", + "parent", + "pred", + "thread", + "rev_thread", + "succ_num", + "last_succ", + "forward", + "search_arc_num", + "all_arc_num", + ] + + for field in required_fields: + assert field in checkpoint, f"Missing checkpoint field: {field}" + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal