-
Notifications
You must be signed in to change notification settings - Fork 121
Llama3 Context Parallel Fixes and Performance edits #1461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Llama3 Context Parallel Fixes and Performance edits #1461
Conversation
|
Important Review skippedAuto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (1)
50-87:⚠️ Potential issue | 🟡 MinorWiden
grad_normtype annotation to accept bothTensorandfloat.Training code passes
floattolog_step()(via.item()on the gradient norm), but the function signature expectstorch.Tensor. This causes a type mismatch that Pyright will flag. Update the annotation tograd_norm: torch.Tensor | float.Additionally,
min_lossis returned as a raw tensor in training entrypoints; consider adding a property accessor for scalar access:♻️ Suggested changes
- grad_norm: torch.Tensor, + grad_norm: torch.Tensor | float,and optionally:
+ `@property` + def min_loss_value(self) -> float: + """Return min_loss as a Python float for external consumption.""" + return float(self.min_loss.item())bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (6)
88-101:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.The
--standaloneflag appears twice (lines 91 and 94). This is likely unintentional and may cause issues with torchrun.🐛 Proposed fix to remove duplicate flag
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node", "2", # 2 processes = 2 GPUs - "--standalone", # Single node mode "train_ddp.py", "--config-name", "L0_sanity", "num_train_steps=4", # Just 4 steps for speed ], recipe_path, )
116-129:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.Same issue as above - duplicate
--standaloneflags.🐛 Proposed fix
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node", "2", # 2 processes = 2 GPUs - "--standalone", # Single node mode "train_fsdp2.py",
141-157:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.Same duplicate flag issue in the checkpointing test.
🐛 Proposed fix
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node", "2", - "--standalone", "train_ddp.py",
174-190:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.Same duplicate flag issue.
🐛 Proposed fix
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node", "2", - "--standalone", "train_fsdp2.py",
200-218:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.Same duplicate flag issue in the BSHD CP test.
🐛 Proposed fix
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node=2", - "--standalone", "train_fsdp2_cp.py",
224-242:⚠️ Potential issue | 🟡 MinorDuplicated
--standaloneflag in torchrun command.Same duplicate flag issue in the THD CP test.
🐛 Proposed fix
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node=2", - "--standalone", "train_fsdp2_cp.py",
🤖 Fix all issues with AI agents
In `@bionemo-recipes/recipes/llama3_native_te/Dockerfile`:
- Around line 2-6: The Dockerfile currently hardcodes an internal base image
(gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-base)
which breaks external builds; replace the fixed FROM with a build ARG pattern so
the default is the public NVIDIA image (e.g., nvidia/pytorch:26.01-py3) but CI
can override with the internal image. Concretely, add an ARG like BASE_IMAGE
with the public image as the default and change the FROM to use that ARG (refer
to the existing hardcoded image string and the commented public image name in
the file), keeping the explanatory note intact; this mirrors the codonfm_ptl_te
recipe pattern so external users can build while CI can pass the internal
registry value.
🧹 Nitpick comments (4)
bionemo-recipes/models/llama3/collator.py (1)
504-512: Broad exception handling may silently swallow errors.Catching all exceptions and converting them to
StopIterationcould hide important failures (e.g., CUDA errors, assertion failures). Consider logging the exception before signaling stop.♻️ Proposed fix to log exceptions
def _do_one_prefetch(self): """Fetch one batch in the background. Stores result in _prefetch_result.""" if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: + except Exception as e: # Process group may have been destroyed; signal stop. + logger.debug("Prefetch exception (may be expected at shutdown): %s", e) self._prefetch_result = StopIteration()bionemo-recipes/recipes/esm2_native_te/collator.py (2)
504-512: Same broad exception handling concern as llama3/collator.py.Consider logging the exception before converting to StopIteration.
♻️ Proposed fix to log exceptions
def _do_one_prefetch(self): """Fetch one batch in the background. Stores result in _prefetch_result.""" if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: + except Exception as e: # Process group may have been destroyed; signal stop. + logger.debug("Prefetch exception (may be expected at shutdown): %s", e) self._prefetch_result = StopIteration()
1-948: Significant code duplication across collator files.This file is nearly identical to
bionemo-recipes/models/llama3/collator.pyandbionemo-recipes/models/esm2/src/esm/collator.py. Consider consolidating into a shared module to reduce maintenance burden.bionemo-recipes/recipes/llama3_native_te/collator.py (1)
488-497: Consider adding thread-safety documentation or synchronization.The current implementation relies on the GIL and the fact that
_prefetch_thread.join()completes before accessing_prefetch_result. While this is correct in practice, the shared mutable state (_prefetch_result) accessed from both the main thread and the prefetch thread could benefit from explicit documentation noting the synchronization invariant (join completes before result access).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
bionemo-recipes/recipes/esm2_native_te/perf_logger.py (1)
45-51:⚠️ Potential issue | 🟠 MajorFix
min_lossreturn type annotation mismatch.
main()is annotated to returnfloat | None(line 48 in train_fsdp2.py) but actually returnsperf_logger.min_loss, which is atorch.Tensor(initialized as a CUDA tensor on line 50 of perf_logger.py and updated viatorch.minimum()on line 103). This type mismatch affects all training scripts in this recipe (train_fsdp2.py, train_mfsdp.py, train_fsdp2_cp.py, train_ddp_cp.py, train_ddp.py).Either convert
min_lossto a scalar before returning (e.g.,perf_logger.min_loss.item()) or update the return type annotation totorch.Tensor.bionemo-recipes/models/llama3/collator.py (1)
282-305:⚠️ Potential issue | 🟠 MajorGuard split_samples when padding inflates per-sequence length.
When padding makes
tokens_available≥ raw sample length (possible if max_tokens_per_batch isn’t divisible by the pad multiple),_split_sample_by_num_tokenscan raise. Falling back to starting a new batch avoids hard failures in that configuration.🛠️ Proposed fix
else: # Calculate how many padded tokens are already in the batch tokens_in_batch = current_length - self._padded_len(len(sample["input_ids"])) # Calculate how many tokens we can fit from this sample tokens_available = self.max_tokens_per_batch - tokens_in_batch - first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) - yield [*samples, first_part] - samples = [remaining_part] + sample_len = len(sample["input_ids"]) + if tokens_available <= 0 or tokens_available >= sample_len: + if samples: + yield samples + samples = [sample] + else: + first_part, remaining_part = _split_sample_by_num_tokens(sample, tokens_available) + yield [*samples, first_part] + samples = [remaining_part] current_length = self._padded_len(len(samples[0]["input_ids"]))bionemo-recipes/recipes/llama3_native_te/train_ddp.py (1)
136-139:⚠️ Potential issue | 🟠 MajorRemove the per-batch
Printing every batch will bottleneck I/O and spam logs, especially in multi-GPU runs. Prefer gated debug logging or remove it entirely.
🔧 Suggested fix
- print(batch["input_ids"].shape) + if dist_config.local_rank == 0 and logger.isEnabledFor(logging.DEBUG): + logger.debug("batch input_ids shape: %s", batch["input_ids"].shape)bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py (1)
401-419:⚠️ Potential issue | 🟡 MinorFail fast if rank-0 scatter payload never arrives.
data_ready.wait(timeout=5)returns a boolean that’s ignored; a timeout will surface as a crypticKeyErrorlater. Add an explicit check for clearer failures (apply to both occurrences).🛠️ Suggested fix
- data_ready.wait(timeout=5) - scatter_object_output_list[0] = scatter_payload["data"][1] + if not data_ready.wait(timeout=5): + raise AssertionError("Timed out waiting for rank 0 scatter payload") + scatter_object_output_list[0] = scatter_payload["data"][1]bionemo-recipes/recipes/llama3_native_te/tests/test_train_two_gpu.py (1)
88-99:⚠️ Potential issue | 🟡 MinorRemove duplicated
--standaloneflags in all torchrun invocations.
--standaloneappears twice in six torchrun command lists across this file (lines 91–94, 119–122, 144–147, 177–180, 203–205, 227–229). Keep only one instance per command to improve clarity; while argparse treats duplicates as redundant, they are confusing to readers.Suggested fix for lines 88–99
run_train_cmd( [ "torchrun", "--standalone", "--nproc_per_node", "2", # 2 processes = 2 GPUs - "--standalone", # Single node mode "train_ddp.py",Apply the same removal to the other five occurrences in this file.
🤖 Fix all issues with AI agents
In `@bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py`:
- Around line 1071-1072: The test instantiation uses an invalid init parameter
cp_world_size for DataCollatorForContextParallel (which defines cp_world_size as
field(init=False)); replace that argument by passing
device_mesh=_DummyCollatorMesh(cp_size=cp_world_size) when constructing
DataCollatorForContextParallel (keep collator=base_collator and
qkv_format="thd"), so locate the line creating cp_collator and swap
cp_world_size=cp_world_size for
device_mesh=_DummyCollatorMesh(cp_size=cp_world_size).
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 488-513: The __next__ implementation must guard against
_prefetch_thread being None and must surface background exceptions instead of
swallowing them; change __next__ to check if self._prefetch_thread is not None
before calling join and to re-raise any Exception objects stored in
self._prefetch_result; modify _do_one_prefetch to store actual exceptions (e.g.,
the Exception instance) into self._prefetch_result rather than converting
everything to StopIteration, and reserve StopIteration only for genuine
iteration termination returned by _send_data_to_cp_tp_ranks; keep using
_kick_prefetch to start the thread, ensure _send_data_to_cp_tp_ranks remains the
producer of StopIteration for end-of-iteration, and continue to set
torch.cuda.set_device(self._cuda_device) when _cuda_device is not None.
In `@bionemo-recipes/recipes/esm2_native_te/perf_logger.py`:
- Around line 98-138: metrics.compute() can return GPU tensors which break
formatting and wandb; after calling metrics = self.metrics.compute() (and before
self.metrics.reset(), wandb.log, and logger.info), convert any tensor values to
host Python scalars (e.g., v.detach().cpu().item() for scalar tensors) or to CPU
tensors as appropriate, replacing entries in the metrics dict with those CPU
scalars so wandb.log(metrics, step=step) and the logger.info(",
".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()])) formatting
work without errors.
In `@bionemo-recipes/recipes/llama3_native_te/collator.py`:
- Around line 296-304: The split logic can overfill when
pad_sequences_to_be_divisible_by doesn't divide max_tokens_per_batch; modify the
splitting branch (where tokens_available is computed) to ensure the padded
length of the chosen split fits the remaining capacity: after computing
tokens_available = self.max_tokens_per_batch - tokens_in_batch, reduce
tokens_available (e.g., decrement in a loop) until
self._padded_len(tokens_available) <= self.max_tokens_per_batch -
tokens_in_batch (or zero), then call _split_sample_by_num_tokens with that
adjusted tokens_available; alternatively, you can detect the incompatible
configuration in __post_init__ (when split_samples=True and max_tokens_per_batch
% pad_sequences_to_be_divisible_by != 0) and raise a clear error — reference
functions/fields: _split_sample_by_num_tokens, _padded_len,
max_tokens_per_batch, pad_sequences_to_be_divisible_by, split_samples, and
__post_init__.
In `@bionemo-recipes/recipes/llama3_native_te/perf_logger.py`:
- Around line 110-121: The sampling mismatch is that running_loss and
grad_acc_step_count are updated every micro-step (see running_loss and
grad_acc_step_count) while num_tokens and num_unpadded_tokens are only updated
at logging intervals (controlled by logging_frequency and step); fix by making
sampling consistent: either move the num_tokens and num_unpadded_tokens
increments to the same micro-step scope so they are updated every micro-step
using batch["input_ids"] and batch.get("attention_mask") (or its fallback) or
only increment grad_acc_step_count when you actually update token counters at
the logging interval; update comments to document the chosen behavior and
reference running_loss, grad_acc_step_count, num_tokens, num_unpadded_tokens,
logging_frequency, and batch so reviewers can find the change.
🧹 Nitpick comments (1)
bionemo-recipes/recipes/llama3_native_te/collator.py (1)
504-512: Broad exception handling may hide real errors.The
except Exceptionblock at lines 510-512 catches all exceptions and converts them toStopIteration. While the comment mentions process group destruction, this could mask other errors (e.g., CUDA OOM, data corruption). Consider either:
- Catching specific exceptions like
RuntimeErrorfrom distributed ops- Logging the exception before converting to
StopIteration♻️ Proposed fix to log exceptions
def _do_one_prefetch(self): """Fetch one batch in the background. Stores result in _prefetch_result.""" if self._cuda_device is not None: torch.cuda.set_device(self._cuda_device) try: self._prefetch_result = self._send_data_to_cp_tp_ranks() - except Exception: + except Exception as e: # Process group may have been destroyed; signal stop. + logger.debug("Prefetch exception (treating as end of iteration): %s", e) self._prefetch_result = StopIteration()
bionemo-recipes/models/esm2/tests/test_collator_context_parallel.py
Outdated
Show resolved
Hide resolved
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
| self.pad_sequences_to_be_divisible_by is not None | ||
| and self.max_tokens_per_batch % self.pad_sequences_to_be_divisible_by != 0 | ||
| ): | ||
| logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it just fills it as much as possible but maintains the divisibility factor?
| """Return the padded length of a sequence, rounding up to the nearest multiple of pad_sequences_to_be_divisible_by.""" | ||
| if self.pad_sequences_to_be_divisible_by is None: | ||
| return length | ||
| return -(-length // self.pad_sequences_to_be_divisible_by) * self.pad_sequences_to_be_divisible_by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whats with the -(-X) ?
| tokens_available = self.max_tokens_per_batch - tokens_in_batch | ||
| if self.pad_sequences_to_be_divisible_by is not None: | ||
| d = self.pad_sequences_to_be_divisible_by | ||
| tokens_available = (tokens_available // d) * d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why do we need d here? Can't we just use self.pad_seq... or is it ugly because its gonna be too long a line?
| pad_between_seqs: bool | ||
|
|
||
|
|
||
| @nvtx.annotate("collator._scatter_batch_to_cp_tp_ranks", color="green") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decorator at runtime does nothing if we aren't doing profiling right?
jomitchellnv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but you may wanna move some of the changes related to prefetch into the other MR
Signed-off-by: Peter St. John <[email protected]>
e424e40 to
008787a
Compare
A collection of small performance improvements and bugfixes for llama3 CP training
Summary by CodeRabbit
Release Notes
New Features
Documentation
Refactor