Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions ignite/base/mixins.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from collections import OrderedDict
from collections.abc import Mapping
from typing import Tuple
from typing import List, Tuple


class Serializable:
_state_dict_all_req_keys: Tuple = ()
_state_dict_one_of_opt_keys: Tuple = ()
_state_dict_all_req_keys: Tuple[str, ...] = ()
_state_dict_one_of_opt_keys: Tuple[Tuple[str, ...], ...] = ((),)

def __init__(self) -> None:
self._state_dict_user_keys: List[str] = []

@property
def state_dict_user_keys(self) -> List:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
raise NotImplementedError
Expand All @@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
raise ValueError(
f"Required state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
opts = [k in state_dict for k in self._state_dict_one_of_opt_keys]
if len(opts) > 0 and ((not any(opts)) or (all(opts))):
raise ValueError(f"state_dict should contain only one of '{self._state_dict_one_of_opt_keys}' keys")

# Handle groups of one-of optional keys
for one_of_opt_keys in self._state_dict_one_of_opt_keys:
if len(one_of_opt_keys) > 0:
opts = [k in state_dict for k in one_of_opt_keys]
num_present = sum(opts)
if num_present == 0:
raise ValueError(f"state_dict should contain at least one of '{one_of_opt_keys}' keys")
if num_present > 1:
raise ValueError(f"state_dict should contain only one of '{one_of_opt_keys}' keys")

# Check user keys
if hasattr(self, "_state_dict_user_keys") and isinstance(self._state_dict_user_keys, list):
for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
226 changes: 177 additions & 49 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,14 @@ def compute_mean_std(engine, batch):

"""

_state_dict_all_req_keys = ("epoch_length", "max_epochs")
_state_dict_one_of_opt_keys = ("iteration", "epoch")
_state_dict_all_req_keys = ("epoch_length",)
_state_dict_one_of_opt_keys = (("iteration", "epoch"), ("max_epochs", "max_iters"))

# Flag to disable engine._internal_run as generator feature for BC
interrupt_resume_enabled = True

def __init__(self, process_function: Callable[["Engine", Any], Any]):
super(Engine, self).__init__()
self._event_handlers: Dict[Any, List] = defaultdict(list)
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
Expand All @@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self.should_terminate_single_epoch: Union[bool, str] = False
self.should_interrupt = False
self.state = State()
self._state_dict_user_keys: List[str] = []
self._allowed_events: List[EventEnum] = []

self._dataloader_iter: Optional[Iterator[Any]] = None
Expand Down Expand Up @@ -691,14 +691,20 @@ def save_engine(_):
a dictionary containing engine's state

"""
keys: Tuple[str, ...] = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
keys: Tuple[str, ...] = self._state_dict_all_req_keys
keys += ("iteration",)
# Include either max_epochs or max_iters based on which was originally set
if self.state.max_iters is not None:
keys += ("max_iters",)
else:
keys += ("max_epochs",)
keys += tuple(self._state_dict_user_keys)
return OrderedDict([(k, getattr(self.state, k)) for k in keys])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Setups engine from `state_dict`.

State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`.
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.

Expand All @@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None:

.. code-block:: python

# Restore from the 4rd epoch
# Restore from the 4th epoch
state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
# or 500th iteration
# state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}
# or with max_iters
# state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)}

trainer = Engine(...)
trainer.load_state_dict(state_dict)
Expand All @@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None:
"""
super(Engine, self).load_state_dict(state_dict)

for k in self._state_dict_user_keys:
if k not in state_dict:
raise ValueError(
f"Required user state attribute '{k}' is absent in provided state_dict '{state_dict.keys()}'"
)
self.state.max_epochs = state_dict["max_epochs"]
# Set epoch_length
self.state.epoch_length = state_dict["epoch_length"]

# Set user keys
for k in self._state_dict_user_keys:
setattr(self.state, k, state_dict[k])

# Set iteration or epoch
if "iteration" in state_dict:
self.state.iteration = state_dict["iteration"]
self.state.epoch = 0
if self.state.epoch_length is not None:
if self.state.epoch_length is not None and self.state.epoch_length > 0:
self.state.epoch = self.state.iteration // self.state.epoch_length
elif "epoch" in state_dict:
else: # epoch is in state_dict
self.state.epoch = state_dict["epoch"]
if self.state.epoch_length is None:
raise ValueError(
Expand All @@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None:
)
self.state.iteration = self.state.epoch_length * self.state.epoch

# Set max_epochs or max_iters with validation
max_epochs_value = state_dict.get("max_epochs", None)
max_iters_value = state_dict.get("max_iters", None)

# Validate max_epochs if present
if max_epochs_value is not None:
if max_epochs_value < 1:
raise ValueError("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value")
if max_epochs_value < self.state.epoch:
raise ValueError(
"max_epochs in state_dict should be larger than or equal to the current epoch "
f"defined in the state: {max_epochs_value} vs {self.state.epoch}. "
)
self.state.max_epochs = max_epochs_value
else:
self.state.max_epochs = None

# Validate max_iters if present
if max_iters_value is not None:
if max_iters_value < 1:
raise ValueError("max_iters in state_dict is invalid. Please, set a correct max_iters positive value")
if max_iters_value < self.state.iteration:
raise ValueError(
"max_iters in state_dict should be larger than or equal to the current iteration "
f"defined in the state: {max_iters_value} vs {self.state.iteration}. "
)
self.state.max_iters = max_iters_value
else:
self.state.max_iters = None

@staticmethod
def _is_done(state: State) -> bool:
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
Expand All @@ -756,6 +792,59 @@ def _is_done(state: State) -> bool:
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
return is_done_iters or is_done_count or is_done_epochs

def _check_and_set_max_epochs(self, max_epochs: Optional[int] = None) -> None:
"""Validate and set max_epochs with proper checks."""
if max_epochs is not None:
if max_epochs < 1:
raise ValueError("Argument max_epochs is invalid. Please, set a correct max_epochs positive value")
# Only validate if training is actually done - allow resuming interrupted training
if self.state.max_epochs is not None and max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be greater than or equal to the start "
f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs

def _check_and_set_max_iters(self, max_iters: Optional[int] = None) -> None:
"""Validate and set max_iters with proper checks."""
if max_iters is not None:
if max_iters < 1:
raise ValueError("Argument max_iters is invalid. Please, set a correct max_iters positive value")
# Only validate if training is actually done - allow resuming interrupted training
if (self.state.max_iters is not None) and max_iters < self.state.iteration:
raise ValueError(
"Argument max_iters should be greater than or equal to the start "
f"iteration defined in the state: {max_iters} vs {self.state.iteration}. "
"Please, set engine.state.max_iters = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_iters = max_iters

def _check_and_set_epoch_length(self, data: Optional[Iterable], epoch_length: Optional[int] = None) -> None:
"""Validate and set epoch_length."""
# Check if we can redefine epoch_length
if self.state.epoch_length is not None:
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
else:
if epoch_length is None:
if data is not None:
epoch_length = self._get_data_length(data)

if epoch_length is not None:
if epoch_length < 1:
raise ValueError(
"Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
"check if input data has non-zero size."
)
self.state.epoch_length = epoch_length

def set_data(self, data: Union[Iterable, DataLoader]) -> None:
"""Method to set data. After calling the method the next batch passed to `processing_function` is
from newly provided data. Please, note that epoch length is not modified.
Expand Down Expand Up @@ -854,59 +943,98 @@ def switch_batch(engine):
if data is not None and not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

if self.state.max_epochs is not None:
# Check and apply overridden parameters
if max_epochs is not None:
if max_epochs < self.state.epoch:
raise ValueError(
"Argument max_epochs should be greater than or equal to the start "
f"epoch defined in the state: {max_epochs} vs {self.state.epoch}. "
"Please, set engine.state.max_epochs = None "
"before calling engine.run() in order to restart the training from the beginning."
)
self.state.max_epochs = max_epochs
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)
if max_epochs is not None and max_iters is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)

if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
# Create new state
if epoch_length is None:
if data is None:
raise ValueError("epoch_length should be provided if data is None")
# Check if we need to create new state or resume
# Create new state if:
# 1. No termination params set (first run), OR
# 2. Training is done AND generator is None AND no new params provided
# 3. Training is done AND same termination params provided (restart case)
should_create_new_state = (
(self.state.max_epochs is None and self.state.max_iters is None)
or (
self._is_done(self.state)
and self._internal_run_generator is None
and max_epochs is None
and max_iters is None
)
or (
self._is_done(self.state)
and self._internal_run_generator is None
and (
(max_epochs is not None and max_epochs == self.state.max_epochs)
or (max_iters is not None and max_iters == self.state.max_iters)
)
)
)

epoch_length = self._get_data_length(data)
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")
if should_create_new_state:
# Create new state
if data is None and epoch_length is None and self.state.epoch_length is None:
raise ValueError("epoch_length should be provided if data is None")

# Set epoch_length for new state
if epoch_length is None:
# Try to get from data first, then fall back to existing state
if data is not None:
epoch_length = self._get_data_length(data)
if epoch_length is None and self.state.epoch_length is not None:
epoch_length = self.state.epoch_length
if epoch_length is not None and epoch_length < 1:
raise ValueError("Input data has zero size. Please provide non-empty data")

# Determine max_epochs/max_iters
if max_iters is None:
if max_epochs is None:
max_epochs = 1
else:
if max_epochs is not None:
raise ValueError(
"Arguments max_iters and max_epochs are mutually exclusive."
"Please provide only max_epochs or max_iters."
)
if epoch_length is not None:
max_epochs = math.ceil(max_iters / epoch_length)

# Initialize new state
self.state.iteration = 0
self.state.epoch = 0
self.state.max_epochs = max_epochs
self.state.max_iters = max_iters
self.state.epoch_length = epoch_length
# Reset generator if previously used
self._internal_run_generator = None
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")

# Log start message
if self.state.max_epochs is not None:
self.logger.info(f"Engine run starting with max_epochs={self.state.max_epochs}.")
else:
self.logger.info(f"Engine run starting with max_iters={self.state.max_iters}.")
else:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
# Resume from existing state
# Apply overridden parameters using helper methods
self._check_and_set_max_epochs(max_epochs)
self._check_and_set_max_iters(max_iters)

# Handle epoch_length validation (simplified from original)
if epoch_length is not None:
if epoch_length != self.state.epoch_length:
raise ValueError(
"Argument epoch_length should be same as in the state, "
f"but given {epoch_length} vs {self.state.epoch_length}"
)

# Log resuming message
if self.state.max_epochs is not None:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_epochs} epochs"
)
else:
self.logger.info(
f"Engine run resuming from iteration {self.state.iteration}, "
f"epoch {self.state.epoch} until {self.state.max_iters} iterations"
)

if self.state.epoch_length is None and data is None:
raise ValueError("epoch_length should be provided if data is None")

Expand Down
Loading
Loading