@@ -128,13 +128,14 @@ def compute_mean_std(engine, batch):
128128
129129 """
130130
131- _state_dict_all_req_keys = ("epoch_length" , "max_epochs" )
132- _state_dict_one_of_opt_keys = ("iteration" , "epoch" )
131+ _state_dict_all_req_keys = ("epoch_length" ,)
132+ _state_dict_one_of_opt_keys = (( "iteration" , "epoch" ), ( "max_epochs" , "max_iters" ) )
133133
134134 # Flag to disable engine._internal_run as generator feature for BC
135135 interrupt_resume_enabled = True
136136
137137 def __init__ (self , process_function : Callable [["Engine" , Any ], Any ]):
138+ super (Engine , self ).__init__ ()
138139 self ._event_handlers : Dict [Any , List ] = defaultdict (list )
139140 self .logger = logging .getLogger (__name__ + "." + self .__class__ .__name__ )
140141 self ._process_function = process_function
@@ -147,7 +148,6 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
147148 self .should_terminate_single_epoch : Union [bool , str ] = False
148149 self .should_interrupt = False
149150 self .state = State ()
150- self ._state_dict_user_keys : List [str ] = []
151151 self ._allowed_events : List [EventEnum ] = []
152152
153153 self ._dataloader_iter : Optional [Iterator [Any ]] = None
@@ -691,14 +691,20 @@ def save_engine(_):
691691 a dictionary containing engine's state
692692
693693 """
694- keys : Tuple [str , ...] = self ._state_dict_all_req_keys + (self ._state_dict_one_of_opt_keys [0 ],)
694+ keys : Tuple [str , ...] = self ._state_dict_all_req_keys
695+ keys += ("iteration" ,)
696+ # Include either max_epochs or max_iters based on which was originally set
697+ if self .state .max_iters is not None :
698+ keys += ("max_iters" ,)
699+ else :
700+ keys += ("max_epochs" ,)
695701 keys += tuple (self ._state_dict_user_keys )
696702 return OrderedDict ([(k , getattr (self .state , k )) for k in keys ])
697703
698704 def load_state_dict (self , state_dict : Mapping ) -> None :
699705 """Setups engine from `state_dict`.
700706
701- State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
707+ State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` or `max_iters`, and `epoch_length`.
702708 If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
703709 Iteration and epoch values are 0-based: the first iteration or epoch is zero.
704710
@@ -709,10 +715,12 @@ def load_state_dict(self, state_dict: Mapping) -> None:
709715
710716 .. code-block:: python
711717
712- # Restore from the 4rd epoch
718+ # Restore from the 4th epoch
713719 state_dict = {"epoch": 3, "max_epochs": 100, "epoch_length": len(data_loader)}
714720 # or 500th iteration
715721 # state_dict = {"iteration": 499, "max_epochs": 100, "epoch_length": len(data_loader)}
722+ # or with max_iters
723+ # state_dict = {"iteration": 499, "max_iters": 1000, "epoch_length": len(data_loader)}
716724
717725 trainer = Engine(...)
718726 trainer.load_state_dict(state_dict)
@@ -721,22 +729,20 @@ def load_state_dict(self, state_dict: Mapping) -> None:
721729 """
722730 super (Engine , self ).load_state_dict (state_dict )
723731
724- for k in self ._state_dict_user_keys :
725- if k not in state_dict :
726- raise ValueError (
727- f"Required user state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
728- )
729- self .state .max_epochs = state_dict ["max_epochs" ]
732+ # Set epoch_length
730733 self .state .epoch_length = state_dict ["epoch_length" ]
734+
735+ # Set user keys
731736 for k in self ._state_dict_user_keys :
732737 setattr (self .state , k , state_dict [k ])
733738
739+ # Set iteration or epoch
734740 if "iteration" in state_dict :
735741 self .state .iteration = state_dict ["iteration" ]
736742 self .state .epoch = 0
737- if self .state .epoch_length is not None :
743+ if self .state .epoch_length is not None and self . state . epoch_length > 0 :
738744 self .state .epoch = self .state .iteration // self .state .epoch_length
739- elif " epoch" in state_dict :
745+ else : # epoch is in state_dict
740746 self .state .epoch = state_dict ["epoch" ]
741747 if self .state .epoch_length is None :
742748 raise ValueError (
@@ -745,6 +751,36 @@ def load_state_dict(self, state_dict: Mapping) -> None:
745751 )
746752 self .state .iteration = self .state .epoch_length * self .state .epoch
747753
754+ # Set max_epochs or max_iters with validation
755+ max_epochs_value = state_dict .get ("max_epochs" , None )
756+ max_iters_value = state_dict .get ("max_iters" , None )
757+
758+ # Validate max_epochs if present
759+ if max_epochs_value is not None :
760+ if max_epochs_value < 1 :
761+ raise ValueError ("max_epochs in state_dict is invalid. Please, set a correct max_epochs positive value" )
762+ if max_epochs_value < self .state .epoch :
763+ raise ValueError (
764+ "max_epochs in state_dict should be larger than or equal to the current epoch "
765+ f"defined in the state: { max_epochs_value } vs { self .state .epoch } . "
766+ )
767+ self .state .max_epochs = max_epochs_value
768+ else :
769+ self .state .max_epochs = None
770+
771+ # Validate max_iters if present
772+ if max_iters_value is not None :
773+ if max_iters_value < 1 :
774+ raise ValueError ("max_iters in state_dict is invalid. Please, set a correct max_iters positive value" )
775+ if max_iters_value < self .state .iteration :
776+ raise ValueError (
777+ "max_iters in state_dict should be larger than or equal to the current iteration "
778+ f"defined in the state: { max_iters_value } vs { self .state .iteration } . "
779+ )
780+ self .state .max_iters = max_iters_value
781+ else :
782+ self .state .max_iters = None
783+
748784 @staticmethod
749785 def _is_done (state : State ) -> bool :
750786 is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
@@ -756,6 +792,59 @@ def _is_done(state: State) -> bool:
756792 is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
757793 return is_done_iters or is_done_count or is_done_epochs
758794
795+ def _check_and_set_max_epochs (self , max_epochs : Optional [int ] = None ) -> None :
796+ """Validate and set max_epochs with proper checks."""
797+ if max_epochs is not None :
798+ if max_epochs < 1 :
799+ raise ValueError ("Argument max_epochs is invalid. Please, set a correct max_epochs positive value" )
800+ # Only validate if training is actually done - allow resuming interrupted training
801+ if self .state .max_epochs is not None and max_epochs < self .state .epoch :
802+ raise ValueError (
803+ "Argument max_epochs should be greater than or equal to the start "
804+ f"epoch defined in the state: { max_epochs } vs { self .state .epoch } . "
805+ "Please, set engine.state.max_epochs = None "
806+ "before calling engine.run() in order to restart the training from the beginning."
807+ )
808+ self .state .max_epochs = max_epochs
809+
810+ def _check_and_set_max_iters (self , max_iters : Optional [int ] = None ) -> None :
811+ """Validate and set max_iters with proper checks."""
812+ if max_iters is not None :
813+ if max_iters < 1 :
814+ raise ValueError ("Argument max_iters is invalid. Please, set a correct max_iters positive value" )
815+ # Only validate if training is actually done - allow resuming interrupted training
816+ if (self .state .max_iters is not None ) and max_iters < self .state .iteration :
817+ raise ValueError (
818+ "Argument max_iters should be greater than or equal to the start "
819+ f"iteration defined in the state: { max_iters } vs { self .state .iteration } . "
820+ "Please, set engine.state.max_iters = None "
821+ "before calling engine.run() in order to restart the training from the beginning."
822+ )
823+ self .state .max_iters = max_iters
824+
825+ def _check_and_set_epoch_length (self , data : Optional [Iterable ], epoch_length : Optional [int ] = None ) -> None :
826+ """Validate and set epoch_length."""
827+ # Check if we can redefine epoch_length
828+ if self .state .epoch_length is not None :
829+ if epoch_length is not None :
830+ if epoch_length != self .state .epoch_length :
831+ raise ValueError (
832+ "Argument epoch_length should be same as in the state, "
833+ f"but given { epoch_length } vs { self .state .epoch_length } "
834+ )
835+ else :
836+ if epoch_length is None :
837+ if data is not None :
838+ epoch_length = self ._get_data_length (data )
839+
840+ if epoch_length is not None :
841+ if epoch_length < 1 :
842+ raise ValueError (
843+ "Argument epoch_length is invalid. Please, either set a correct epoch_length value or "
844+ "check if input data has non-zero size."
845+ )
846+ self .state .epoch_length = epoch_length
847+
759848 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
760849 """Method to set data. After calling the method the next batch passed to `processing_function` is
761850 from newly provided data. Please, note that epoch length is not modified.
@@ -854,59 +943,98 @@ def switch_batch(engine):
854943 if data is not None and not isinstance (data , Iterable ):
855944 raise TypeError ("Argument data should be iterable" )
856945
857- if self .state .max_epochs is not None :
858- # Check and apply overridden parameters
859- if max_epochs is not None :
860- if max_epochs < self .state .epoch :
861- raise ValueError (
862- "Argument max_epochs should be greater than or equal to the start "
863- f"epoch defined in the state: { max_epochs } vs { self .state .epoch } . "
864- "Please, set engine.state.max_epochs = None "
865- "before calling engine.run() in order to restart the training from the beginning."
866- )
867- self .state .max_epochs = max_epochs
868- if epoch_length is not None :
869- if epoch_length != self .state .epoch_length :
870- raise ValueError (
871- "Argument epoch_length should be same as in the state, "
872- f"but given { epoch_length } vs { self .state .epoch_length } "
873- )
946+ if max_epochs is not None and max_iters is not None :
947+ raise ValueError (
948+ "Arguments max_iters and max_epochs are mutually exclusive."
949+ "Please provide only max_epochs or max_iters."
950+ )
874951
875- if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
876- # Create new state
877- if epoch_length is None :
878- if data is None :
879- raise ValueError ("epoch_length should be provided if data is None" )
952+ # Check if we need to create new state or resume
953+ # Create new state if:
954+ # 1. No termination params set (first run), OR
955+ # 2. Training is done AND generator is None AND no new params provided
956+ # 3. Training is done AND same termination params provided (restart case)
957+ should_create_new_state = (
958+ (self .state .max_epochs is None and self .state .max_iters is None )
959+ or (
960+ self ._is_done (self .state )
961+ and self ._internal_run_generator is None
962+ and max_epochs is None
963+ and max_iters is None
964+ )
965+ or (
966+ self ._is_done (self .state )
967+ and self ._internal_run_generator is None
968+ and (
969+ (max_epochs is not None and max_epochs == self .state .max_epochs )
970+ or (max_iters is not None and max_iters == self .state .max_iters )
971+ )
972+ )
973+ )
880974
881- epoch_length = self ._get_data_length (data )
882- if epoch_length is not None and epoch_length < 1 :
883- raise ValueError ("Input data has zero size. Please provide non-empty data" )
975+ if should_create_new_state :
976+ # Create new state
977+ if data is None and epoch_length is None and self .state .epoch_length is None :
978+ raise ValueError ("epoch_length should be provided if data is None" )
884979
980+ # Set epoch_length for new state
981+ if epoch_length is None :
982+ # Try to get from data first, then fall back to existing state
983+ if data is not None :
984+ epoch_length = self ._get_data_length (data )
985+ if epoch_length is None and self .state .epoch_length is not None :
986+ epoch_length = self .state .epoch_length
987+ if epoch_length is not None and epoch_length < 1 :
988+ raise ValueError ("Input data has zero size. Please provide non-empty data" )
989+
990+ # Determine max_epochs/max_iters
885991 if max_iters is None :
886992 if max_epochs is None :
887993 max_epochs = 1
888994 else :
889- if max_epochs is not None :
890- raise ValueError (
891- "Arguments max_iters and max_epochs are mutually exclusive."
892- "Please provide only max_epochs or max_iters."
893- )
894995 if epoch_length is not None :
895996 max_epochs = math .ceil (max_iters / epoch_length )
896997
998+ # Initialize new state
897999 self .state .iteration = 0
8981000 self .state .epoch = 0
8991001 self .state .max_epochs = max_epochs
9001002 self .state .max_iters = max_iters
9011003 self .state .epoch_length = epoch_length
9021004 # Reset generator if previously used
9031005 self ._internal_run_generator = None
904- self .logger .info (f"Engine run starting with max_epochs={ max_epochs } ." )
1006+
1007+ # Log start message
1008+ if self .state .max_epochs is not None :
1009+ self .logger .info (f"Engine run starting with max_epochs={ self .state .max_epochs } ." )
1010+ else :
1011+ self .logger .info (f"Engine run starting with max_iters={ self .state .max_iters } ." )
9051012 else :
906- self .logger .info (
907- f"Engine run resuming from iteration { self .state .iteration } , "
908- f"epoch { self .state .epoch } until { self .state .max_epochs } epochs"
909- )
1013+ # Resume from existing state
1014+ # Apply overridden parameters using helper methods
1015+ self ._check_and_set_max_epochs (max_epochs )
1016+ self ._check_and_set_max_iters (max_iters )
1017+
1018+ # Handle epoch_length validation (simplified from original)
1019+ if epoch_length is not None :
1020+ if epoch_length != self .state .epoch_length :
1021+ raise ValueError (
1022+ "Argument epoch_length should be same as in the state, "
1023+ f"but given { epoch_length } vs { self .state .epoch_length } "
1024+ )
1025+
1026+ # Log resuming message
1027+ if self .state .max_epochs is not None :
1028+ self .logger .info (
1029+ f"Engine run resuming from iteration { self .state .iteration } , "
1030+ f"epoch { self .state .epoch } until { self .state .max_epochs } epochs"
1031+ )
1032+ else :
1033+ self .logger .info (
1034+ f"Engine run resuming from iteration { self .state .iteration } , "
1035+ f"epoch { self .state .epoch } until { self .state .max_iters } iterations"
1036+ )
1037+
9101038 if self .state .epoch_length is None and data is None :
9111039 raise ValueError ("epoch_length should be provided if data is None" )
9121040
0 commit comments