Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Jan 21, 2026

[WIP] torch.compile Support for te.ops API

This is an early draft - NOT ready for merge. I'm sharing this to show progress and encourage discussion on the design approach.

Overview

This PR introduces torch.compile(fullgraph=True) support for te.ops by wrapping fusion logic in custom operators. The key insight is to make fusion decisions invisible to torch.compile by encapsulating them inside opaque custom ops.

What are Opaque Types?

PyTorch's torch._library.opaque_object.register_opaque_type allows passing custom Python objects through torch.compile without causing graph breaks. During tracing, the object becomes a FakeScriptObject - torch.compile doesn't see inside it.

Two variants:

  • Reference types (typ="reference") - for mutable/stateful objects, passed as graph inputs at runtime
  • Value types (typ="value") - for immutable data, baked into graph as constants

This lets us hide complex TE logic (fusion decisions, FP8 recipes, operation containers) from the compiler while still being fully compilable.

Design Highlights

Opaque Types Used:

  • Recipe - reference type (delayed scaling mutates internal state)
  • OpsContainer - reference type, holds BasicOperation list, created outside compiled region
  • OpaqueKwargs - value type (immutable, baked into graph)
  • TensorInfo - value type for lightweight shape/dtype descriptors

Key Abstraction - pseudo_forward:

  • New method on BasicOperation for shape inference without actual computation
  • Called at compile time (via register_fake) for output shape prediction
  • Called at backward time to reconstruct ctx_data from saved TensorInfo
  • Eliminates need for mutable CtxContainer - deterministic reconstruction instead

Custom Operators

te_ops::fused_forward

def fused_forward_impl(
    x: torch.Tensor,
    ops_container: OpsContainer,
    recipe: Recipe,
    kwargs_opaque: OpaqueKwargs,
    params: list[torch.Tensor],
    extra_inputs: list[torch.Tensor],
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
    """
    Perform fused forward pass.
    Fusion logic happens inside run_forward (invisible to torch.compile).
    
    Returns: (output, tensors_to_save, extra_outputs)
    """

te_ops::fused_backward

def fused_backward_impl(
    grad_output: torch.Tensor,
    grad_extra_outputs: list[torch.Tensor],
    ops_container: OpsContainer,
    recipe: Recipe,
    kwargs_opaque: OpaqueKwargs,
    input_info: TensorInfo,
    extra_inputs_info: tuple[TensorInfo, ...],
    tensors_saved: list[torch.Tensor],
    params: list[torch.Tensor],
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
    """
    Perform fused backward pass.
    Fusion logic happens inside run_backward (invisible to torch.compile).
    
    Returns: (grad_input, grad_params, grad_extra_inputs)
    """

Integration with TE Sequential

te.Sequential groups consecutive FusibleOperations and wraps each group in an OperationFuser (see _make_module_groups() in sequential.py). The OperationFuser:

  1. Flattens FusibleOperations into BasicOperations
  2. Applies fusion passes in maybe_fuse_ops() (e.g., fuse_forward_linear_bias_activation, fuse_backward_activation_bias)
  3. Runs forward/backward via _OperationFuserAutogradFunction

Problem: OperationFuser is not compile-friendly - fusion decisions, dynamic op lists, and complex autograd logic cause graph breaks.

Solution: TorchCompileCompatibleFuser is a drop-in replacement for OperationFuser:

  • Wraps the same BasicOperation list in an OpsContainer (opaque to torch.compile)
  • Fusion logic (_fuse_forward_ops, _fuse_backward_ops) moves inside OpsContainer.run_forward/run_backward
  • Custom ops (te_ops::fused_forward/backward) expose a clean tensor-in-tensor-out interface
  • All fusion complexity is hidden from the compiler while preserving TE optimizations

pggPL and others added 3 commits January 19, 2026 21:58
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Comment on lines +12 to +13
if TYPE_CHECKING:
from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not import directly? compile_compat doesn't depend on anything in te.ops.basic, so if we import in the correct order then we shouldn't have circular dependencies.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of poor-quality ai-generated code here, so it's not worth to review this PR in detail. I think we need to have some agreement on high-level design and I will reimplement it from scratch. Maybe I will elaborate more why this PR works the way it works tomorrow.

"""

def pseudo_forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: pseudo_forward seems too vague to me. Something like torch_compile_forward or compile_forward is more intuitive.

Comment on lines +484 to +486
This default implementation provides basic shape propagation.
Subclasses should override if they have different output shapes
or need to save tensors for backward.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that we overwrite pseudo_forward in the bias op, which doesn't save any backward tensors. Do we need to overwrite if the ctx is non-trivial in any way? If that's the case, then that seems very burdensome to an external user writing a custom op outside of TE (see #2597).

If the the base class can't handle it automatically, how about we use an opt-in approach? If a fuser has ops that support torch.compile, we'll use the infrastructure in compile_compat. If any op does not support torch.compile, then we'll fall back to a naive impl with graph breaks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will be burdensome - opt-in sounds like good idea.

return t[:idx], t[idx:]


class OpsContainer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this class duplicates a lot of logic from Fuser, especially related to op fusion and recipe initialization. I wonder if we can refactor Fuser so we can reuse this code, and then switch between launching an autograd function and launching torch.compile infrastructure (e.g. maybe TorchCompileCompatibleFuser can inherit from Fuser and override the function that launches the autograd function).

@timmoon10 timmoon10 self-requested a review January 21, 2026 18:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants