Skip to content

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Summary

Adds get_backward_dw_params() API to TransformerEngineBaseModule to support CUDA graph execution in Megatron-LM. This method returns the parameters involved in delayed weight gradient computation, enabling proper hook registration for gradient reduction operations.

Key changes:

  • New get_backward_dw_params() method that mirrors the parameter retrieval logic from backward_dw()
  • Returns concatenated weight and bias parameters using noop_cat() helper
  • Enables Megatron-LM to register hooks on these parameters before CUDA graph execution
  • Fixes issue where backward_post_hook was being discarded during wgrad CUDA graph execution, causing parameters to skip gradient reduce

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation is simple and correct: it adds a single utility method that extracts and returns parameters using the same logic already validated in backward_dw(). The method uses noop_cat() which is a well-tested helper function throughout the codebase. No new logic or side effects are introduced. The change is narrowly scoped to modules that use delayed weight gradient computation (Linear, LayerNormLinear, etc.) which all have the required use_bias and bias_names attributes.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Added get_backward_dw_params() method that returns parameters for delayed weight gradient computation by concatenating weight and bias tensors using noop_cat(). Mirrors logic in backward_dw() method.

Sequence Diagram

sequenceDiagram
    participant ML as Megatron-LM
    participant TE as TE Module
    participant WS as wgrad_store
    participant Hooks as wgrad_hooks
    
    Note over ML,Hooks: CUDA Graph Execution Context
    
    ML->>TE: backward() [main backward pass]
    TE->>WS: store wgrad computation
    Note over WS: Delayed wgrad compute enabled
    
    ML->>ML: Execute wgrad CUDA graph
    ML->>TE: get_backward_dw_params()
    TE->>TE: noop_cat(weight_tensors)
    TE->>TE: noop_cat(bias_tensors) [if use_bias]
    TE-->>ML: Return [weight_param, bias_param]
    
    Note over ML: Parameter hooks registered on<br/>returned parameters
    
    ML->>TE: backward_dw()
    TE->>WS: pop wgrad, bgrad
    TE->>TE: weight.grad = wgrad
    TE->>TE: bias.grad = bgrad [if use_bias]
    TE->>Hooks: Execute wgrad_accumulation_and_reduce_hooks()
    Hooks-->>TE: Hook execution completes
    Note over Hooks: Hooks now properly fire<br/>for grad reduce
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant