diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 76d9e4f4b0..648f2c3db4 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -12,6 +12,7 @@ "div_by_1_rule", "dropout_inference_rule", "dropout_zero_rule", + "expand_before_binary_op_rules", "flatten_to_reshape_rule", "fuse_batchnorm_into_conv_rule", "fuse_batchnorm_into_conv_transpose_rule", @@ -125,6 +126,9 @@ no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule, ) +from onnxscript.rewriter.rules.common._remove_expand_before_binary_op import ( + expand_before_binary_op_rules, +) from onnxscript.rewriter.rules.common._remove_optional_bias import ( remove_optional_bias_from_conv_rule, remove_optional_bias_from_conv_transpose_rule, diff --git a/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py b/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py new file mode 100644 index 0000000000..92a83c1dcf --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op.py @@ -0,0 +1,295 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fusion rule to remove an Expand node before a binary operator. + +This implements the optimization: + + BinaryOp(Expand(x, shape), y) -> BinaryOp(x, y) + BinaryOp(x, Expand(y, shape)) -> BinaryOp(x, y) + +This is valid when the binary operator's broadcasting semantics would produce +the same output shape as first expanding the input and then applying the op. +""" + +from __future__ import annotations + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._ir_utils import get_numpy_value +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + +# Binary operators in ONNX standard opset that support numpy-style broadcasting. +_BROADCAST_BINARY_OPS: tuple[str, ...] = ( + "Add", + "And", + "BitShift", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "Div", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "Mod", + "Mul", + "Or", + "Pow", + "PRelu", + "Sub", + "Xor", +) + + +def _compute_broadcast_dim(d1, d2): + """Return the numpy broadcast of two dimension values. + + Each dimension value may be an ``int`` or an ``onnx_ir.SymbolicDim``. + Returns ``None`` when the result cannot be determined statically (e.g. two + distinct symbolic values neither of which is known to be 1). + """ + if d1 == 1: + return d2 + if d2 == 1: + return d1 + if d1 == d2: + return d1 + return None + + +def _compute_broadcast_shape(shape1: ir.Shape, shape2: ir.Shape) -> list | None: + """Compute numpy-style broadcast shape symbolically. + + Returns the broadcast shape as a list of dimension values (``int`` or + ``SymbolicDim``), or ``None`` when the result cannot be determined (e.g. + unknown ranks or incompatible static dims). + """ + rank1 = shape1.rank() + rank2 = shape2.rank() + if rank1 is None or rank2 is None: + return None + rank = max(rank1, rank2) + result = [] + for i in range(rank): + idx1 = rank1 - rank + i + d1 = shape1[idx1] if idx1 >= 0 else 1 + idx2 = rank2 - rank + i + d2 = shape2[idx2] if idx2 >= 0 else 1 + d = _compute_broadcast_dim(d1, d2) + if d is None: + return None + result.append(d) + return result + + +def _check_dims_sufficient( + expand_shape: ir.Shape, + x_shape: ir.Shape, + y_shape: ir.Shape, +) -> MatchResult: + """Check that x and y together cover every dimension of the expand target. + + For each dimension ``i`` of *expand_shape* (right-aligned) the expand is + considered redundant when at least one of the following holds: + + - ``expand_shape[i] == 1`` - expand cannot shrink a dim, so ``x_d`` must + also be 1 and both with and without expand produce ``y_d``. + - ``x_d == expand_shape[i]`` - the expand is a no-op at this dim. + - ``y_d == expand_shape[i]`` - ``y`` already supplies this expansion. + + Comparisons work for both ``int`` and ``SymbolicDim`` values. + """ + check_result = MatchResult() + e_rank = expand_shape.rank() + x_rank = x_shape.rank() + y_rank = y_shape.rank() + if e_rank is None: + return check_result.fail("Expand output rank is unknown.") + + for rev_i in range(e_rank): + i = e_rank - 1 - rev_i + e_d = expand_shape[i] + + if isinstance(e_d, int) and e_d == 1: + continue # expand cannot shrink; x_d is also 1, no-op + + x_idx = x_rank - 1 - rev_i + x_d = x_shape[x_idx] if x_idx >= 0 else 1 + if x_d == e_d: + continue # expand is a no-op at this dimension + + y_idx = y_rank - 1 - rev_i + y_d = y_shape[y_idx] if y_idx >= 0 else 1 + if y_d == e_d: + continue # y already supplies this dimension + + return check_result.fail( + f"Cannot verify that removing Expand is safe at dimension {i}: " + f"x_d={x_d!r}, expand_d={e_d!r}, y_d={y_d!r}." + ) + + return check_result + + +def _check_expand_removable( + expand_input: ir.Value, + shape: ir.Value, + other_input: ir.Value, + expand_output: ir.Value | None = None, + binary_op_output: ir.Value | None = None, +) -> MatchResult: + """Check if an Expand node can be safely removed before a binary op. + + The Expand ``expanded_x = Expand(x, expand_shape)`` before a binary op + ``out = BinaryOp(expanded_x, y)`` is redundant when the binary op's own + broadcasting produces the same output as if the expand had been applied. + + Three strategies are tried in order: + + 1. **Constant expand shape** - When ``shape`` is a compile-time constant, + the dimension values are extracted from it and the check is performed + directly. + + 2. **Expand output shape annotation** - When ``shape`` is dynamic but the + Expand node's output value already carries a shape annotation (e.g. + after ONNX shape inference has been applied to the model), those + dimension values are used for the check. + + 3. **Binary op output shape** - When neither of the above is available, + the rule verifies that ``broadcast(x.shape, y.shape)`` symbolically + equals the binary op's output shape. If they agree, the binary op's + own broadcasting already accounts for all the expansion and the + Expand is redundant. + + Args: + expand_input: The value fed into the Expand node (``x``). + shape: The target shape operand of the Expand node. + other_input: The other operand of the binary op (``y``). + expand_output: The output value of the Expand node. Required for + strategy 2. + binary_op_output: The output value of the binary op. Required for + strategy 3. + + Returns: + A :class:`MatchResult` that is successful when the Expand can be + removed. + """ + check_result = MatchResult() + + x_shape = expand_input.shape + y_shape = other_input.shape + if x_shape is None or y_shape is None: + return check_result.fail("Input shapes are not known.") + + x_rank = x_shape.rank() + y_rank = y_shape.rank() + + # --- Strategy 1: expand target shape is a compile-time constant --- + expand_shape_val = get_numpy_value(shape) + if expand_shape_val is not None: + expand_shape = tuple(int(v) for v in expand_shape_val.tolist()) + expand_rank = len(expand_shape) + + for rev_i in range(expand_rank): + i = expand_rank - 1 - rev_i + e_d = expand_shape[i] # always a known integer from numpy + + if e_d == 1: + continue # expand cannot shrink; x_d is also 1, no-op + + x_idx = x_rank - 1 - rev_i + x_d = x_shape[x_idx] if x_idx >= 0 else 1 + + if isinstance(x_d, int) and x_d == e_d: + continue # expand is a no-op at this dimension + + y_idx = y_rank - 1 - rev_i + y_d = y_shape[y_idx] if y_idx >= 0 else 1 + + if isinstance(y_d, int) and y_d == e_d: + continue # y already supplies this dimension + + return check_result.fail( + f"Cannot verify that removing Expand is safe at dimension {i}: " + f"x_d={x_d!r}, expand_d={e_d}, y_d={y_d!r}." + ) + + return check_result + + # --- Strategy 2: Expand output shape is known (e.g. from shape inference) --- + if expand_output is not None and expand_output.shape is not None: + return _check_dims_sufficient(expand_output.shape, x_shape, y_shape) + + # --- Strategy 3: use the binary op's output shape --- + # broadcast(x.shape, y.shape) must equal the binary op's output shape. + # If it does, the binary op's own broadcasting already produces the same + # result as first expanding x and then broadcasting. + if binary_op_output is not None and binary_op_output.shape is not None: + op_output_shape = binary_op_output.shape + if op_output_shape.rank() is not None: + computed = _compute_broadcast_shape(x_shape, y_shape) + if computed is not None and len(computed) == op_output_shape.rank(): + if all(c == a for c, a in zip(computed, op_output_shape)): + return check_result + return check_result.fail( + "broadcast(x.shape, y.shape) does not match the binary op output shape." + ) + + return check_result.fail( + "Expand target shape is not a constant and no shape annotations are available." + ) + + +class _ExpandFirstInput(RewriteRuleClassBase): + """Removes ``BinaryOp(Expand(x, shape), y)`` -> ``BinaryOp(x, y)``.""" + + def __init__(self, op_type: str) -> None: + super().__init__(f"ExpandFirst_{op_type}", remove_nodes=False) + self._op_type = op_type + + def pattern(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value: + return getattr(op, self._op_type)(op.Expand(x, shape), y) + + def check(self, context, x: ir.Value, shape: ir.Value, y: ir.Value) -> MatchResult: + expand_output = context.root.inputs[0] if context.root.inputs else None + binary_op_output = context.root.outputs[0] if context.root.outputs else None + return _check_expand_removable( + x, shape, y, expand_output=expand_output, binary_op_output=binary_op_output + ) + + def rewrite(self, op, x: ir.Value, shape: ir.Value, y: ir.Value) -> ir.Value: + return getattr(op, self._op_type)(x, y) + + +class _ExpandSecondInput(RewriteRuleClassBase): + """Removes ``BinaryOp(x, Expand(y, shape))`` -> ``BinaryOp(x, y)``.""" + + def __init__(self, op_type: str) -> None: + super().__init__(f"ExpandSecond_{op_type}", remove_nodes=False) + self._op_type = op_type + + def pattern(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value: + return getattr(op, self._op_type)(x, op.Expand(y, shape)) + + def check(self, context, x: ir.Value, y: ir.Value, shape: ir.Value) -> MatchResult: + expand_output = context.root.inputs[1] if context.root.inputs else None + binary_op_output = context.root.outputs[0] if context.root.outputs else None + return _check_expand_removable( + y, shape, x, expand_output=expand_output, binary_op_output=binary_op_output + ) + + def rewrite(self, op, x: ir.Value, y: ir.Value, shape: ir.Value) -> ir.Value: + return getattr(op, self._op_type)(x, y) + + +def _make_expand_before_binary_op_rules() -> list: + """Create rewrite rules for removing Expand before each supported binary op.""" + rules = [] + for op_type in _BROADCAST_BINARY_OPS: + rules.append(_ExpandFirstInput.rule(op_type)) + rules.append(_ExpandSecondInput.rule(op_type)) + return rules + + +expand_before_binary_op_rules = RewriteRuleSet(_make_expand_before_binary_op_rules()) diff --git a/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py b/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py new file mode 100644 index 0000000000..1a3bfb5451 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_expand_before_binary_op_test.py @@ -0,0 +1,348 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Tests for the remove-Expand-before-binary-op fusion rule.""" + +from __future__ import annotations + +import unittest + +import numpy as np +import onnx +import onnx.helper +import onnx.numpy_helper +import onnx.reference +import onnx.shape_inference +import parameterized + +import onnxscript.ir as ir +from onnxscript.rewriter.rules.common import _remove_expand_before_binary_op as mod + + +def _run_model(model: ir.Model, feeds: dict) -> list: + """Run a model using the ONNX reference evaluator.""" + proto = ir.to_proto(model) + ref = onnx.reference.ReferenceEvaluator(proto) + return ref.run(None, feeds) + + +class RemoveExpandBeforeBinaryOpTest(unittest.TestCase): + """Tests for _remove_expand_before_binary_op rules.""" + + def _apply_and_check( + self, + model_text: str, + expected_count: int, + expected_op_types: list[str], + ) -> ir.Model: + """Helper: apply the rules and verify the result.""" + model = ir.from_onnx_text(model_text) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, expected_count) + actual_op_types = [node.op_type for node in model.graph] + self.assertEqual(actual_op_types, expected_op_types) + return model + + # ------------------------------------------------------------------ + # Cases where the Expand should be removed + # ------------------------------------------------------------------ + + @parameterized.parameterized.expand( + [ + ("Add",), + ("Sub",), + ("Mul",), + ("Div",), + ] + ) + def test_expand_first_input_same_shape_is_removed(self, op_type: str): + """Expand producing same shape as input should be removed from BinaryOp.""" + model_text = f""" + + agraph (float[3, 4] x, float[3, 4] y) => (float[3, 4] output) + + {{ + expanded = Expand(x, shape) + output = {op_type}(expanded, y) + }} + """ + model = self._apply_and_check(model_text, 1, [op_type]) + + # Verify numerical correctness + x = np.random.randn(3, 4).astype(np.float32) + y = np.random.randn(3, 4).astype(np.float32) + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + def test_expand_first_input_broadcast_covered_by_other_input(self): + """Expand from [3, 4] to [4, 3, 4] can be removed when y has shape [4, 3, 4].""" + model_text = """ + + agraph (float[3, 4] x, float[4, 3, 4] y) => (float[4, 3, 4] output) + + { + expanded = Expand(x, shape) + output = Add(expanded, y) + } + """ + model = self._apply_and_check(model_text, 1, ["Add"]) + + x = np.random.randn(3, 4).astype(np.float32) + y = np.random.randn(4, 3, 4).astype(np.float32) + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + def test_expand_second_input_is_removed(self): + """Expand on the second input of a binary op should be removed.""" + model_text = """ + + agraph (float[4, 3, 4] x, float[3, 4] y) => (float[4, 3, 4] output) + + { + expanded = Expand(y, shape) + output = Mul(x, expanded) + } + """ + model = self._apply_and_check(model_text, 1, ["Mul"]) + + x = np.random.randn(4, 3, 4).astype(np.float32) + y = np.random.randn(3, 4).astype(np.float32) + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + def test_expand_with_broadcast_compatible_other_input(self): + """Expand from [3] to [4, 3] can be removed when y has shape [4, 1].""" + model_text = """ + + agraph (float[3] x, float[4, 1] y) => (float[4, 3] output) + + { + expanded = Expand(x, shape) + output = Add(expanded, y) + } + """ + model = self._apply_and_check(model_text, 1, ["Add"]) + + x = np.random.randn(3).astype(np.float32) + y = np.random.randn(4, 1).astype(np.float32) + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + def test_expand_sub_first_input_is_removed(self): + """Expand on the first input of Sub should be removed.""" + model_text = """ + + agraph (float[3, 4] x, float[3, 4] y) => (float[3, 4] output) + + { + expanded = Expand(x, shape) + output = Sub(expanded, y) + } + """ + model = self._apply_and_check(model_text, 1, ["Sub"]) + + x = np.random.randn(3, 4).astype(np.float32) + y = np.random.randn(3, 4).astype(np.float32) + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + def test_expand_div_second_input_is_removed(self): + """Expand on the second input of Div should be removed.""" + model_text = """ + + agraph (float[4, 3, 4] x, float[3, 4] y) => (float[4, 3, 4] output) + + { + expanded = Expand(y, shape) + output = Div(x, expanded) + } + """ + model = self._apply_and_check(model_text, 1, ["Div"]) + + x = np.random.randn(4, 3, 4).astype(np.float32) + y = np.random.randn(3, 4).astype(np.float32) + 2.0 # avoid division by zero + original = ir.from_onnx_text(model_text) + expected = _run_model(original, {"x": x, "y": y}) + got = _run_model(model, {"x": x, "y": y}) + np.testing.assert_allclose(got[0], expected[0], rtol=1e-5) + + # ------------------------------------------------------------------ + # Cases where the Expand should NOT be removed + # ------------------------------------------------------------------ + + def test_expand_changes_output_shape_not_removed(self): + """Expand that changes the output shape compared to direct broadcast must be kept.""" + # x has shape [3], expand to [4, 3], other is a scalar. + # With expand: broadcast([4, 3], []) = [4, 3] + # Without expand: broadcast([3], []) = [3] <- different! + model_text = """ + + agraph (float[3] x) => (float[4, 3] output) + + { + expanded = Expand(x, shape) + output = Add(expanded, one) + } + """ + model = ir.from_onnx_text(model_text) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_expand_target_shape_not_constant_removed_via_output_shape(self): + """Expand with a dynamic shape is removed when the binary op output shape + confirms the expansion is redundant. + + x=[3, 4] has no dimension equal to 1, so Expand can only output [3, 4]. + With y=[3, 4] the binary op output shape is also [3, 4], and + broadcast([3, 4], [3, 4]) = [3, 4] matches, so the expand is provably a + no-op and is safely removed. + """ + model_text = """ + + agraph (float[3, 4] x, float[3, 4] y, int64[2] shape) => (float[3, 4] output) + { + expanded = Expand(x, shape) + output = Add(expanded, y) + } + """ + model = ir.from_onnx_text(model_text) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 1) + """Expand with a symbolic x dim can be removed when y statically covers the expansion. + + x=[N], expand_shape=[3, 4], y=[3, 4]: since y provides all expand dimensions + as known integers, the expand is redundant regardless of N's runtime value. + """ + model_text = """ + + agraph (float[N] x, float[3, 4] y) => (float[3, 4] output) + + { + expanded = Expand(x, shape) + output = Add(expanded, y) + } + """ + model = ir.from_onnx_text(model_text) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 1) + + def test_expand_with_symbolic_y_dim_not_removed(self): + """Expand cannot be removed when y has a symbolic dim in a position where the + expand is doing work and that symbolic dim cannot be verified to equal expand_d. + """ + # x=[3], expand_shape=[4, 3], y=[M, 3]. + # At dim 0 (expand adds dim 4): x_d=1 (virtual), y_d=M (symbolic) -> can't verify. + model_text = """ + + agraph (float[3] x, float[M, 3] y) => (float[4, 3] output) + + { + expanded = Expand(x, shape) + output = Add(expanded, y) + } + """ + model = ir.from_onnx_text(model_text) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 0) + + def test_full_optimization(self): + oh = onnx.helper + model_proto = oh.make_model( + oh.make_graph( + [ + oh.make_node("Shape", ["x"], ["n"], start=0, end=1), + oh.make_node("Shape", ["x"], ["b"], start=1, end=2), + oh.make_node("Concat", ["n", "b"], ["shape"], axis=0), + oh.make_node("Expand", ["x", "shape"], ["expanded"]), + oh.make_node("Add", ["expanded", "y1"], ["z1"]), + oh.make_node("Add", ["expanded", "y2"], ["z2"]), + oh.make_node("Add", ["expanded", "y3"], ["z3"]), + oh.make_node("Add", ["z1", "z2"], ["z12"]), + oh.make_node("Add", ["z12", "z3"], ["z"]), + ], + "test", + [ + oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, ["N", 1]), + oh.make_tensor_value_info("y1", onnx.TensorProto.FLOAT, [1, "B"]), + oh.make_tensor_value_info("y2", onnx.TensorProto.FLOAT, [1, "B"]), + oh.make_tensor_value_info("y3", onnx.TensorProto.FLOAT, [1, "B"]), + ], + [ + oh.make_tensor_value_info("z", onnx.TensorProto.FLOAT, ["N", "B"]), + ], + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 20)], + ) + onnx.checker.check_model(model_proto) + # Shape inference is required so that the Expand output carries its + # shape annotation ([N, 1]). Without it the rule cannot verify that + # the expansion is redundant. + inferred_proto = onnx.shape_inference.infer_shapes(model_proto, data_prop=True) + model = ir.serde.deserialize_model(inferred_proto) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 5) + + def test_full_optimization_more_complex(self): + oh = onnx.helper + onh = onnx.numpy_helper + + model_proto = oh.make_model( + oh.make_graph( + [ + oh.make_node("Shape", ["x"], ["n"], start=0, end=1), + oh.make_node("Shape", ["x"], ["b"], start=1, end=2), + oh.make_node("Concat", ["n", "b"], ["shape"], axis=0), + oh.make_node("Add", ["shape", "one"], ["shape1"]), + oh.make_node("Sub", ["shape1", "one"], ["shape2"]), + oh.make_node("Expand", ["x", "shape2"], ["expanded"]), + oh.make_node("Add", ["expanded", "y1"], ["z1"]), + oh.make_node("Add", ["expanded", "y2"], ["z2"]), + oh.make_node("Add", ["expanded", "y3"], ["z3"]), + oh.make_node("Add", ["z1", "z2"], ["z12"]), + oh.make_node("Add", ["z12", "z3"], ["z"]), + ], + "test", + [ + oh.make_tensor_value_info("x", onnx.TensorProto.FLOAT, ["N", 1]), + oh.make_tensor_value_info("y1", onnx.TensorProto.FLOAT, [1, "B"]), + oh.make_tensor_value_info("y2", onnx.TensorProto.FLOAT, [1, "B"]), + oh.make_tensor_value_info("y3", onnx.TensorProto.FLOAT, [1, "B"]), + ], + [ + oh.make_tensor_value_info("z", onnx.TensorProto.FLOAT, ["N", "B"]), + ], + [onh.from_array(np.array([1], dtype=np.int64), "one")], + # Explicit shape annotations on intermediate values (as produced by + # shape inference or by the model creator). These allow the rule to + # verify that the Expand is redundant without tracing the exact + # computation that produced the shape tensor. + value_info=[ + oh.make_tensor_value_info("expanded", onnx.TensorProto.FLOAT, ["N", 1]), + oh.make_tensor_value_info("z1", onnx.TensorProto.FLOAT, ["N", "B"]), + oh.make_tensor_value_info("z2", onnx.TensorProto.FLOAT, ["N", "B"]), + oh.make_tensor_value_info("z3", onnx.TensorProto.FLOAT, ["N", "B"]), + ], + ), + ir_version=10, + opset_imports=[oh.make_opsetid("", 20)], + ) + onnx.checker.check_model(model_proto) + model = ir.serde.deserialize_model(model_proto) + count = mod.expand_before_binary_op_rules.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 5) + + +if __name__ == "__main__": + unittest.main()