Skip to content

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Feb 9, 2026

Summary

Adds a Mixtral Mixture-of-Experts model to bionemo-recipes/models/mixtral/ using TransformerEngine, following the same pattern as the existing llama3
model.

MoE Implementation

The core MoE block (NVMixtralSparseMoeBlock) uses:

  • te.GroupedLinear for efficient parallel expert FFN computation (gate_up + down projections)
  • te.moe_permute / te.moe_unpermute with map_type="index" for token-to-expert routing
  • Standard nn.Linear router (kept in bf16 during FP8 training via te.autocast(enabled=False))
  • SwiGLU activation with fused gate/up projection split

Weight conversion handles the structural difference between HF's stacked 3D expert tensors ([num_experts, out, in]) and TE's per-expert GroupedLinear
weights (weight0, weight1, ...).

Base test class improvements

  • Added clear_gpu_memory fixture (gc + cuda cache clear before/after each test) to prevent OOM cascading
  • get_converted_te_model_checkpoint now frees the HF model and moves TE to CPU before saving (save_pretrained clones state dict internally)
  • test_golden_values and test_golden_values_thd now run models sequentially to support large models

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Mixtral model implementation with Transformer Engine backend support
    • Introduced data collation and context-parallel batching utilities for efficient model training
    • Added model conversion tools to seamlessly convert between HuggingFace and Transformer Engine formats
  • Improvements

    • Enhanced GPU memory management for large model testing with automatic cleanup
    • Added comprehensive test framework for model validation and compatibility checks
  • Tests

    • Introduced BaseModelTest class and testing utilities for standardized model validation across formats

Signed-off-by: Peter St. John <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 9, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Signed-off-by: Peter St. John <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 17

🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 98-180: The __call__ method ignores the return_tensors parameter;
validate that return_tensors is either None or "pt" at the start of __call__ and
raise NotImplementedError for other values, and forward return_tensors into
downstream calls so tensor backend is consistent: call self.collator(features,
return_tensors=return_tensors) instead of self.collator(features) and call
_pt_flatten_collate(features, return_position_ids=self.return_position_ids,
return_tensors=return_tensors) (or the appropriate flatten helper that accepts
return_tensors); keep the rest of the logic (masked_input_ids/labels,
separator_id handling, and padding via
_pad_batch_to_multiple_of/_pad_sequences_to_be_divisible_by) unchanged.
- Around line 305-351: In the __call__ method of the collator, the computation
batch_shard["max_length_k"] = batch_shard["max_length_q"] = max_length *
round(max_length / 64) is incorrect; replace it with a ceil-to-multiple-of-64
calculation (e.g. compute padded_max = ((max_length + 63) // 64) * 64) and
assign padded_max to both batch_shard["max_length_k"] and
batch_shard["max_length_q"] so the result is the next multiple of 64 without
floating point rounding or inflation.
- Around line 238-277: The iterator currently calls
_split_sample_by_num_tokens(sample, tokens_available) when tokens_available can
be 0 (causing an error) and can yield an empty batch when a single sample
exceeds max_tokens_per_batch with split_samples=False; fix __iter__ to first
check tokens_available <= 0 and in that case yield the current samples and reset
samples before handling the incoming sample, and then reprocess the incoming
sample: if split_samples is True, split the sample into chunks of size up to
max_tokens_per_batch (call _split_sample_by_num_tokens in a loop using chunk
size = max_tokens_per_batch) and yield/fill batches accordingly; if
split_samples is False, ensure you never yield an empty batch by appending the
oversized sample as its own batch (or yield it immediately) instead of yielding
samples=[]; update references in __iter__ to use tokens_available guard,
_split_sample_by_num_tokens, split_samples, max_tokens_per_batch, samples and
current_length.

In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 136-137: The current filtering uses
MixtralConfig.__init__.__code__.co_varnames which is fragile because co_varnames
contains locals too; update the logic that builds valid_keys (used to create
filtered_config from te_config_dict) to derive parameter names from
MixtralConfig.__init__ via inspect.signature (e.g.,
inspect.signature(MixtralConfig.__init__).parameters) and then filter
te_config_dict by those parameter names so only actual constructor args are
preserved.

In `@bionemo-recipes/models/mixtral/export.py`:
- Line 53: The copy call uses a relative path so it breaks when the working
directory isn't the file's folder; change the source to be anchored to this
module by resolving Path(__file__).parent / "modeling_mixtral_te.py" and pass
that resolved path as the first argument to shutil.copy when copying to
export_path (the existing export_path / "modeling_mixtral_te.py" destination can
remain); update the code that calls shutil.copy accordingly (referencing
shutil.copy and the filename "modeling_mixtral_te.py").
- Around line 22-37: The export_hf_checkpoint function is creating a randomly
initialized model by calling AutoModelForCausalLM.from_config; replace that with
loading the actual pretrained weights by calling
AutoModelForCausalLM.from_pretrained using the same tag (keep or remove the
separate AutoConfig.from_pretrained as needed), so model_hf holds the real
checkpoint weights before export; update any related uses of model_hf and ensure
tokenizer/config are loaded from_pretrained as well if required.

In `@bionemo-recipes/models/mixtral/modeling_mixtral_te.py`:
- Line 453: The module sets torch._dynamo.config.capture_scalar_outputs = True
at import time, which mutates global TorchDynamo state and can affect other
code; change this to a local, temporary setting or document it: in the functions
that require this behavior (e.g., the compile/optimization entry points in
modeling_mixtral_te), save the current
torch._dynamo.config.capture_scalar_outputs value, set it to True only for the
scope where you call torch.compile/torch._dynamo operations, then restore the
original value in a finally block (or use a small context manager) so the global
config isn't mutated at module load; alternatively, add a clear comment in the
module/top-level docstring explaining the global requirement if a global change
is unavoidable.

In `@bionemo-recipes/models/mixtral/requirements.txt`:
- Around line 1-5: The requirements file currently leaves torch, transformers,
transformer_engine[pytorch], torchao, and lm-eval unpinned which causes
reproducibility and compatibility issues; update the requirements (the entries
referencing torch, transformers, transformer_engine[pytorch], torchao!=0.14.0,
and lm-eval) to pin explicit compatible versions (or reference a constraints
file) that match the project's tested torch version matrix and known working
transformer_engine and torchao releases, ensuring transformer_engine is the
pytorch build and avoiding the excluded torchao 0.14.0 conflict; provide exact
version specs consistent with the root project strategy so installs are
deterministic.

In `@bionemo-recipes/models/mixtral/state.py`:
- Line 321: Fix the typo in the logger.error message: change "Enountered" to
"Encountered" in the error string used in the code that logs IndexError during
transform (the logger.error call that includes source_matches and
target_matches); update the f-string to read logger.error(f"Encountered
IndexError during transform.\n{source_matches=}\n{target_matches=}") so the
message is spelled correctly while preserving the variables and formatting.
- Around line 70-72: The dataclass/module-level defaults use mutable lists
(transforms, state_dict_ignored_entries); change their default values from [] to
None and initialize them to empty lists at construction (e.g., in __post_init__
of the dataclass or where the object is created) to avoid shared mutable state;
keep cast_dtype as Optional[torch.dtype] = None unchanged and ensure any code
referencing transforms or state_dict_ignored_entries handles the None-to-list
initialization (refer to the symbols transforms, state_dict_ignored_entries,
TransformCTX, cast_dtype).

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Around line 1-29: Remove the duplicated license header blocks and leave a
single, consistent header: keep one SPDX copyright line with the correct year
(use 2026), set SPDX-License-Identifier to "Apache-2.0", and retain the standard
Apache-2.0 license text that follows; delete the other entire license block so
only one header remains at the top of the file.

In `@bionemo-recipes/models/mixtral/tests/common/fixtures.py`:
- Around line 141-143: Replace the direct environment deletions with safe
removals using os.environ.pop for keys "NVTE_FUSED_ATTN" and "NVTE_FLASH_ATTN"
to avoid KeyError if they were never set; keep the existing update to
_attention_backends["backend_selection_requires_update"] = True unchanged so the
backend refresh still occurs. Use os.environ.pop("NVTE_FUSED_ATTN", None) and
os.environ.pop("NVTE_FLASH_ATTN", None) in the teardown (where the current del
os.environ[...] calls are) to safely remove the variables.
- Line 66: Replace the unsafe deletion del os.environ["NVTE_DEBUG"] in the
fixtures cleanup with a safe pop call so removing the NVTE_DEBUG env var cannot
raise KeyError; locate the occurrence of del os.environ["NVTE_DEBUG"] in
bionemo-recipes/models/mixtral/tests/common/fixtures.py (the teardown/cleanup
for the test fixture) and use os.environ.pop("NVTE_DEBUG", None) to silently
handle the case where the variable is already absent.
- Around line 1-30: The file contains two SPDX/Apache-2.0 license header blocks
with conflicting years; remove the second duplicate header (the block that
begins with "SPDX-FileCopyrightText: Copyright (c) 2025" and its following
Apache-2.0 license text) so only the first header (the 2026 SPDX header)
remains; ensure no extra blank lines are left at the top after removal.

In `@bionemo-recipes/models/mixtral/tests/common/README.md`:
- Around line 7-13: The fenced code block in README.md under tests/common is
missing a language marker (MD040); update the snippet in that file (the
three-line tree block within the fenced code) to include a language tag such as
"text" or "plaintext" after the opening backticks so linting passes (e.g.,
change ``` to ```text).

In `@bionemo-recipes/models/mixtral/tests/common/test_modeling_common.py`:
- Around line 316-354: The HF API calls use the wrong kwarg name: replace the
"dtype" param with "torch_dtype" in both upstream_class.from_pretrained(...)
inside get_reference_model and AutoConfig.from_pretrained(...) inside
get_reference_model_no_weights so the requested precision (e.g., torch.bfloat16
or torch.float32) is honored; keep the existing attn_implementation and revision
logic and leave model.to("cuda") as-is.
- Around line 34-36: The HAS_DATA_CENTER_GPU probe calls
torch.cuda.get_device_name(0) unguarded which raises in CPU-only environments;
update the definition of HAS_DATA_CENTER_GPU to first check
torch.cuda.is_available() (and optionally wrap the probe in a try/except for
RuntimeError/AssertionError) and only call torch.cuda.get_device_name(0) when
CUDA is available, otherwise set HAS_DATA_CENTER_GPU to False; modify the
variable in the test module (the HAS_DATA_CENTER_GPU assignment) to implement
this guard so imports do not fail on CPU-only systems.
🧹 Nitpick comments (6)
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)

104-105: Remove redundant pad_token check.

This check duplicates lines 69-70 in get_tokenizer(). Since get_test_input_data calls self.get_tokenizer() at line 97, the pad_token is already set.

♻️ Proposed fix
         data_collator = DataCollatorForLanguageModeling(
             tokenizer=tokenizer,
             pad_to_multiple_of=pad_to_multiple_of,
             mlm=False,
         )

-        if tokenizer.pad_token is None:
-            tokenizer.pad_token = tokenizer.eos_token
-
         if format == "thd":
bionemo-recipes/models/mixtral/modeling_mixtral_te.py (1)

380-401: Consider standard import for GenerationMixin.

The dynamic import __import__("transformers").GenerationMixin works but is unconventional. A standard import would be clearer.

♻️ Proposed fix
+from transformers import GenerationMixin
+
 ...
 
-class NVMixtralForCausalLM(NVMixtralPreTrainedModel, __import__("transformers").GenerationMixin):
+class NVMixtralForCausalLM(NVMixtralPreTrainedModel, GenerationMixin):
bionemo-recipes/models/mixtral/convert.py (1)

60-71: Consider alternative to exec() for dynamic function creation.

Using exec() works here but is harder to debug and understand. A simpler approach using *args:

♻️ Proposed alternative
 def _make_merge_experts_fn(num_experts: int):
-    """Create a merge function with the correct number of named parameters.
-
-    The state.py transform system maps function parameter names to source keys, so we need a function
-    with exactly `num_experts` named parameters (weight0, weight1, ...).
-    """
-    param_names = [f"weight{i}" for i in range(num_experts)]
-    code = f"def merge_experts({', '.join(param_names)}):\n    return torch.stack([{', '.join(param_names)}])"
-    local_ns = {"torch": torch}
-    exec(code, local_ns)
-    return local_ns["merge_experts"]
+    """Create a merge function that stacks expert weights."""
+    def merge_experts(*weights):
+        if len(weights) != num_experts:
+            raise ValueError(f"Expected {num_experts} weights, got {len(weights)}")
+        return torch.stack(weights)
+    return merge_experts

Note: This assumes the state transform system supports *args. If named parameters are strictly required by the transform system, the current exec() approach is acceptable but should be documented.

bionemo-recipes/models/mixtral/tests/common/fixtures.py (1)

62-63: Remove redundant import os.

os is already imported at line 33.

♻️ Proposed fix
 def use_te_debug():
     """Auto-use fixture to enable TransformerEngine debugging.
 
     This fixture automatically enables debug mode for TransformerEngine
     in all tests for better error messages.
     """
-    import os
-
     os.environ["NVTE_DEBUG"] = "1"
bionemo-recipes/models/mixtral/state.py (2)

161-161: Use logger instead of print for unexpected keys.

Consistency with the rest of the module which uses logger.debug, logger.warning, etc.

♻️ Proposed fix
-            print(f"Unexpected key: {name} not in target model but is in source model.")
+            logger.warning(f"Unexpected key: {name} not in target model but is in source model.")

265-265: Avoid global side effect from np.set_printoptions.

This modifies global numpy print settings on every transform call. If needed for debugging, consider using a context manager or removing it.

♻️ Proposed fix
-        np.set_printoptions(threshold=10)

If this is needed for debugging specific issues, use np.printoptions(threshold=10) as a context manager around the specific debug output instead.

Signed-off-by: Peter St. John <[email protected]>
Comment on lines +142 to +147
@pytest.mark.skip(
reason="MoE routing is batch-dependent: padding tokens in BSHD affect softmax normalization "
"and top-k expert selection, so BSHD and THD produce fundamentally different routing decisions."
)
def test_golden_values_thd(self, te_attn_backend):
"""Skip: BSHD vs THD comparison is not meaningful for MoE models."""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct?

Signed-off-by: Peter St. John <[email protected]>
@pstjohn
Copy link
Collaborator Author

pstjohn commented Feb 11, 2026

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 11, 2026

✅ Actions performed

Comments resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant