Skip to content

Add fusion rule to remove Expand before broadcast-capable binary operators#2862

Open
Copilot wants to merge 11 commits intomainfrom
copilot/create-fusion-rule-remove-expand-node
Open

Add fusion rule to remove Expand before broadcast-capable binary operators#2862
Copilot wants to merge 11 commits intomainfrom
copilot/create-fusion-rule-remove-expand-node

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 20, 2026

Adds a rewrite rule that eliminates redundant Expand nodes preceding binary operators that natively support NumPy-style broadcasting.

Pattern

BinaryOp(Expand(x, shape), y)  →  BinaryOp(x, y)
BinaryOp(x, Expand(y, shape))  →  BinaryOp(x, y)

Safety check

The rule applies a dimension-by-dimension analysis to determine if the Expand is redundant. For each dimension i, the expand is safe to remove if any of the following hold:

  • expand_shape[i] == 1 - expand cannot shrink a dimension, so it is a no-op.
  • x.shape[i] == expand_shape[i] - the expand is a no-op at this dimension.
  • y.shape[i] == expand_shape[i] - y already covers the expansion via its own broadcasting.

Otherwise the check fails conservatively. Three producer-agnostic strategies are used to resolve the expand target shape:

  1. Constant expand shape: When the shape argument is a compile-time constant, the check is applied directly. Individual dimensions of x or y may still be symbolic. For example, Add(Expand(x=[N], shape=[3,4]), y=[3,4]) is optimized to Add(x, y) because y statically provides all expansion dimensions.

  2. Expand output shape annotation: When shape is dynamic but the Expand node's output value already carries a shape annotation (e.g. after ONNX shape inference has been applied), those dimension values are used directly for the check. For example, after onnx.shape_inference.infer_shapes, Expand(x=[N,1], Concat(Shape(x,0:1), Shape(x,1:2))) gets output shape [N,1] and the rule fires.

  3. Binary op output shape: When neither of the above is available, the rule verifies that broadcast(x.shape, y.shape) symbolically equals the binary op's output shape. If they agree, the binary op's own broadcasting already accounts for all the expansion and the Expand is redundant.

Supported ops

Add, Sub, Mul, Div, Pow, And, Or, Xor, BitwiseAnd, BitwiseOr, BitwiseXor, Greater, Less, Equal, GreaterOrEqual, LessOrEqual, Mod, PRelu, BitShift

Changes

  • _remove_expand_before_binary_op.py — new module with _ExpandFirstInput / _ExpandSecondInput rule classes, _compute_broadcast_shape / _check_dims_sufficient helpers, and the exported expand_before_binary_op_rules RewriteRuleSet; rule classes access context.root to obtain the Expand output and binary op output values
  • _remove_expand_before_binary_op_test.py — tests covering removal when safe (including dynamic shapes via shape annotations and binary op output shape matching), and non-removal when the expansion cannot be statically verified
  • rules/common/__init__.py — exports expand_before_binary_op_rules
Original prompt

This section details on the original issue you should resolve

<issue_title>create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting</issue_title>
<issue_description></issue_description>

Comments on the Issue (you are @copilot in this section)


✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.

… operators

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Copilot AI changed the title [WIP] Add fusion rule to remove expand node before binary operator Add fusion rule to remove Expand before broadcast-capable binary operators Mar 20, 2026
Copilot AI requested a review from xadupre March 20, 2026 16:29
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
xadupre added 3 commits March 20, 2026 18:23
Signed-off-by: Xavier Dupré <xadupre@microsoft.com>
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 20, 2026

Codecov Report

❌ Patch coverage is 87.55187% with 30 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.99%. Comparing base (19e5284) to head (a47d985).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
...er/rules/common/_remove_expand_before_binary_op.py 77.04% 15 Missing and 13 partials ⚠️
...les/common/_remove_expand_before_binary_op_test.py 98.30% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2862      +/-   ##
==========================================
+ Coverage   71.86%   71.99%   +0.12%     
==========================================
  Files         239      241       +2     
  Lines       29138    29379     +241     
  Branches     2875     2899      +24     
==========================================
+ Hits        20941    21152     +211     
- Misses       7219     7235      +16     
- Partials      978      992      +14     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
…adcast comparison

Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
Agent-Logs-Url: https://github.com/microsoft/onnxscript/sessions/4d4f4fb8-b66e-456e-a1d6-b1eb5ca1b532
Copilot AI requested a review from xadupre March 23, 2026 12:17
@xadupre xadupre marked this pull request as ready for review March 25, 2026 10:06
expand_shape = tuple(int(v) for v in expand_shape_val.tolist())
expand_rank = len(expand_shape)

for rev_i in range(expand_rank):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the names i and rev_i should be swapped ... it is confusing in current form

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot can you swap the names i and rev_i in this code fragment?

if isinstance(y_d, int) and y_d == e_d:
continue # y already supplies this dimension

return check_result.fail(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it should not fail here, but try strategy 2/3 (in case some of the dims are symbolic).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot , if some of the dims are symbolic, then the checks above will fail. Instead of failing, perhaps it should continue on to try strategies 2/3? There should be some strategy that will handle symbolic dims

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

create a fusion rule to remove an expand node before a binary operator if this op can handle it through broadcasting

3 participants