Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853
Fix aten_gru: add linear_before_reset=1 to match PyTorch GRU semantics#2853
Conversation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2853 +/- ##
=======================================
Coverage 71.86% 71.86%
=======================================
Files 239 239
Lines 29138 29138
Branches 2875 2875
=======================================
Hits 20941 20941
Misses 7219 7219
Partials 978 978 ☔ View full report in Codecov by Sentry. |
|
@copilot add a comment inline for future readers |
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
Added a comment (commit |
titaiwangms
left a comment
There was a problem hiding this comment.
@copilot is there any existing tests are fixed?
There was a problem hiding this comment.
Pull request overview
This PR fixes the Torch aten_gru → ONNX GRU lowering to match PyTorch nn.GRU semantics by explicitly setting linear_before_reset=1, addressing a known numerical mismatch.
Changes:
- Add
linear_before_reset=1to the biasedop.GRUcall inaten_gru. - Add
linear_before_reset=1to the bias-freeop.GRUcall inaten_gru. - Add inline documentation explaining why
linear_before_reset=1is required for PyTorch parity.
PyTorch
nn.GRUapplies the linear transformation before multiplying by the reset gate (linear_before_reset=1), but theaten_grutranslation was emitting ONNXGRUops with the defaultlinear_before_reset=0, producing numerically wrong results (error ~0.1 vs expected ~1e-7).Changes
onnxscript/function_libs/torch_lib/ops/core.py: Addlinear_before_reset=1to bothop.GRUcalls inaten_gru— the biased and unbiased variants.Original prompt
📱 Kick off Copilot coding agent tasks wherever you are with GitHub Mobile, available on iOS and Android.