metal: fused Sdpa via the vendored MetalFlashAttention kernel (~2×)#2320
Conversation
…rite + NNEF + resume
tract's DynKeyValueCache grows by TypedConcat([past, new]) each step, copying the
whole t-token past into a fresh buffer -> O(T^2) total copy over a T-token decode.
Apple Core ML "stateful in-place KV" lever. Pieces:
1. InPlaceKvCache: geometric-growth in-place cache. Buffer with spare capacity along
`axis`, write each new chunk at the cursor (Tensor::assign_slice, strided-safe for
any axis), double only when capacity is exceeded -> O(T) amortized copy.
valid_view() exposes the live [0..len] region as a ZERO-COPY ndarray view (the path
that realizes the win). For the seq axis of [B,H,S,D] a per-head slice of the
capacity buffer is a contiguous prefix, so a consumer reads it at concat cost.
2. InPlaceKvSdpa: stateful fused op owning the K/V in-place caches, running the CPU
SDPA (FlashSdpaOp::flash_attention_gqa) over the zero-copy views. tract Tensors
cannot be zero-copy views ACROSS an op boundary (Tensor::slice copies), so keeping
the buffers inside the consuming op is what makes the saving real. Drop-in for
{kv_cache(K), kv_cache(V), Sdpa}; does GQA internally.
3. InPlaceKvSdpaTransform: rewrite pass that strips the GQA broadcast chain
(fuse_kv_cache_broadcast_rule) then fuses {cache(K), cache(V), Sdpa} -> InPlaceKvSdpa
so existing decode models adopt the in-place cache transparently.
4. NNEF ser/de: round-trips via tract_transformers_inplace_kv_sdpa (registered).
5. Resume: save_to/load_from checkpoint the cache as [K,V] tensors; freeze/unfreeze
snapshot the running state in-process. Both bit-exact resume; snapshot is O(len).
Validated (11 tests): in-place bit-exact vs concat-grow; fused op matches concat-cache
+ FlashSdpaOp baseline (prefill+decode, GQA, causal/non-causal); runs end-to-end via a
persistent SimpleState; the rewrite fires + the rewritten model matches baseline; NNEF
round-trip; freeze/unfreeze and save/load resume bit-exact; growth amortized. fmt +
clippy clean; transformers lib 23/0 no-regression.
Benched (release, B=1 H=8 D=128):
- cache-update only: 21x (T=256) -> 709x (T=4096), O(T^2) -> O(T)
- end-to-end via the op: 1.10x (256) -> 1.63x (2048), 39% faster decode @2k
- resume checkpoint: O(len), 0.10ms (256) -> 1.76ms (4096), one-time
Follow-up: GPU coupling (sonos#2320 MFA kernel reading capacity buffer + length).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
I'm happy with using what is already there on Metal, including wiring stuff that is not wired as you did. At some point, Cuda took first stage. I'm happy to get Metal up to speed when there are low-hanging fruits. Plugging it stuff which is "nearly there" makes sense. Convolution support is the big remaining gap, as there is no clear equivalent to cudnn (which incidentally i don't like very much, as it's an extra runtime dep). I got rid of GGML vendoring because the in-house kernel, which I actually understand is in the same ballpark on model and architecture that matters to tract. Also it was not vendored gracefully, we had to patch it pretty heavily (so basically forking it). But it was not for any dependency policy. |
|
Roger that! I can do the rebase, the gate fix, and the GQA broadcast
Il giorno ven 12 giu 2026 alle ore 12:53 Mathieu Poumeyrol <
***@***.***> ha scritto:
… *kali* left a comment (sonos/tract#2320)
<#2320 (comment)>
I'm happy with using what is already there on Metal, including wiring
stuff that is not wired as you did.
At some point, Cuda took first stage. I'm happy to get Metal up to speed
when there are low-hanging fruits. Plugging it stuff which is "nearly
there" makes sense.
Convolution support is the big remaining gap, as there is no clear
equivalent to cudnn (which incidentally i don't like very much, as it's an
extra runtime dep).
I got rid of GGML vendoring because the in-house kernel, which I actually
understand is in the same ballpark on model and architecture that matters
to tract. Also it was not vendored gracefully, we had to patch it pretty
heavily (so basically forking it). But it was not for any dependency policy.
—
Reply to this email directly, view it on GitHub
<#2320?email_source=notifications&email_token=APL2Z6QRQL6J2U3CZNTQRBT47PVMRA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGEYDCMZRG422M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KYZTPN52GK4S7MNWGSY3L#issuecomment-4691013175>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6WKTFWNGEOGOKK45KD47PVMRAVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNJVGU3DCNRWG2QXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6TAKZGHBUBO3BWF7GL47PVMRA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGEYDCMZRG422M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KUZTPN52GK4S7NFXXG>
and Android
<https://github.com/notifications/mobile/android/APL2Z6WCRW4ZTF7RYHPLAQD47PVMRA5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGEYDCMZRG422M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2K4ZTPN52GK4S7MFXGI4TPNFSA>.
Download it today!
You are receiving this because you authored the thread.Message ID:
***@***.***>
--
Best Regards
Ckristian Zoli
Email: ***@***.***
|
tract's metal crate already vendors libMetalFlashAttention (Apple / Philip
Turner's flash-attention kernels, MIT) but only used its sgemm/hgemm entry
points -- the fused `attention` kernel shipped inside the metallib was never
dispatched. The MetalTransform instead explodes every Sdpa into
einsum + softmax + einsum, materializing the full (B,H,Sq,Sk) score matrix in
device memory and round-tripping it through three separate kernels.
This wires the fused kernel:
* `dispatch_metal_mfa_attention` drives the vendored `attention` function
(online softmax / flash attention; ABI reconstructed from the v1.0.1
source + on-GPU pipeline reflection). f32 and f16.
* `mfa_attention_head_major` adapts tract's native [B,H,S,D] layout to MFA's
(Q/O=[R,H,D], K=[H,D,C], V=[C,H,D]) on-device via copy_nd.
* `MetalMfaSdpa` op + a `register_metal_op!(Sdpa)` translator route a real
Sdpa node to the fused kernel; unsupported shapes fall back to the existing
explode path. The new `rewire_sdpa_metal` only flattens the Sdpa nodes the
kernel can't take, leaving fusable ones intact (cuda keeps the shared
`rewire_sdpa`).
* causal masking via an additive [Sq,Sk] mask -- the metallib's `triangular`
function-constant alone is a no-op, pinned by a regression test.
Eliminates the (B,H,Sq,Sk) intermediate and collapses three kernels to one.
Measured on M-series (f32, B=1 H=8 S=512 D=64): the kernel is ~2x the explode
path, and an 8-layer all-attention stack (amortizing host sync) runs 2.70x
faster on the attention portion -- so a real model's end-to-end gain is 2.70x
scaled by attention's compute share.
Correctness: bit-close to a CPU reference across f32/f16, head dims 16..128,
masked, causal, multi-head, and head-major layout; an e2e test builds a real
Sdpa model, runs the MetalTransform, asserts it routes to MetalMfaSdpa, and
matches the CPU FlashSdpa output.
Apple MetalFlashAttention: https://github.com/philipturner/metal-flash-attention
Prior art (fused-attention dispatch): llama.cpp ggml-metal flash-attn kernel;
candle-metal-kernels.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
main's new GQA suite (q_heads = r * kv_heads) exposed that mfa_sdpa_supported never compared head counts: GQA Sdpa nodes fused into the 2023 MFA kernel (which predates GQA) and computed garbage. Require equal & concrete Q/K/V head counts so GQA falls back to the explode path, assert the same invariant in eval, and pin the decline + fallback-matches-CPU behavior with a regression test. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
51dbc4f to
9a39f8c
Compare
|
Great, proceeding as-is then. Pushed a rebase plus a fix: the gate never compared Q vs K/V head counts, so your new GQA tests hit the MFA kernel (which predates GQA) and got garbage — GQA now falls back to the explode path, with a regression test. Also FYI: the metallib's Follow-up ready on Have a nice weekend! |
|
/ci llm trying the new toy :) |
✅ CI / large-models: success
|
|
M5 Machine?
Il giorno ven 12 giu 2026 alle ore 15:08 Mathieu Poumeyrol <
***@***.***> ha scritto:
… *kali* left a comment (sonos/tract#2320)
<#2320 (comment)>
/ci llm
trying the new toy :)
—
Reply to this email directly, view it on GitHub
<#2320?email_source=notifications&email_token=APL2Z6URII2ENLVV6BWHBRD47QFF3A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGIYDEOBSGI22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KYZTPN52GK4S7MNWGSY3L#issuecomment-4692028225>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6QTKMPOT6FBTRN655L47QFF3AVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNJVGU3DCNRWG2QXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6TJRFQL35CYOEUHK5347QFF3A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGIYDEOBSGI22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KUZTPN52GK4S7NFXXG>
and Android
<https://github.com/notifications/mobile/android/APL2Z6Q372VOZLNLQBBLLWT47QFF3A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINRZGIYDEOBSGI22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2K4ZTPN52GK4S7MFXGI4TPNFSA>.
Download it today!
You are receiving this because you authored the thread.Message ID:
***@***.***>
--
Best Regards
Ckristian Zoli
Email: ***@***.***
|
|
i wish... just an easier way to start extra validation jobs :) |
@kali — opening this as much as an RFC as a PR, so please read the design question below before the diff.
As I've seen you've been actively building out the GPU path lately — mirroring the CUDA kernels into Metal (
gather,diag_gather), landingscaled_masked_softmax(bool mask + post-softmax mask) on both backends, and adding the transform pre-check + the CPU-fallback-when-a-gpu-op-rejects-a-shape — I wanted to surface something adjacent that's been sitting unused in the Metal crate and check the direction with you before investing further.The opportunity
The vendored
libMetalFlashAttention.metallib(Apple / Philip Turner's MetalFlashAttention, MIT — already in-tree) ships four functions:sgemm,hgemm,convolution, andattention. We only ever dispatchsgemm/hgemm. Theattentionentry is a fully-implemented fused flash-attention kernel (online softmax, never materializes the score matrix) — and it's simply never called. (For completeness:convolutionin this metallib is an empty stub, soattentionis the only unused-but-real kernel here.)Meanwhile
MetalTransformexplodes everySdpaintoeinsum → softmax → einsum, which materializes the full(B,H,Sq,Sk)score buffer in device memory and round-trips it through three kernels — the middle one being thescaled_masked_softmaxyou just landed.This PR wires the fused kernel and routes
Sdpato it.The design question (why this is an RFC)
Dispatching a vendored, 2023-era
macosx13metallib that we don't own is a real commitment — and I noticed you just dropped the unusedGgmlFlashAttnkernel library on the CUDA side (73dada812), which cuts the other way. So I'd rather ask directly than assume:Is wiring this vendored Metal kernel the direction you want — or would you prefer a fresh, owned port (e.g. translating the MLX /
ggml-metalflash-attention kernel into a.metalsource we control)?The case for wiring it now: it's already in-tree, fully implemented, MIT, and measures ~2× — a low-risk way to close the fused-Metal-SDPA gap today. The owned-
.metal-port hedge stays open as a follow-up if metallib longevity on M3/M4 worries you. This PR is the "wire what's already there" option, fully validated, for you to accept or redirect.What it does
dispatch_metal_mfa_attention— drives the vendoredattentionfunction. ABI (buffers / function-constants / grid geometry) reconstructed from the MFA v1.0.1 source + on-GPU pipeline reflection. f32 + f16.mfa_attention_head_major— adapts tract's native[B,H,S,D]to MFA's layout (Q/O=[R,H,D],K=[H,D,C],V=[C,H,D]) on-device viacopy_nd. The one unavoidable copy is theKtranspose (candidate to fold later).MetalMfaSdpaop +register_metal_op!(Sdpa)translator — routes a realSdpanode to the fused kernel. Unsupported shapes returnNoneand fall through to the CPU-fallback path you just added (85255fdb9). A Metal-localrewire_sdpa_metalflattens only theSdpanodes the kernel can't take, leaving fusable ones intact (CUDA keeps the sharedrewire_sdpauntouched).[Sq,Sk]mask — the metallib'striangularfunction-constant alone is a no-op (it computes full attention), pinned by a regression test.Numbers (Apple M-series, f32, B=1 H=8 S=512 D=64)
dispatch_eval)(B,H,Sq,Sk)score buffer(A single-op model bench reports 3.9×, but that's overhead-inflated — the 2.70× multi-layer figure is the honest one, and it's consistent with the kernel-level ~2×.)
Correctness & gates
Sdpamodel, runsMetalTransform, asserts it routes toMetalMfaSdpa, and matches the CPUFlashSdpaoutput.tract-metalsuite 71/0;cargo build --workspaceclean; fmt + clippy clean on the new code.Orthogonal to #2319 (that one is the CPU
FlashSdpapath; this is Metal).Credits / prior art
ggml-metalflash-attn;candle-metal-kernels.🤖 Generated with Claude Code