Skip to content

Commit 0e153ef

Browse files
committed
chore: Merge branch 'main' into monitoring_improvements
2 parents eb747fd + 9420c0b commit 0e153ef

46 files changed

Lines changed: 2247 additions & 446 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ config_files/instruction_tuning
172172
data/lorem_ipsum_instruct.jsonl
173173
tutorials/scaling_up/logs*
174174
tutorials/scaling_up/experiments_old/*
175-
176175
results/*
177176
tutorials/einsum_transformer/experiments/*
178-
tutorials/warmstart/experiments/*
177+
tutorials/warmstart/experiments/*
178+

config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ device_mesh:
194194
config:
195195
device_type: cuda
196196
data_parallel_replicate_degree: 1
197-
pipeline_parallel_degree: 2
197+
pipeline_parallel_degree: 4
198198
data_parallel_shard_degree: -1
199199
world_size: ${settings.cuda_env.world_size}
200200

@@ -251,7 +251,7 @@ scheduled_pipeline:
251251
loss_fn:
252252
instance_key: loss_fn
253253
pass_type: BY_REFERENCE
254-
pp_schedule_name: gpipe
254+
pp_schedule_name: Interleaved1F1B
255255
batch_size: ${settings.step_profile.local_train_micro_batch_size}
256256
microbatch_size: 2
257257
pp_degree: ${device_mesh.config.pipeline_parallel_degree}
@@ -318,7 +318,7 @@ staged_pipeline:
318318
instance_key: device_mesh
319319
pass_type: BY_REFERENCE
320320
local_rank: ${settings.cuda_env.local_rank}
321-
pp_schedule_name: gpipe
321+
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
322322
num_layers_per_stage: 2
323323

324324
model_raw:
@@ -332,7 +332,7 @@ model_raw:
332332
sequence_length: ${settings.step_profile.sequence_length}
333333
prediction_key: ${loss_fn.config.prediction_key}
334334
vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
335-
n_layer: 2
335+
n_layer: 6
336336
n_head_q: 8
337337
n_head_kv: 4
338338
ffn_hidden: 128

config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ staged_pipeline:
308308
instance_key: device_mesh
309309
pass_type: BY_REFERENCE
310310
local_rank: ${settings.cuda_env.local_rank}
311-
pp_schedule_name: gpipe
311+
pp_schedule_name: ${scheduled_pipeline.config.pp_schedule_name}
312312
num_layers_per_stage: 2
313313

314314
model_raw:

src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def _save_checkpoint(self, app_state: AppState, training_progress: TrainingProgr
8989
# saving the model via FULL_STATE_DICT and checkpoint via FULL_OPTIM_STATE_DICT
9090
model_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
9191
optim_save_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
92-
model = app_state.model
92+
assert len(app_state.model_parts) == 1, "FSDP1CheckpointSaving only supports a single model part."
93+
model = app_state.model_parts[0]
9394
optimizer = app_state.optimizer
9495
with FSDP.state_dict_type(
9596
module=model,

src/modalities/checkpointing/stateful/app_state.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from torch.optim import Optimizer
1616
from torch.optim.lr_scheduler import LRScheduler
1717

18+
from modalities.optimizers.optimizer_list import OptimizersList
19+
1820

1921
class StatefulComponents(Enum):
2022
MODEL = "model"
@@ -34,15 +36,18 @@ class AppState(Stateful):
3436
https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
3537
"""
3638

37-
def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None):
39+
def __init__(
40+
self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
41+
):
3842
"""Initializes the AppState object.
3943
4044
Args:
41-
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
45+
model (nn.Module | list[nn.Module]): The model or model parts can be either
46+
a non-sharded model, FSDP1 or FSDP2 model.
4247
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
4348
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
4449
"""
45-
self._model = model
50+
self._model_parts = list(model) if isinstance(model, list) else [model]
4651
self._optimizer = optimizer
4752
self._lr_scheduler = lr_scheduler
4853
self._is_loaded = False
@@ -56,8 +61,8 @@ def is_loaded(self) -> bool:
5661
return self._is_loaded
5762

5863
@property
59-
def model(self) -> nn.Module:
60-
return self._model
64+
def model_parts(self) -> list[nn.Module]:
65+
return self._model_parts
6166

6267
@property
6368
def optimizer(self) -> Optimizer:
@@ -153,15 +158,18 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
153158
class ModelStateRetriever(StateRetrieverIF):
154159
@staticmethod
155160
def get_state_dict(app_state: AppState) -> dict[str, Any]:
156-
"""Returns the state dict of the model in the AppState object.
161+
"""Returns the flattened state dicts of the model parts in the AppState object.
157162
158163
Args:
159164
app_state (AppState): The app_state object containing the model.
160165
161166
Returns:
162167
dict[str, Any]: The state dict of the model in the AppState object.
163168
"""
164-
return get_model_state_dict(model=app_state.model)
169+
state_dicts = list(map(get_model_state_dict, app_state.model_parts))
170+
state_dict_keys = sum((list(sd.keys()) for sd in state_dicts), [])
171+
assert len(state_dict_keys) == len(set(state_dict_keys)), "State dict keys are not unique across model parts."
172+
return {k: v for sd in state_dicts for k, v in sd.items()}
165173

166174
@staticmethod
167175
def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
@@ -171,7 +179,8 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
171179
app_state (AppState): The app_state object containing the model.
172180
state_dict (dict[str, Any]): The state dict to load into the model.
173181
"""
174-
set_model_state_dict(model=app_state.model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
182+
for model in app_state.model_parts:
183+
set_model_state_dict(model=model, model_state_dict=state_dict, options=StateDictOptions(strict=False))
175184

176185

177186
class OptimizerStateRetriever(StateRetrieverIF):
@@ -185,13 +194,17 @@ def get_state_dict(app_state: AppState) -> dict[str, Any]:
185194
Returns:
186195
dict[str, Any]: The state dict of the optimizer in the AppState object.
187196
"""
188-
sd = get_optimizer_state_dict(
189-
model=app_state.model,
190-
optimizers=app_state.optimizer,
191-
# NOTE: Flattening is required for pipeline parallelism to work correctly.
192-
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
193-
options=StateDictOptions(flatten_optimizer_state_dict=True),
194-
)
197+
if isinstance(app_state.optimizer, OptimizersList):
198+
sd = app_state.optimizer.state_dict()
199+
else:
200+
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
201+
sd = get_optimizer_state_dict(
202+
model=app_state.model_parts[0],
203+
optimizers=app_state.optimizer,
204+
# NOTE: Flattening is required for pipeline parallelism to work correctly.
205+
# see https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/components/checkpoint.py#L193-L214
206+
options=StateDictOptions(flatten_optimizer_state_dict=True),
207+
)
195208
return sd
196209

197210
@staticmethod
@@ -202,12 +215,16 @@ def load_state_dict_(app_state: AppState, state_dict: dict[str, Any]) -> None:
202215
app_state (AppState): The app_state object containing the optimizer.
203216
state_dict (dict[str, Any]): The state dict to load into the optimizer.
204217
"""
205-
set_optimizer_state_dict(
206-
model=app_state.model,
207-
optimizers=app_state.optimizer,
208-
optim_state_dict=state_dict,
209-
options=StateDictOptions(flatten_optimizer_state_dict=True),
210-
)
218+
if isinstance(app_state.optimizer, OptimizersList):
219+
app_state.optimizer.load_state_dict(state_dict)
220+
else:
221+
assert len(app_state.model_parts) == 1, "Expected a single model part for non-OptimizersList optimizer."
222+
set_optimizer_state_dict(
223+
model=app_state.model_parts[0],
224+
optimizers=app_state.optimizer,
225+
optim_state_dict=state_dict,
226+
options=StateDictOptions(flatten_optimizer_state_dict=True),
227+
)
211228

212229

213230
class LRSchedulerStateRetriever(StateRetrieverIF):

src/modalities/checkpointing/stateful/app_state_factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ class AppStateFactory:
1515

1616
@staticmethod
1717
def get_raw_app_state(
18-
model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
18+
model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
1919
) -> AppState:
2020
"""Creates a new (non-checkpoint loaded) AppState object from an instantiated
2121
model, optimizer, and optional learning rate scheduler.
2222
2323
Args:
24-
model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model.
24+
model (nn.Module | list[nn.Module]): The model (parts) can be either
25+
a non-sharded model, FSDP1 or FSDP2 model.
2526
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
2627
lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
2728

src/modalities/config/component_factory.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Type, TypeVar
22

3-
from pydantic import BaseModel
3+
from pydantic import AliasChoices, BaseModel
4+
from pydantic.fields import FieldInfo
45

56
from modalities.registry.registry import Registry
67
from modalities.util import print_rank_0
@@ -164,30 +165,53 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co
164165
config_dict=config_dict,
165166
component_config_type=component_config_type,
166167
)
167-
comp_config = component_config_type(**config_dict, strict=True)
168+
comp_config = component_config_type.model_validate(config_dict, extra="forbid")
168169
return comp_config
169170

170171
def _assert_valid_config_keys(
171172
self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild]
172173
) -> None:
173-
required_keys = []
174-
optional_keys = []
175-
for key, field in component_config_type.model_fields.items():
174+
# Collect required and optional keys, including aliases if defined.
175+
required_keys: list[str] = []
176+
optional_keys: list[str] = []
177+
# Map aliases to canonical field names for clearer error messages.
178+
alias_to_field: dict[str, str] = {}
179+
180+
for field_name, field in component_config_type.model_fields.items():
181+
names_for_field = self._parse_str_aliases(alias_to_field, field_name, field)
176182
if field.is_required():
177-
required_keys.append(key)
183+
required_keys.extend(names_for_field)
178184
else:
179-
optional_keys.append(key)
185+
optional_keys.extend(names_for_field)
180186

181-
invalid_keys = []
182-
for key in config_dict.keys():
183-
if key not in required_keys and key not in optional_keys:
184-
invalid_keys.append(key)
187+
all_valid_keys = set(required_keys) | set(optional_keys)
188+
189+
invalid_keys = [key for key in config_dict.keys() if key not in all_valid_keys]
185190
if len(invalid_keys) > 0:
186191
message = f"Invalid keys {invalid_keys} for config `{component_key}.{variant_key}`"
187192
message += f" of type {component_config_type}:\n{config_dict}\n"
188-
message += f"Required keys: {required_keys}\nOptional keys: {optional_keys}"
193+
if alias_to_field:
194+
message += f"Alias to field mapping: {alias_to_field}\n"
195+
message += f"Required keys (including aliases): {required_keys}\n"
196+
message += f"Optional keys (including aliases): {optional_keys}\n"
189197
raise ValueError(message)
190198

199+
def _parse_str_aliases(self, alias_to_field: dict[str, str], field_name: str, field: FieldInfo) -> set[str]:
200+
names_for_field = {field_name}
201+
if field.alias and field.alias != field_name:
202+
names_for_field.add(field.alias)
203+
alias_to_field[field.alias] = field_name
204+
if field.validation_alias and field.validation_alias != field_name:
205+
if isinstance(field.validation_alias, str):
206+
names_for_field.add(field.validation_alias)
207+
alias_to_field[field.validation_alias] = field_name
208+
elif isinstance(field.validation_alias, AliasChoices):
209+
for alias in field.validation_alias.choices:
210+
if isinstance(alias, str):
211+
names_for_field.add(alias)
212+
alias_to_field[alias] = field_name
213+
return names_for_field
214+
191215
def _instantiate_component(self, component_key: str, variant_key: str, component_config: BaseModel) -> Any:
192216
component_type: Type = self.registry.get_component(component_key, variant_key)
193217
component_config_dict = self._base_model_to_dict(component_config)

src/modalities/config/config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
PydanticModelInitializationIFType,
2828
PydanticOptimizerIFType,
2929
PydanticPytorchDeviceType,
30+
PydanticPytorchModuleOrListType,
3031
PydanticPytorchModuleType,
3132
PydanticSamplerIFType,
3233
PydanticTokenizerIFType,
@@ -43,6 +44,7 @@
4344
ActivationCheckpointingVariants,
4445
)
4546
from modalities.util import parse_enum_by_name
47+
from modalities.utils.deprecated_alias import add_deprecated_alias
4648

4749

4850
class ProcessGroupBackendType(LookupEnum):
@@ -145,7 +147,7 @@ class CheckpointSavingConfig(BaseModel):
145147

146148
class AdamOptimizerConfig(BaseModel):
147149
lr: float
148-
wrapped_model: PydanticPytorchModuleType
150+
wrapped_model: PydanticPytorchModuleOrListType
149151
betas: tuple[float, float]
150152
eps: float
151153
weight_decay: float
@@ -156,7 +158,7 @@ class AdamOptimizerConfig(BaseModel):
156158

157159
class AdamWOptimizerConfig(BaseModel):
158160
lr: float
159-
wrapped_model: PydanticPytorchModuleType
161+
wrapped_model: PydanticPytorchModuleOrListType
160162
betas: tuple[float, float]
161163
eps: float
162164
weight_decay: float
@@ -268,7 +270,7 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:
268270

269271

270272
class FSDP2WrappedModelConfig(BaseModel):
271-
model: PydanticPytorchModuleType
273+
model: PydanticPytorchModuleOrListType
272274
block_names: list[str]
273275
mixed_precision_settings: FSDP2MixedPrecisionSettings
274276
reshard_after_forward: bool = True
@@ -293,7 +295,7 @@ def validate_dp_mesh_existence(self):
293295

294296

295297
class DebuggingEnrichedModelConfig(BaseModel):
296-
model: PydanticPytorchModuleType
298+
model: PydanticPytorchModuleOrListType
297299
logging_dir_path: Path
298300
tracked_ranks: Optional[Set[int]] = None
299301
log_interval_steps: Optional[int] = 1
@@ -306,7 +308,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None:
306308

307309

308310
class GPT2ModelTPConfig(BaseModel):
309-
model: PydanticPytorchModuleType # TODO set proper type
311+
model: PydanticPytorchModuleOrListType # TODO set proper type
310312
device_mesh: PydanticDeviceMeshIFType
311313

312314
@model_validator(mode="after")
@@ -338,7 +340,7 @@ class CompiledModelConfig(BaseModel):
338340

339341

340342
class WeightInitializedModelConfig(BaseModel):
341-
model: PydanticPytorchModuleType
343+
model: PydanticPytorchModuleOrListType
342344
model_initializer: PydanticModelInitializationIFType
343345

344346
# avoid warning about protected namespace 'model_', see
@@ -363,12 +365,12 @@ class SelectiveOpACParams(BaseModel):
363365

364366
ac_variant: ActivationCheckpointingVariants
365367
layers_fqn: str
366-
model: PydanticPytorchModuleType
368+
model: PydanticPytorchModuleOrListType
367369
ac_fun_params: FullACParams | SelectiveLayerACParams | SelectiveOpACParams
368370

369371

370372
class RawAppStateConfig(BaseModel):
371-
model: PydanticPytorchModuleType
373+
model: PydanticPytorchModuleOrListType
372374
optimizer: PydanticOptimizerIFType
373375
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None
374376

@@ -493,12 +495,13 @@ class RichResultSubscriberConfig(BaseModel):
493495
global_rank: int
494496

495497

498+
@add_deprecated_alias("model_parts", "wrapped_model")
496499
class GPT2MFUCalculatorConfig(BaseModel):
497500
n_layer: Annotated[int, Field(strict=True, gt=0)]
498501
sequence_length: Annotated[int, Field(strict=True, gt=0)]
499502
n_embd: Annotated[int, Field(strict=True, gt=0)]
500503
world_size: Annotated[int, Field(strict=True, gt=0)]
501-
wrapped_model: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType
504+
model_parts: PydanticFSDP1ModuleType | PydanticFSDP2ModuleType | list[PydanticFSDP2ModuleType]
502505
device_mesh: Optional[PydanticDeviceMeshIFType] = None
503506

504507

src/modalities/config/pydantic_if_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __get_pydantic_core_schema__(
6767
CheckpointSavingExecutionABC, PydanticThirdPartyTypeIF(CheckpointSavingExecutionABC)
6868
]
6969
PydanticPytorchModuleType = Annotated[nn.Module, PydanticThirdPartyTypeIF(nn.Module)]
70+
PydanticPytorchModuleOrListType = PydanticPytorchModuleType | list[PydanticPytorchModuleType]
7071
PydanticFSDP1ModuleType = Annotated[FSDP1, PydanticThirdPartyTypeIF(FSDP1)]
7172
PydanticFSDP2ModuleType = Annotated[FSDP2, PydanticThirdPartyTypeIF(FSDP2)]
7273
PydanticTokenizerIFType = Annotated[TokenizerWrapper, PydanticThirdPartyTypeIF(TokenizerWrapper)]

0 commit comments

Comments
 (0)