1515from torch .optim import Optimizer
1616from torch .optim .lr_scheduler import LRScheduler
1717
18+ from modalities .optimizers .optimizer_list import OptimizersList
19+
1820
1921class StatefulComponents (Enum ):
2022 MODEL = "model"
@@ -34,15 +36,18 @@ class AppState(Stateful):
3436 https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
3537 """
3638
37- def __init__ (self , model : nn .Module , optimizer : Optimizer , lr_scheduler : Optional [LRScheduler ] = None ):
39+ def __init__ (
40+ self , model : nn .Module | list [nn .Module ], optimizer : Optimizer , lr_scheduler : Optional [LRScheduler ] = None
41+ ):
3842 """Initializes the AppState object.
3943
4044 Args:
41- model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
45+ model (nn.Module | list[nn.Module]): The model or model parts can be either
46+ a non-sharded model, FSDP1 or FSDP2 model.
4247 optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
4348 lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
4449 """
45- self ._model = model
50+ self ._model_parts = list ( model ) if isinstance ( model , list ) else [ model ]
4651 self ._optimizer = optimizer
4752 self ._lr_scheduler = lr_scheduler
4853 self ._is_loaded = False
@@ -56,8 +61,8 @@ def is_loaded(self) -> bool:
5661 return self ._is_loaded
5762
5863 @property
59- def model (self ) -> nn .Module :
60- return self ._model
64+ def model_parts (self ) -> list [ nn .Module ] :
65+ return self ._model_parts
6166
6267 @property
6368 def optimizer (self ) -> Optimizer :
@@ -153,15 +158,18 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
153158class ModelStateRetriever (StateRetrieverIF ):
154159 @staticmethod
155160 def get_state_dict (app_state : AppState ) -> dict [str , Any ]:
156- """Returns the state dict of the model in the AppState object.
161+ """Returns the flattened state dicts of the model parts in the AppState object.
157162
158163 Args:
159164 app_state (AppState): The app_state object containing the model.
160165
161166 Returns:
162167 dict[str, Any]: The state dict of the model in the AppState object.
163168 """
164- return get_model_state_dict (model = app_state .model )
169+ state_dicts = list (map (get_model_state_dict , app_state .model_parts ))
170+ state_dict_keys = sum ((list (sd .keys ()) for sd in state_dicts ), [])
171+ assert len (state_dict_keys ) == len (set (state_dict_keys )), "State dict keys are not unique across model parts."
172+ return {k : v for sd in state_dicts for k , v in sd .items ()}
165173
166174 @staticmethod
167175 def load_state_dict_ (app_state : AppState , state_dict : dict [str , Any ]) -> None :
@@ -171,7 +179,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
171179 app_state (AppState): The app_state object containing the model.
172180 state_dict (dict[str, Any]): The state dict to load into the model.
173181 """
174- set_model_state_dict (model = app_state .model , model_state_dict = state_dict , options = StateDictOptions (strict = False ))
182+ for model in app_state .model_parts :
183+ set_model_state_dict (model = model , model_state_dict = state_dict , options = StateDictOptions (strict = False ))
175184
176185
177186class OptimizerStateRetriever (StateRetrieverIF ):
@@ -185,13 +194,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
185194 Returns:
186195 dict[str, Any]: The state dict of the optimizer in the AppState object.
187196 """
188- sd = get_optimizer_state_dict (
189- model = app_state .model ,
190- optimizers = app_state .optimizer ,
191- # NOTE: Flattening is required for pipeline parallelism to work correctly.
192- # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
193- options = StateDictOptions (flatten_optimizer_state_dict = True ),
194- )
197+ if isinstance (app_state .optimizer , OptimizersList ):
198+ sd = app_state .optimizer .state_dict ()
199+ else :
200+ assert len (app_state .model_parts ) == 1 , "Expected a single model part for non-OptimizersList optimizer."
201+ sd = get_optimizer_state_dict (
202+ model = app_state .model_parts [0 ],
203+ optimizers = app_state .optimizer ,
204+ # NOTE: Flattening is required for pipeline parallelism to work correctly.
205+ # see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
206+ options = StateDictOptions (flatten_optimizer_state_dict = True ),
207+ )
195208 return sd
196209
197210 @staticmethod
@@ -202,12 +215,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
202215 app_state (AppState): The app_state object containing the optimizer.
203216 state_dict (dict[str, Any]): The state dict to load into the optimizer.
204217 """
205- set_optimizer_state_dict (
206- model = app_state .model ,
207- optimizers = app_state .optimizer ,
208- optim_state_dict = state_dict ,
209- options = StateDictOptions (flatten_optimizer_state_dict = True ),
210- )
218+ if isinstance (app_state .optimizer , OptimizersList ):
219+ app_state .optimizer .load_state_dict (state_dict )
220+ else :
221+ assert len (app_state .model_parts ) == 1 , "Expected a single model part for non-OptimizersList optimizer."
222+ set_optimizer_state_dict (
223+ model = app_state .model_parts [0 ],
224+ optimizers = app_state .optimizer ,
225+ optim_state_dict = state_dict ,
226+ options = StateDictOptions (flatten_optimizer_state_dict = True ),
227+ )
211228
212229
213230class LRSchedulerStateRetriever (StateRetrieverIF ):
0 commit comments