Skip to content

metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320

Merged
kali merged 2 commits into
sonos:mainfrom
czoli1976:feature/metal-flash-sdpa
Jun 15, 2026
Merged

metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320
kali merged 2 commits into
sonos:mainfrom
czoli1976:feature/metal-flash-sdpa

Conversation

@czoli1976

Copy link
Copy Markdown
Contributor

@kali — opening this as much as an RFC as a PR, so please read the design question below before the diff.

As I've seen you've been actively building out the GPU path lately — mirroring the CUDA kernels into Metal (gather, diag_gather), landing scaled_masked_softmax (bool mask + post-softmax mask) on both backends, and adding the transform pre-check + the CPU-fallback-when-a-gpu-op-rejects-a-shape — I wanted to surface something adjacent that's been sitting unused in the Metal crate and check the direction with you before investing further.

The opportunity

The vendored libMetalFlashAttention.metallib (Apple / Philip Turner's MetalFlashAttention, MIT — already in-tree) ships four functions: sgemm, hgemm, convolution, and attention. We only ever dispatch sgemm/hgemm. The attention entry is a fully-implemented fused flash-attention kernel (online softmax, never materializes the score matrix) — and it's simply never called. (For completeness: convolution in this metallib is an empty stub, so attention is the only unused-but-real kernel here.)

Meanwhile MetalTransform explodes every Sdpa into einsum → softmax → einsum, which materializes the full (B,H,Sq,Sk) score buffer in device memory and round-trips it through three kernels — the middle one being the scaled_masked_softmax you just landed.

This PR wires the fused kernel and routes Sdpa to it.

The design question (why this is an RFC)

Dispatching a vendored, 2023-era macosx13 metallib that we don't own is a real commitment — and I noticed you just dropped the unused GgmlFlashAttn kernel library on the CUDA side (73dada812), which cuts the other way. So I'd rather ask directly than assume:

Is wiring this vendored Metal kernel the direction you want — or would you prefer a fresh, owned port (e.g. translating the MLX / ggml-metal flash-attention kernel into a .metal source we control)?

The case for wiring it now: it's already in-tree, fully implemented, MIT, and measures ~2× — a low-risk way to close the fused-Metal-SDPA gap today. The owned-.metal-port hedge stays open as a follow-up if metallib longevity on M3/M4 worries you. This PR is the "wire what's already there" option, fully validated, for you to accept or redirect.

What it does

  • dispatch_metal_mfa_attention — drives the vendored attention function. ABI (buffers / function-constants / grid geometry) reconstructed from the MFA v1.0.1 source + on-GPU pipeline reflection. f32 + f16.
  • mfa_attention_head_major — adapts tract's native [B,H,S,D] to MFA's layout (Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd. The one unavoidable copy is the K transpose (candidate to fold later).
  • MetalMfaSdpa op + register_metal_op!(Sdpa) translator — routes a real Sdpa node to the fused kernel. Unsupported shapes return None and fall through to the CPU-fallback path you just added (85255fdb9). A Metal-local rewire_sdpa_metal flattens only the Sdpa nodes the kernel can't take, leaving fusable ones intact (CUDA keeps the shared rewire_sdpa untouched).
  • causal via an additive [Sq,Sk] mask — the metallib's triangular function-constant alone is a no-op (it computes full attention), pinned by a regression test.

Numbers (Apple M-series, f32, B=1 H=8 S=512 D=64)

measurement result
fused kernel vs explode path (both preallocated, dispatch_eval) ~2×, growing with S (2.4× at S=2048)
8-layer all-attention stack (amortizes host sync) 2.70× on the attention portion (fused 6.6 ms/layer vs explode 17.9)
projected real-model e2e (2.70× × attention compute-share) ~1.2–1.4×
(B,H,Sq,Sk) score buffer eliminated (≈128 MB at S=2048/H=8)

(A single-op model bench reports 3.9×, but that's overhead-inflated — the 2.70× multi-layer figure is the honest one, and it's consistent with the kernel-level ~2×.)

Correctness & gates

  • Bit-close to a CPU reference across f32/f16, head dims 16–128, masked, causal, multi-head, and the head-major adapter.
  • An e2e test builds a real Sdpa model, runs MetalTransform, asserts it routes to MetalMfaSdpa, and matches the CPU FlashSdpa output.
  • Full tract-metal suite 71/0; cargo build --workspace clean; fmt + clippy clean on the new code.

Orthogonal to #2319 (that one is the CPU FlashSdpa path; this is Metal).

Credits / prior art

🤖 Generated with Claude Code

czoli1976 added a commit to czoli1976/tract that referenced this pull request Jun 3, 2026
…rite + NNEF + resume

tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:

1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
   `axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
   any axis), double only when capacity is exceeded -> O(T) amortized copy.
   valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
   that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
   capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.

2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
   SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
   cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
   the buffers inside the consuming op is what makes the saving real. Drop-in for
   {kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.

3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
   (fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
   so existing decode models adopt the in-place cache transparently.

4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).

5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
   snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).

Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.

Benched (release, B=1 H=8 D=128):
  - cache-update only:      21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
  - end-to-end via the op:  1.10x (256) -> 1.63x (2048), 39% faster decode @2k
  - resume checkpoint:      O(len), 0.10ms (256) -> 1.76ms (4096), one-time

Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@kali

kali commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

I'm happy with using what is already there on Metal, including wiring stuff that is not wired as you did.

At some point, Cuda took first stage. I'm happy to get Metal up to speed when there are low-hanging fruits. Plugging it stuff which is "nearly there" makes sense.

Convolution support is the big remaining gap, as there is no clear equivalent to cudnn (which incidentally i don't like very much, as it's an extra runtime dep).

I got rid of GGML vendoring because the in-house kernel, which I actually understand is in the same ballpark on model and architecture that matters to tract. Also it was not vendored gracefully, we had to patch it pretty heavily (so basically forking it). But it was not for any dependency policy.

@czoli1976

czoli1976 commented Jun 12, 2026 via email

Copy link
Copy Markdown
Contributor Author

czoli1976 and others added 2 commits June 12, 2026 14:18
tract's metal crate already vendors libMetalFlashAttention (Apple / Philip
Turner's flash-attention kernels, MIT) but only used its sgemm/hgemm entry
points -- the fused `attention` kernel shipped inside the metallib was never
dispatched. The MetalTransform instead explodes every Sdpa into
einsum + softmax + einsum, materializing the full (B,H,Sq,Sk) score matrix in
device memory and round-tripping it through three separate kernels.

This wires the fused kernel:

  * `dispatch_metal_mfa_attention` drives the vendored `attention` function
    (online softmax / flash attention; ABI reconstructed from the v1.0.1
    source + on-GPU pipeline reflection). f32 and f16.
  * `mfa_attention_head_major` adapts tract's native [B,H,S,D] layout to MFA's
    (Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd.
  * `MetalMfaSdpa` op + a `register_metal_op!(Sdpa)` translator route a real
    Sdpa node to the fused kernel; unsupported shapes fall back to the existing
    explode path. The new `rewire_sdpa_metal` only flattens the Sdpa nodes the
    kernel can't take, leaving fusable ones intact (cuda keeps the shared
    `rewire_sdpa`).
  * causal masking via an additive [Sq,Sk] mask -- the metallib's `triangular`
    function-constant alone is a no-op, pinned by a regression test.

Eliminates the (B,H,Sq,Sk) intermediate and collapses three kernels to one.
Measured on M-series (f32, B=1 H=8 S=512 D=64): the kernel is ~2x the explode
path, and an 8-layer all-attention stack (amortizing host sync) runs 2.70x
faster on the attention portion -- so a real model's end-to-end gain is 2.70x
scaled by attention's compute share.

Correctness: bit-close to a CPU reference across f32/f16, head dims 16..128,
masked, causal, multi-head, and head-major layout; an e2e test builds a real
Sdpa model, runs the MetalTransform, asserts it routes to MetalMfaSdpa, and
matches the CPU FlashSdpa output.

Apple MetalFlashAttention: https://github.com/philipturner/metal-flash-attention
Prior art (fused-attention dispatch): llama.cpp ggml-metal flash-attn kernel;
candle-metal-kernels.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
main's new GQA suite (q_heads = r * kv_heads) exposed that
mfa_sdpa_supported never compared head counts: GQA Sdpa nodes fused
into the 2023 MFA kernel (which predates GQA) and computed garbage.
Require equal & concrete Q/K/V head counts so GQA falls back to the
explode path, assert the same invariant in eval, and pin the decline +
fallback-matches-CPU behavior with a regression test.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@czoli1976 czoli1976 force-pushed the feature/metal-flash-sdpa branch from 51dbc4f to 9a39f8c Compare June 12, 2026 13:26
@czoli1976

Copy link
Copy Markdown
Contributor Author

Great, proceeding as-is then. Pushed a rebase plus a fix: the gate never compared Q vs K/V head counts, so your new GQA tests hit the MFA kernel (which predates GQA) and got garbage — GQA now falls back to the explode path, with a regression test. Also FYI: the metallib's convolution entry is an empty stub, so MFA won't help on that front.

Follow-up ready on feature/metal-mlx-sdpa (will open once this lands): MLX's SDPA kernels ported as owned .metal source, same idiom as the in-tree MlxGemm. Native GQA, decode + prefill kernels, and it benches 5.7× vs explode — 2.2× faster than this metallib — so it would retire the binary blob. It also carries an optional commit wiring MLX's Metal-4 tensor-op kernel for the M5 neural accelerators (runtime-gated; correct only on M5 hardware, which I don't have — kernel is verbatim MLX, validated upstream).

Have a nice weekend!

@kali

kali commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

/ci llm

trying the new toy :)

@github-actions

Copy link
Copy Markdown

✅ CI / large-models: success

  • cli: success
  • foundation-llms: success
  • foundation-llm: success
  • parakeet-tdt-600m-v3: success
  • nemotron-speech-streaming-en-06b: success

View workflow run

@czoli1976

czoli1976 commented Jun 12, 2026 via email

Copy link
Copy Markdown
Contributor Author

@kali

kali commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

i wish... just an easier way to start extra validation jobs :)

@kali kali merged commit b9c8059 into sonos:main Jun 15, 2026
57 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants