Skip to content

linalg/arm64: NEON rms_norm_f32 kernel#11

Open
czoli1976 wants to merge 6 commits into
feat/avx512-rms-normfrom
feat/arm64-neon-rms-norm
Open

linalg/arm64: NEON rms_norm_f32 kernel#11
czoli1976 wants to merge 6 commits into
feat/avx512-rms-normfrom
feat/arm64-neon-rms-norm

Conversation

@czoli1976

Copy link
Copy Markdown
Owner

Summary

Adds an aarch64 NEON implementation of tract_linalg::ops().rms_norm_f32, mirroring the AVX-512 kernel from the parent RmsNorm PR (czoli1976#9 / sonos sonos#2311).

Stacked on feat/avx512-rms-norm because it needs:

  • the new Ops::rms_norm_f32 slot from tract-linalg
  • the new arch-neutral fast path in core::ops::nn::RmsNorm::eval

both of which the parent PR adds. If parent merges first this rebases to a single commit on top.

What

16 f32 lanes per inner loop iteration (4 v-registers × 4 lanes each):

  • Pass 1 — sum of squares via 4 parallel fmla chains (v0..v3) → 3-way fadd tree → vaddvq_f32 horizontal reduce to scalar.
  • Pass 2 — broadcast inv_std into v0, fmul/st1 each 4-v-register chunk in place.
  • Scalar tail handles the (len % 16 != 0) remainder.

Plugs into Ops::rms_norm_f32 from arm64::plug() right after the existing softmax / max / sum plumbing. NEON is mandatory on aarch64 so no runtime feature detection is needed; #[target_feature(enable = "neon")] is only there to provide the asm + intrinsic context.

The core-side RmsNorm::eval fast path (added by the parent PR) is already arch-neutral, so every model with a trailing-axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A / Neoverse instead of falling back to the 4-call composition.

Test plan

  • cargo check --target aarch64-unknown-linux-gnu -p tract-linalg — clean
  • cargo check --target aarch64-unknown-linux-gnu -p tract-linalg --tests — clean
  • cargo check -p tract-linalg (host x86_64) — unchanged
  • cargo test --release -p tract-linalg --lib rms_norm on x86_64 — 8 passed (existing generic + AVX-512 tests untouched)
  • cargo fmt --all -- --check clean
  • cargo clippy clean on the new file
  • Runtime tests on aarch64 hardware — needs a real aarch64 box to verify the 4 new frame tests (trivial, n=1024+7 with tail, n=8 all-tail, empty). I didn't have an aarch64 runner here; the kernel compiles cross-target and matches the AVX-512 algorithm structure exactly, but please run on Apple Silicon / Cortex-A before merging.
  • End-to-end bench on aarch64 — the existing linalg/benches/rms_norm.rs gains a neon column under cfg(target_arch = "aarch64"); expected wins are similar to AVX-512's 16-18× over the composed path, scaled down by the lane-count ratio (NEON 4 vs AVX-512 16 = roughly 4× at kernel level).

Risk

  • NEON is always available on aarch64, so no host gating issue.
  • Pure addition: no existing kernel modified, no Ops field semantics changed.
  • Falls back to the generic scalar rms_norm_f32 from tract_linalg::generic on non-aarch64 / non-AVX-512 hosts (unchanged).

Notes for review

This is the natural ARM mirror of the AVX-512 RmsNorm PR — same algorithm, same 2-pass structure, same scalar tail, just narrower SIMD (4 f32 vs 16). Submitting it now so the symmetry holds: f32/f16 RmsNorm gets accelerated on both x86 AVX-512 and aarch64 NEON paths from one core-side dispatch.

Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com


Generated by Claude Code

@czoli1976 czoli1976 force-pushed the feat/arm64-neon-rms-norm branch from 956aeab to 4ce3ff2 Compare May 28, 2026 20:07
@czoli1976 czoli1976 force-pushed the feat/avx512-rms-norm branch from ded79fd to 855f563 Compare May 28, 2026 20:13
@czoli1976 czoli1976 force-pushed the feat/arm64-neon-rms-norm branch from 4ce3ff2 to 459d38b Compare May 28, 2026 20:13
@czoli1976 czoli1976 force-pushed the feat/avx512-rms-norm branch from 855f563 to 6227823 Compare May 29, 2026 08:13
@czoli1976 czoli1976 force-pushed the feat/arm64-neon-rms-norm branch from 459d38b to 94ffd7b Compare May 29, 2026 08:13
@kali kali force-pushed the feat/avx512-rms-norm branch from 6227823 to ed8dfb5 Compare June 5, 2026 11:58
kali and others added 6 commits June 8, 2026 09:36
Pad had no GPU implementation, so any model using it (e.g. the Nemotron/Parakeet
preprocessor's signal-centering and real-to-complex pads) bounced through the host
mid-graph. Add GpuPad, a backend-agnostic copy-based op: broadcast the pad value
across the output, then copy the input into the interior, both via the existing
copy_nd. Constant mode only; Reflect/Edge stay on the host. The preprocessor signal
path now runs entirely on GPU.
Add a linalg-side fused row-wise RmsNorm primitive
(`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call
composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single
two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce
+ rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail
handles the remainder when row_len % 64 != 0; vmovups is used throughout
since per-row slices from a tensor are not guaranteed 64-byte aligned.

`core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs
where the normalised axis is the last (contiguous) one — it iterates row
by row and dispatches to the linalg primitive. Other shapes (non-trailing
axis) keep the original composition. Generic scalar fallback ships
alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar
version, which is itself ~equivalent to the composed path because both
are memory-bandwidth bound.

CUDA and Metal already expose a fused `rms_norm` kernel
(`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`);
this closes the CPU side of the same gap.

Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s):
  - row 1024:  0.77 (composed) -> 12.4 (AVX-512)   16.2x
  - row 2048:  0.77            -> 13.8             17.9x
  - row 4096:  0.77            -> 13.8             17.9x

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an aarch64 NEON implementation of `tract_linalg::ops().rms_norm_f32`,
mirroring the AVX-512 kernel from the parent RmsNorm PR. 16 f32 lanes per
inner loop iteration (4 v-registers of 4 lanes each):

  Pass 1 — sum of squares via 4 fmla chains (v0..v3), 3-way fadd reduce,
           then horizontal reduce to scalar via vaddvq_f32.
  Pass 2 — broadcast inv_std into v0, multiply each 4-v-register chunk
           in place.
  Scalar tail handles (len % 16 != 0).

Plugs into `Ops::rms_norm_f32` in `arm64::plug()`. The core-side fast path
in `core::ops::nn::RmsNorm::eval` (added by the parent PR) is already
arch-neutral and picks this up automatically — every model with a trailing-
axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A /
Neoverse instead of the generic 4-call composition.

Tests use the same scalar-reference pattern as the AVX-512 kernel:
trivial, prop-style sin/cos input at n=16, n=1024+7 (exercising the
scalar tail), and a sub-chunk n=8 (all-tail) case. NEON is mandatory on
aarch64 so no runtime feature detection is needed; the kernel is gated by
`#[target_feature(enable = "neon")]` only for the inline-asm + intrinsic
context.

Cross-compile check: `cargo check --target aarch64-unknown-linux-gnu -p
tract-linalg` clean on the modified files. The x86_64 bench output is
unchanged (the kernel module is `#[cfg(target_arch = "aarch64")]`-only via
the `arm64` parent), and the rms_norm bench gains a "neon" column when
built for aarch64.

Dependencies: needs the parent RmsNorm PR (which adds the `Ops::rms_norm_f32`
slot and the `core::ops::nn::RmsNorm::eval` dispatcher). If the parent
lands first this rebases trivially.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kali kali force-pushed the feat/arm64-neon-rms-norm branch from 94ffd7b to 129b6d2 Compare June 8, 2026 11:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants