-
Notifications
You must be signed in to change notification settings - Fork 612
[PyTorch] Prototype of torch.compile support for TE Sequential #2608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
| if TYPE_CHECKING: | ||
| from ..compile_compat.tensor_info import TensorInfo, PseudoForwardResult |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not import directly? compile_compat doesn't depend on anything in te.ops.basic, so if we import in the correct order then we shouldn't have circular dependencies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of poor-quality ai-generated code here, so it's not worth to review this PR in detail. I think we need to have some agreement on high-level design and I will reimplement it from scratch. Maybe I will elaborate more why this PR works the way it works tomorrow.
| """ | ||
|
|
||
| def pseudo_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: pseudo_forward seems too vague to me. Something like torch_compile_forward or compile_forward is more intuitive.
| This default implementation provides basic shape propagation. | ||
| Subclasses should override if they have different output shapes | ||
| or need to save tensors for backward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that we overwrite pseudo_forward in the bias op, which doesn't save any backward tensors. Do we need to overwrite if the ctx is non-trivial in any way? If that's the case, then that seems very burdensome to an external user writing a custom op outside of TE (see #2597).
If the the base class can't handle it automatically, how about we use an opt-in approach? If a fuser has ops that support torch.compile, we'll use the infrastructure in compile_compat. If any op does not support torch.compile, then we'll fall back to a naive impl with graph breaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will be burdensome - opt-in sounds like good idea.
| return t[:idx], t[idx:] | ||
|
|
||
|
|
||
| class OpsContainer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this class duplicates a lot of logic from Fuser, especially related to op fusion and recipe initialization. I wonder if we can refactor Fuser so we can reuse this code, and then switch between launching an autograd function and launching torch.compile infrastructure (e.g. maybe TorchCompileCompatibleFuser can inherit from Fuser and override the function that launches the autograd function).
[WIP] torch.compile Support for te.ops API
This is an early draft - NOT ready for merge. I'm sharing this to show progress and encourage discussion on the design approach.
Overview
This PR introduces
torch.compile(fullgraph=True)support forte.opsby wrapping fusion logic in custom operators. The key insight is to make fusion decisions invisible to torch.compile by encapsulating them inside opaque custom ops.What are Opaque Types?
PyTorch's
torch._library.opaque_object.register_opaque_typeallows passing custom Python objects throughtorch.compilewithout causing graph breaks. During tracing, the object becomes aFakeScriptObject- torch.compile doesn't see inside it.Two variants:
typ="reference") - for mutable/stateful objects, passed as graph inputs at runtimetyp="value") - for immutable data, baked into graph as constantsThis lets us hide complex TE logic (fusion decisions, FP8 recipes, operation containers) from the compiler while still being fully compilable.
Design Highlights
Opaque Types Used:
Recipe- reference type (delayed scaling mutates internal state)OpsContainer- reference type, holdsBasicOperationlist, created outside compiled regionOpaqueKwargs- value type (immutable, baked into graph)TensorInfo- value type for lightweight shape/dtype descriptorsKey Abstraction -
pseudo_forward:BasicOperationfor shape inference without actual computationregister_fake) for output shape predictionctx_datafrom savedTensorInfoCtxContainer- deterministic reconstruction insteadCustom Operators
te_ops::fused_forward
te_ops::fused_backward
Integration with TE Sequential
te.Sequentialgroups consecutiveFusibleOperations and wraps each group in anOperationFuser(see_make_module_groups()insequential.py). TheOperationFuser:FusibleOperations intoBasicOperationsmaybe_fuse_ops()(e.g.,fuse_forward_linear_bias_activation,fuse_backward_activation_bias)_OperationFuserAutogradFunctionProblem:
OperationFuseris not compile-friendly - fusion decisions, dynamic op lists, and complex autograd logic cause graph breaks.Solution:
TorchCompileCompatibleFuseris a drop-in replacement forOperationFuser:BasicOperationlist in anOpsContainer(opaque to torch.compile)_fuse_forward_ops,_fuse_backward_ops) moves insideOpsContainer.run_forward/run_backwardte_ops::fused_forward/backward) expose a clean tensor-in-tensor-out interface