Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,16 @@ enum ProblemType {
MAX_ITER_REACHED
};

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, uint64_t maxIter,
int resume_mode=0, int return_checkpoint=0,
double* flow_state=nullptr, double* pi_state=nullptr,
signed char* state_state=nullptr, int* parent_state=nullptr,
int64_t* pred_state=nullptr, int* thread_state=nullptr,
int* rev_thread_state=nullptr, int* succ_num_state=nullptr,
int* last_succ_state=nullptr, signed char* forward_state=nullptr,
int64_t* search_arc_num_out=nullptr, int64_t* all_arc_num_out=nullptr);

int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);


Expand Down
46 changes: 40 additions & 6 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,14 @@


int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, uint64_t maxIter) {
double* alpha, double* beta, double *cost, uint64_t maxIter,
int resume_mode, int return_checkpoint,
double* flow_state, double* pi_state, signed char* state_state,
int* parent_state, int64_t* pred_state,
int* thread_state, int* rev_thread_state,
int* succ_num_state, int* last_succ_state,
signed char* forward_state,
int64_t* search_arc_num_out, int64_t* all_arc_num_out) {
// beware M and C are stored in row major C style!!!

using namespace lemon;
Expand Down Expand Up @@ -93,8 +100,29 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,


// Solve the problem with the network simplex algorithm

int ret=net.run();
// If resume_mode=1 and checkpoint data provided, resume from checkpoint
// Otherwise do normal run

int64_t search_arc_num_in = 0, all_arc_num_in = 0;
if (resume_mode == 1 && search_arc_num_out != nullptr && all_arc_num_out != nullptr) {
search_arc_num_in = *search_arc_num_out;
all_arc_num_in = *all_arc_num_out;
}

int ret;
if (resume_mode == 1 && flow_state != nullptr) {
// Resume from checkpoint
ret = net.runFromCheckpoint(
flow_state, pi_state, state_state,
parent_state, pred_state,
thread_state, rev_thread_state,
succ_num_state, last_succ_state, forward_state,
search_arc_num_in, all_arc_num_in);
} else {
// Normal run
ret = net.run();
}

uint64_t i, j;
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
Expand All @@ -111,16 +139,22 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,

}

// Save checkpoint if requested and arrays provided
if (return_checkpoint == 1 && flow_state != nullptr) {
net.saveCheckpoint(
flow_state, pi_state, state_state,
parent_state, pred_state,
thread_state, rev_thread_state,
succ_num_state, last_succ_state, forward_state,
search_arc_num_out, all_arc_num_out);
}

return ret;
}







int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!
Expand Down
97 changes: 95 additions & 2 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def emd(
center_dual=True,
numThreads=1,
check_marginals=True,
warm_start=False,
):
r"""Solves the Earth Movers distance problem and returns the OT matrix

Expand Down Expand Up @@ -232,6 +233,10 @@ def emd(
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.
warm_start: bool or dict, optional (default=False)
If True, returns warm start data in the log for resuming computation.
If dict (from previous call with warm_start=True), resumes optimization
from the provided state. Requires log=True when saving state.


Returns
Expand All @@ -241,7 +246,9 @@ def emd(
parameters
log: dict, optional
If input log is true, a dictionary containing the
cost and dual variables and exit status
cost and dual variables and exit status. If warm_start=True,
also contains a "checkpoint" key with the internal solver state
for resuming computation.


Examples
Expand All @@ -258,6 +265,14 @@ def emd(
array([[0.5, 0. ],
[0. , 0.5]])

Warm start example for resuming optimization:

>>> # First call - save warm start data
>>> G, log = ot.emd(a, b, M, numItermax=100, log=True, warm_start=True)
>>> # log["checkpoint"] contains the solver state
>>> # Resume from warm start
>>> G, log = ot.emd(a, b, M, numItermax=1000, log=True, warm_start=log)


.. _references-emd:
References
Expand Down Expand Up @@ -321,7 +336,67 @@ def emd(

numThreads = check_number_threads(numThreads)

G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
# Handle warm_start parameter
checkpoint_data = None
return_checkpoint = False

if isinstance(warm_start, dict):
# Resume from previous warm_start dict
# Check if checkpoint is nested under "checkpoint" key or at top level
if "checkpoint" in warm_start:
chkpt = warm_start["checkpoint"]
else:
chkpt = warm_start

checkpoint_data = {
"flow": nx.to_numpy(chkpt.get("flow", chkpt.get("_flow")))
if ("flow" in chkpt or "_flow" in chkpt)
else None,
"pi": nx.to_numpy(chkpt.get("pi", chkpt.get("_pi")))
if ("pi" in chkpt or "_pi" in chkpt)
else None,
"state": nx.to_numpy(chkpt.get("state", chkpt.get("_state")))
if ("state" in chkpt or "_state" in chkpt)
else None,
"parent": nx.to_numpy(chkpt.get("parent", chkpt.get("_parent")))
if ("parent" in chkpt or "_parent" in chkpt)
else None,
"pred": nx.to_numpy(chkpt.get("pred", chkpt.get("_pred")))
if ("pred" in chkpt or "_pred" in chkpt)
else None,
"thread": nx.to_numpy(chkpt.get("thread", chkpt.get("_thread")))
if ("thread" in chkpt or "_thread" in chkpt)
else None,
"rev_thread": nx.to_numpy(chkpt.get("rev_thread", chkpt.get("_rev_thread")))
if ("rev_thread" in chkpt or "_rev_thread" in chkpt)
else None,
"succ_num": nx.to_numpy(chkpt.get("succ_num", chkpt.get("_succ_num")))
if ("succ_num" in chkpt or "_succ_num" in chkpt)
else None,
"last_succ": nx.to_numpy(chkpt.get("last_succ", chkpt.get("_last_succ")))
if ("last_succ" in chkpt or "_last_succ" in chkpt)
else None,
"forward": nx.to_numpy(chkpt.get("forward", chkpt.get("_forward")))
if ("forward" in chkpt or "_forward" in chkpt)
else None,
"search_arc_num": int(
chkpt.get("search_arc_num", chkpt.get("_search_arc_num", 0))
),
"all_arc_num": int(chkpt.get("all_arc_num", chkpt.get("_all_arc_num", 0))),
}
# Filter out None values
checkpoint_data = {k: v for k, v in checkpoint_data.items() if v is not None}
elif warm_start is True:
# Save warm_start data - requires log=True
if not log:
raise ValueError(
"warm_start=True requires log=True to return the warm start data"
)
return_checkpoint = True

G, cost, u, v, result_code, checkpoint_out = emd_c(
a, b, M, numItermax, numThreads, checkpoint_data, int(return_checkpoint)
)

if center_dual:
u, v = center_ot_dual(u, v, a, b)
Expand All @@ -345,6 +420,24 @@ def emd(
log["v"] = nx.from_numpy(v, type_as=type_as)
log["warning"] = result_code_string
log["result_code"] = result_code

# Add checkpoint data if requested (preserve original dtypes, don't cast)
if return_checkpoint and checkpoint_out is not None:
log["checkpoint"] = {
"flow": checkpoint_out["flow"],
"pi": checkpoint_out["pi"],
"state": checkpoint_out["state"],
"parent": checkpoint_out["parent"],
"pred": checkpoint_out["pred"],
"thread": checkpoint_out["thread"],
"rev_thread": checkpoint_out["rev_thread"],
"succ_num": checkpoint_out["succ_num"],
"last_succ": checkpoint_out["last_succ"],
"forward": checkpoint_out["forward"],
"search_arc_num": int(checkpoint_out["search_arc_num"]),
"all_arc_num": int(checkpoint_out["all_arc_num"]),
}

return nx.from_numpy(G, type_as=type_as), log
return nx.from_numpy(G, type_as=type_as)

Expand Down
Loading