diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..8efe6a53b5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1529,6 +1529,16 @@ def backward_dw(self): for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: wgrad_accumulation_and_reduce_hook() + def get_backward_dw_params(self): + """ + Get the parameters for the backward weight gradient computation. + """ + params = [] + params.append(noop_cat(self._get_weight_tensors())) + if self.use_bias: + params.append(noop_cat([getattr(self, name) for name in self.bias_names])) + return params + def is_debug_iter(self) -> bool: """ This function checks if the debug should be enabled for this layer.