Skip to content

Commit de06a34

Browse files
minitutimmoon10
andauthored
Add NVTX ranges to FP8 amax AR and grad output preprocessing (NVIDIA#1530)
Add NVTX ranges Signed-off-by: Jaemin Choi <jaeminc@nvidia.com> Co-authored-by: Jaemin Choi <jaeminc@nvidia.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent 13bd745 commit de06a34

2 files changed

Lines changed: 8 additions & 0 deletions

File tree

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ def backward(
522522

523523
if ctx.grad_output_quantizer is not None:
524524
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
525+
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
525526
(
526527
grad_output,
527528
grad_bias,
@@ -531,6 +532,7 @@ def backward(
531532
ctx.parallel_mode == "row",
532533
ctx.grad_output_quantizer,
533534
)
535+
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
534536

535537
# Prepare GEMM input
536538
# Note: Perform tensor-parallel communication if needed
@@ -747,7 +749,9 @@ def backward(
747749
wgrad = None
748750

749751
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
752+
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
750753
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
754+
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
751755

752756
# Scatter fp8 weight buffers
753757
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):

transformer_engine/pytorch/module/linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
427427
# Note: Cast to expected dtype and perform tensor-parallel communication
428428
if ctx.grad_output_quantizer is not None:
429429
ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True)
430+
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
430431
(
431432
grad_output,
432433
grad_bias,
@@ -436,6 +437,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
436437
ctx.parallel_mode == "row",
437438
ctx.grad_output_quantizer,
438439
)
440+
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
439441

440442
# Prepare input tensor
441443
# Note: Perform tensor-parallel communication if needed
@@ -623,7 +625,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
623625
wgrad = None
624626

625627
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
628+
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
626629
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
630+
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
627631

628632
# Scatter fp8 weight buffers
629633
if ctx.fp8 and not isinstance(weight, QuantizedTensor):

0 commit comments

Comments
 (0)