Add fusion rule to remove Expand before broadcast-capable binary operators#2862
Add fusion rule to remove Expand before broadcast-capable binary operators#2862
Conversation
… operators Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
Outdated
Show resolved
Hide resolved
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
Outdated
Show resolved
Hide resolved
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
…://github.com/microsoft/onnxscript into copilot/create-fusion-rule-remove-expand-node
Signed-off-by: Xavier Dupré <xadupre@microsoft.com>
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py
Show resolved
Hide resolved
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2862 +/- ##
==========================================
+ Coverage 71.86% 71.99% +0.12%
==========================================
Files 239 241 +2
Lines 29138 29379 +241
Branches 2875 2899 +24
==========================================
+ Hits 20941 21152 +211
- Misses 7219 7235 +16
- Partials 978 992 +14 ☔ View full report in Codecov by Sentry. |
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py
Fixed
Show fixed
Hide fixed
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py
Outdated
Show resolved
Hide resolved
…adcast comparison Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com> Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/4d4f4fb8-b66e-456e-a1d6-b1eb5ca1b532
| expand_shape = tuple(int(v) for v in expand_shape_val.tolist()) | ||
| expand_rank = len(expand_shape) | ||
|
|
||
| for rev_i in range(expand_rank): |
There was a problem hiding this comment.
Looks like the names i and rev_i should be swapped ... it is confusing in current form
There was a problem hiding this comment.
@copilot can you swap the names i and rev_i in this code fragment?
| if isinstance(y_d, int) and y_d == e_d: | ||
| continue # y already supplies this dimension | ||
|
|
||
| return check_result.fail( |
There was a problem hiding this comment.
Maybe it should not fail here, but try strategy 2/3 (in case some of the dims are symbolic).
There was a problem hiding this comment.
@copilot , if some of the dims are symbolic, then the checks above will fail. Instead of failing, perhaps it should continue on to try strategies 2/3? There should be some strategy that will handle symbolic dims
Adds a rewrite rule that eliminates redundant
Expandnodes preceding binary operators that natively support NumPy-style broadcasting.Pattern
Safety check
The rule applies a dimension-by-dimension analysis to determine if the
Expandis redundant. For each dimensioni, the expand is safe to remove if any of the following hold:expand_shape[i] == 1- expand cannot shrink a dimension, so it is a no-op.x.shape[i] == expand_shape[i]- the expand is a no-op at this dimension.y.shape[i] == expand_shape[i]-yalready covers the expansion via its own broadcasting.Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape:
Constant expand shape: When the
shapeargument is a compile-time constant, the check is applied directly. Individual dimensions ofxorymay still be symbolic. For example,Add(Expand(x=[N], shape=[3,4]), y=[3,4])is optimized toAdd(x, y)becauseystatically provides all expansion dimensions.Expand output shape annotation: When
shapeis dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, afteronnx.shape_inference.infer_shapes,Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2)))gets output shape[N,1]and the rule fires.Binary op output shape: When neither of the above is available, the rule verifies that
broadcast(x.shape, y.shape)symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant.Supported ops
Add,Sub,Mul,Div,Pow,And,Or,Xor,BitwiseAnd,BitwiseOr,BitwiseXor,Greater,Less,Equal,GreaterOrEqual,LessOrEqual,Mod,PRelu,BitShiftChanges
_remove_expand_before_binary_op.py— new module with_ExpandFirstInput/_ExpandSecondInputrule classes,_compute_broadcast_shape/_check_dims_sufficienthelpers, and the exportedexpand_before_binary_op_rulesRewriteRuleSet; rule classes accesscontext.rootto obtain the Expand output and binary op output values_remove_expand_before_binary_op_test.py— tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verifiedrules/common/__init__.py— exportsexpand_before_binary_op_rulesOriginal prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.