linalg/arm64: NEON rms_norm_f32 kernel#11
Open
czoli1976 wants to merge 6 commits into
Open
Conversation
7 tasks
956aeab to
4ce3ff2
Compare
ded79fd to
855f563
Compare
4ce3ff2 to
459d38b
Compare
855f563 to
6227823
Compare
459d38b to
94ffd7b
Compare
6227823 to
ed8dfb5
Compare
Pad had no GPU implementation, so any model using it (e.g. the Nemotron/Parakeet preprocessor's signal-centering and real-to-complex pads) bounced through the host mid-graph. Add GpuPad, a backend-agnostic copy-based op: broadcast the pad value across the output, then copy the input into the interior, both via the existing copy_nd. Constant mode only; Reflect/Edge stay on the host. The preprocessor signal path now runs entirely on GPU.
Add a linalg-side fused row-wise RmsNorm primitive (`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce + rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail handles the remainder when row_len % 64 != 0; vmovups is used throughout since per-row slices from a tensor are not guaranteed 64-byte aligned. `core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs where the normalised axis is the last (contiguous) one — it iterates row by row and dispatches to the linalg primitive. Other shapes (non-trailing axis) keep the original composition. Generic scalar fallback ships alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar version, which is itself ~equivalent to the composed path because both are memory-bandwidth bound. CUDA and Metal already expose a fused `rms_norm` kernel (`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`); this closes the CPU side of the same gap. Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s): - row 1024: 0.77 (composed) -> 12.4 (AVX-512) 16.2x - row 2048: 0.77 -> 13.8 17.9x - row 4096: 0.77 -> 13.8 17.9x Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an aarch64 NEON implementation of `tract_linalg::ops().rms_norm_f32`,
mirroring the AVX-512 kernel from the parent RmsNorm PR. 16 f32 lanes per
inner loop iteration (4 v-registers of 4 lanes each):
Pass 1 — sum of squares via 4 fmla chains (v0..v3), 3-way fadd reduce,
then horizontal reduce to scalar via vaddvq_f32.
Pass 2 — broadcast inv_std into v0, multiply each 4-v-register chunk
in place.
Scalar tail handles (len % 16 != 0).
Plugs into `Ops::rms_norm_f32` in `arm64::plug()`. The core-side fast path
in `core::ops::nn::RmsNorm::eval` (added by the parent PR) is already
arch-neutral and picks this up automatically — every model with a trailing-
axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A /
Neoverse instead of the generic 4-call composition.
Tests use the same scalar-reference pattern as the AVX-512 kernel:
trivial, prop-style sin/cos input at n=16, n=1024+7 (exercising the
scalar tail), and a sub-chunk n=8 (all-tail) case. NEON is mandatory on
aarch64 so no runtime feature detection is needed; the kernel is gated by
`#[target_feature(enable = "neon")]` only for the inline-asm + intrinsic
context.
Cross-compile check: `cargo check --target aarch64-unknown-linux-gnu -p
tract-linalg` clean on the modified files. The x86_64 bench output is
unchanged (the kernel module is `#[cfg(target_arch = "aarch64")]`-only via
the `arm64` parent), and the rms_norm bench gains a "neon" column when
built for aarch64.
Dependencies: needs the parent RmsNorm PR (which adds the `Ops::rms_norm_f32`
slot and the `core::ops::nn::RmsNorm::eval` dispatcher). If the parent
lands first this rebases trivially.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
94ffd7b to
129b6d2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds an aarch64 NEON implementation of
tract_linalg::ops().rms_norm_f32, mirroring the AVX-512 kernel from the parent RmsNorm PR (czoli1976#9/ sonos sonos#2311).Stacked on
feat/avx512-rms-normbecause it needs:Ops::rms_norm_f32slot fromtract-linalgcore::ops::nn::RmsNorm::evalboth of which the parent PR adds. If parent merges first this rebases to a single commit on top.
What
16 f32 lanes per inner loop iteration (4 v-registers × 4 lanes each):
fmlachains (v0..v3) → 3-wayfaddtree →vaddvq_f32horizontal reduce to scalar.inv_stdinto v0,fmul/st1each 4-v-register chunk in place.len % 16 != 0) remainder.Plugs into
Ops::rms_norm_f32fromarm64::plug()right after the existing softmax / max / sum plumbing. NEON is mandatory on aarch64 so no runtime feature detection is needed;#[target_feature(enable = "neon")]is only there to provide the asm + intrinsic context.The core-side
RmsNorm::evalfast path (added by the parent PR) is already arch-neutral, so every model with a trailing-axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A / Neoverse instead of falling back to the 4-call composition.Test plan
cargo check --target aarch64-unknown-linux-gnu -p tract-linalg— cleancargo check --target aarch64-unknown-linux-gnu -p tract-linalg --tests— cleancargo check -p tract-linalg(host x86_64) — unchangedcargo test --release -p tract-linalg --lib rms_normon x86_64 — 8 passed (existing generic + AVX-512 tests untouched)cargo fmt --all -- --checkcleancargo clippyclean on the new filelinalg/benches/rms_norm.rsgains aneoncolumn undercfg(target_arch = "aarch64"); expected wins are similar to AVX-512's 16-18× over the composed path, scaled down by the lane-count ratio (NEON 4 vs AVX-512 16 = roughly 4× at kernel level).Risk
Opsfield semantics changed.rms_norm_f32fromtract_linalg::genericon non-aarch64 / non-AVX-512 hosts (unchanged).Notes for review
This is the natural ARM mirror of the AVX-512 RmsNorm PR — same algorithm, same 2-pass structure, same scalar tail, just narrower SIMD (4 f32 vs 16). Submitting it now so the symmetry holds: f32/f16 RmsNorm gets accelerated on both x86 AVX-512 and aarch64 NEON paths from one core-side dispatch.
Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com
Generated by Claude Code