-
Notifications
You must be signed in to change notification settings - Fork 121
Add Mixtral model for MoE demo #1458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Mixtral model for MoE demo #1458
Conversation
Signed-off-by: Peter St. John <[email protected]>
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing touches🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 98-180: The __call__ method ignores the return_tensors parameter;
validate that return_tensors is either None or "pt" at the start of __call__ and
raise NotImplementedError for other values, and forward return_tensors into
downstream calls so tensor backend is consistent: call self.collator(features,
return_tensors=return_tensors) instead of self.collator(features) and call
_pt_flatten_collate(features, return_position_ids=self.return_position_ids,
return_tensors=return_tensors) (or the appropriate flatten helper that accepts
return_tensors); keep the rest of the logic (masked_input_ids/labels,
separator_id handling, and padding via
_pad_batch_to_multiple_of/_pad_sequences_to_be_divisible_by) unchanged.
- Around line 305-351: In the __call__ method of the collator, the computation
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length *
round(max_length / 64) is incorrect; replace it with a ceil-to-multiple-of-64
calculation (e.g. compute padded_max = ((max_length + 63) // 64) * 64) and
assign padded_max to both batch_shard["max_length_k"] and
batch_shard["max_length_q"] so the result is the next multiple of 64 without
floating point rounding or inflation.
- Around line 238-277: The iterator currently calls
_split_sample_by_num_tokens(sample, tokens_available) when tokens_available can
be 0 (causing an error) and can yield an empty batch when a single sample
exceeds max_tokens_per_batch with split_samples=False; fix __iter__ to first
check tokens_available <= 0 and in that case yield the current samples and reset
samples before handling the incoming sample, and then reprocess the incoming
sample: if split_samples is True, split the sample into chunks of size up to
max_tokens_per_batch (call _split_sample_by_num_tokens in a loop using chunk
size = max_tokens_per_batch) and yield/fill batches accordingly; if
split_samples is False, ensure you never yield an empty batch by appending the
oversized sample as its own batch (or yield it immediately) instead of yielding
samples=[]; update references in __iter__ to use tokens_available guard,
_split_sample_by_num_tokens, split_samples, max_tokens_per_batch, samples and
current_length.
In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 136-137: The current filtering uses
MixtralConfig.__init__.__code__.co_varnames which is fragile because co_varnames
contains locals too; update the logic that builds valid_keys (used to create
filtered_config from te_config_dict) to derive parameter names from
MixtralConfig.__init__ via inspect.signature (e.g.,
inspect.signature(MixtralConfig.__init__).parameters) and then filter
te_config_dict by those parameter names so only actual constructor args are
preserved.
In `@bionemo-recipes/models/mixtral/export.py`:
- Line 53: The copy call uses a relative path so it breaks when the working
directory isn't the file's folder; change the source to be anchored to this
module by resolving Path(__file__).parent / "modeling_mixtral_te.py" and pass
that resolved path as the first argument to shutil.copy when copying to
export_path (the existing export_path / "modeling_mixtral_te.py" destination can
remain); update the code that calls shutil.copy accordingly (referencing
shutil.copy and the filename "modeling_mixtral_te.py").
- Around line 22-37: The export_hf_checkpoint function is creating a randomly
initialized model by calling AutoModelForCausalLM.from_config; replace that with
loading the actual pretrained weights by calling
AutoModelForCausalLM.from_pretrained using the same tag (keep or remove the
separate AutoConfig.from_pretrained as needed), so model_hf holds the real
checkpoint weights before export; update any related uses of model_hf and ensure
tokenizer/config are loaded from_pretrained as well if required.
In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Line 453: The module sets torch._dynamo.config.capture_scalar_outputs = True
at import time, which mutates global TorchDynamo state and can affect other
code; change this to a local, temporary setting or document it: in the functions
that require this behavior (e.g., the compile/optimization entry points in
modeling_mixtral_te), save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True only for the
scope where you call torch.compile/torch._dynamo operations, then restore the
original value in a finally block (or use a small context manager) so the global
config isn't mutated at module load; alternatively, add a clear comment in the
module/top-level docstring explaining the global requirement if a global change
is unavoidable.
In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements file currently leaves torch, transformers,
transformer_engine[pytorch], torchao, and lm-eval unpinned which causes
reproducibility and compatibility issues; update the requirements (the entries
referencing torch, transformers, transformer_engine[pytorch], torchao!=0.14.0,
and lm-eval) to pin explicit compatible versions (or reference a constraints
file) that match the project's tested torch version matrix and known working
transformer_engine and torchao releases, ensuring transformer_engine is the
pytorch build and avoiding the excluded torchao 0.14.0 conflict; provide exact
version specs consistent with the root project strategy so installs are
deterministic.
In `@bionemo-recipes/models/mixtral/state.py`:
- Line 321: Fix the typo in the logger.error message: change "Enountered" to
"Encountered" in the error string used in the code that logs IndexError during
transform (the logger.error call that includes source_matches and
target_matches); update the f-string to read logger.error(f"Encountered
IndexError during transform.\n{source_matches=}\n{target_matches=}") so the
message is spelled correctly while preserving the variables and formatting.
- Around line 70-72: The dataclass/module-level defaults use mutable lists
(transforms, state_dict_ignored_entries); change their default values from [] to
None and initialize them to empty lists at construction (e.g., in __post_init__
of the dataclass or where the object is created) to avoid shared mutable state;
keep cast_dtype as Optional[torch.dtype] = None unchanged and ensure any code
referencing transforms or state_dict_ignored_entries handles the None-to-list
initialization (refer to the symbols transforms, state_dict_ignored_entries,
TransformCTX, cast_dtype).
In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 1-29: Remove the duplicated license header blocks and leave a
single, consistent header: keep one SPDX copyright line with the correct year
(use 2026), set SPDX-License-Identifier to "Apache-2.0", and retain the standard
Apache-2.0 license text that follows; delete the other entire license block so
only one header remains at the top of the file.
In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 141-143: Replace the direct environment deletions with safe
removals using os.environ.pop for keys "NVTE_FUSED_ATTN" and "NVTE_FLASH_ATTN"
to avoid KeyError if they were never set; keep the existing update to
_attention_backends["backend_selection_requires_update"] = True unchanged so the
backend refresh still occurs. Use os.environ.pop("NVTE_FUSED_ATTN", None) and
os.environ.pop("NVTE_FLASH_ATTN", None) in the teardown (where the current del
os.environ[...] calls are) to safely remove the variables.
- Line 66: Replace the unsafe deletion del os.environ["NVTE_DEBUG"] in the
fixtures cleanup with a safe pop call so removing the NVTE_DEBUG env var cannot
raise KeyError; locate the occurrence of del os.environ["NVTE_DEBUG"] in
bionemo-recipes/models/mixtral/tests/common/fixtures.py (the teardown/cleanup
for the test fixture) and use os.environ.pop("NVTE_DEBUG", None) to silently
handle the case where the variable is already absent.
- Around line 1-30: The file contains two SPDX/Apache-2.0 license header blocks
with conflicting years; remove the second duplicate header (the block that
begins with "SPDX-FileCopyrightText: Copyright (c) 2025" and its following
Apache-2.0 license text) so only the first header (the 2026 SPDX header)
remains; ensure no extra blank lines are left at the top after removal.
In `@bionemo-recipes/models/mixtral/tests/common/README.md`:
- Around line 7-13: The fenced code block in README.md under tests/common is
missing a language marker (MD040); update the snippet in that file (the
three-line tree block within the fenced code) to include a language tag such as
"text" or "plaintext" after the opening backticks so linting passes (e.g.,
change ``` to ```text).
In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 316-354: The HF API calls use the wrong kwarg name: replace the
"dtype" param with "torch_dtype" in both upstream_class.from_pretrained(...)
inside get_reference_model and AutoConfig.from_pretrained(...) inside
get_reference_model_no_weights so the requested precision (e.g., torch.bfloat16
or torch.float32) is honored; keep the existing attn_implementation and revision
logic and leave model.to("cuda") as-is.
- Around line 34-36: The HAS_DATA_CENTER_GPU probe calls
torch.cuda.get_device_name(0) unguarded which raises in CPU-only environments;
update the definition of HAS_DATA_CENTER_GPU to first check
torch.cuda.is_available() (and optionally wrap the probe in a try/except for
RuntimeError/AssertionError) and only call torch.cuda.get_device_name(0) when
CUDA is available, otherwise set HAS_DATA_CENTER_GPU to False; modify the
variable in the test module (the HAS_DATA_CENTER_GPU assignment) to implement
this guard so imports do not fail on CPU-only systems.
🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)
104-105: Remove redundant pad_token check.This check duplicates lines 69-70 in
get_tokenizer(). Sinceget_test_input_datacallsself.get_tokenizer()at line 97, the pad_token is already set.♻️ Proposed fix
data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, pad_to_multiple_of=pad_to_multiple_of, mlm=False, ) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if format == "thd":bionemo-recipes/models/mixtral/modeling_mixtral_te.py (1)
380-401: Consider standard import for GenerationMixin.The dynamic import
__import__("transformers").GenerationMixinworks but is unconventional. A standard import would be clearer.♻️ Proposed fix
+from transformers import GenerationMixin + ... -class NVMixtralForCausalLM(NVMixtralPreTrainedModel, __import__("transformers").GenerationMixin): +class NVMixtralForCausalLM(NVMixtralPreTrainedModel, GenerationMixin):bionemo-recipes/models/mixtral/convert.py (1)
60-71: Consider alternative toexec()for dynamic function creation.Using
exec()works here but is harder to debug and understand. A simpler approach using*args:♻️ Proposed alternative
def _make_merge_experts_fn(num_experts: int): - """Create a merge function with the correct number of named parameters. - - The state.py transform system maps function parameter names to source keys, so we need a function - with exactly `num_experts` named parameters (weight0, weight1, ...). - """ - param_names = [f"weight{i}" for i in range(num_experts)] - code = f"def merge_experts({', '.join(param_names)}):\n return torch.stack([{', '.join(param_names)}])" - local_ns = {"torch": torch} - exec(code, local_ns) - return local_ns["merge_experts"] + """Create a merge function that stacks expert weights.""" + def merge_experts(*weights): + if len(weights) != num_experts: + raise ValueError(f"Expected {num_experts} weights, got {len(weights)}") + return torch.stack(weights) + return merge_expertsNote: This assumes the state transform system supports
*args. If named parameters are strictly required by the transform system, the currentexec()approach is acceptable but should be documented.bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)
62-63: Remove redundantimport os.
osis already imported at line 33.♻️ Proposed fix
def use_te_debug(): """Auto-use fixture to enable TransformerEngine debugging. This fixture automatically enables debug mode for TransformerEngine in all tests for better error messages. """ - import os - os.environ["NVTE_DEBUG"] = "1"bionemo-recipes/models/mixtral/state.py (2)
161-161: Use logger instead of print for unexpected keys.Consistency with the rest of the module which uses
logger.debug,logger.warning, etc.♻️ Proposed fix
- print(f"Unexpected key: {name} not in target model but is in source model.") + logger.warning(f"Unexpected key: {name} not in target model but is in source model.")
265-265: Avoid global side effect fromnp.set_printoptions.This modifies global numpy print settings on every transform call. If needed for debugging, consider using a context manager or removing it.
♻️ Proposed fix
- np.set_printoptions(threshold=10)If this is needed for debugging specific issues, use
np.printoptions(threshold=10)as a context manager around the specific debug output instead.
bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Peter St. John <[email protected]>
| @pytest.mark.skip( | ||
| reason="MoE routing is batch-dependent: padding tokens in BSHD affect softmax normalization " | ||
| "and top-k expert selection, so BSHD and THD produce fundamentally different routing decisions." | ||
| ) | ||
| def test_golden_values_thd(self, te_attn_backend): | ||
| """Skip: BSHD vs THD comparison is not meaningful for MoE models.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this correct?
Signed-off-by: Peter St. John <[email protected]>
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
Summary
Adds a Mixtral Mixture-of-Experts model to bionemo-recipes/models/mixtral/ using TransformerEngine, following the same pattern as the existing llama3
model.
MoE Implementation
The core MoE block (NVMixtralSparseMoeBlock) uses:
Weight conversion handles the structural difference between HF's stacked 3D expert tensors ([num_experts, out, in]) and TE's per-expert GroupedLinear
weights (weight0, weight1, ...).
Base test class improvements
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests