transformers: KIVI-style KV-cache quantization — packed u8 storage, ~4× memory vs f32#2329
Conversation
|
@kali is this an interested area ? I was even thinking of an SSD Offload but not sure if that goes too far and should be managed externally of tract |
c1fa13a to
fc8c39d
Compare
|
Rebased onto current |
|
Rebased! |
|
[image: image.png]
Il giorno mer 17 giu 2026 alle ore 18:57 Mathieu Poumeyrol <
***@***.***> ha scritto:
… *kali* left a comment (sonos/tract#2329)
<#2329 (comment)>
Rebased!
—
Reply to this email directly, view it on GitHub
<#2329?email_source=notifications&email_token=APL2Z6SXVNJ6ES7IZ4PTQED5ALLY5A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZTGM3DQMJYGU22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KYZTPN52GK4S7MNWGSY3L#issuecomment-4733681855>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/APL2Z6TIXITJC4YGMTROV5D5ALLY5AVCNFSNUABEKJSXA33TNF2G64TZHM4TSNJWGEZDCMZ3JFZXG5LFHM2DKNZRGA3TKMRSHCQXMAQ>
.
Triage notifications, keep track of coding agent tasks and review pull
requests on the go with GitHub Mobile for iOS
<https://github.com/notifications/mobile/ios/APL2Z6SMSIE5RQXJXDZTEOL5ALLY5A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZTGM3DQMJYGU22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2KUZTPN52GK4S7NFXXG>
and Android
<https://github.com/notifications/mobile/android/APL2Z6QVIF65E7B34R6Z65T5ALLY5A5CNFSNUABFM5UWIORPF5TWS5BNNB2WEL2JONZXKZKDN5WW2ZLOOQXTINZTGM3DQMJYGU22M4TFMFZW63VGMF2XI2DPOKSWK5TFNZ2K4ZTPN52GK4S7MFXGI4TPNFSA>.
Download it today!
You are receiving this because you authored the thread.Message ID:
***@***.***>
--
Best Regards
Ckristian Zoli
Email: ***@***.***
|
|
Rebased! |
|
Bear with me :) |
…e bits) Training-free affine quantize<->dequantize for the KV cache: keep every token but at fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.) Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/ kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean. Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ansform
Completes the KIVI-style KV-cache quantization integration:
1. QuantKeyCache: per-channel u8 storage for Keys. D channels each have a running
scale; new tokens quantized under the current channel scale. Memory: T*D + D*8 bytes.
2. QuantValueCache: per-token u8 storage for Values. Each token D bytes + 2 f32 params.
Memory: T*D + T*8 bytes (~4x vs f32 at large D).
3. QuantizedKvSdpa: stateful fused op (Op/EvalOp/TypedOp + OpState + freeze) that
stores K/V in packed u8, dequantizes per-head on each decode step, attends via
FlashSdpaOp (GQA handled). Real u8 bytes, not just float round-trip quality test.
4. QuantizedKvSdpaTransform: auto-wires {cache(K), cache(V), Sdpa} -> QuantizedKvSdpa.
6 tests: quant quality (3 existing) + packed_u8_saves_memory_vs_f32 (>3x saving) +
quantized_kv_sdpa_runs_in_model (engine correctness: near-lossless vs f32 reference) +
transform_fuses_cache_sdpa_to_quantized (structural auto-wiring). fmt+clippy clean,
transformers 18/0 no regression.
Configurable via the bits parameter (1..=16); int8 = near-lossless 4x vs f32 / 2x vs
f16. CommVQ codebook variant is the follow-on.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
tract_transformers_quantized_kv_sdpa primitive: axis + optional scale. Round-trip test: axis and scale survive write_to_tar -> model_for_read. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Rebased again. CI work is so frustrating. |
|
🔴 Bench vs main — 1 speed regression(s) Reference: main nightly, latest 2026-06-18 (0d old) · PR Speed — evaltime · prefill · decode
🟢 3 improvement(s)
lower is better except prefill/decode (tok/s) · adaptive thresholds (max(floor, k×noise) vs the series' own history) · single-shot vs nightly reference · full table → run summary |
|
hmmmmmmm, why that refression on parakeet_tdt_600m_v3_f32f32_preprocessor_1s ??? PS: like that notalgic "hey_snips" model :P |
Just noise I think honestly. We'll have to accept a bit of it until I tweak the threshold rules. As it is this model swap from cpu to gpu constantly, so I think it's very impacted by scheduling.
yeah, still have a couple of this around to play with :) |
Training-free KV-cache quantization: store K and V in packed u8 bytes instead of f32, keeping every token at ~4× less memory — the "keep everything, write in shorthand" alternative to eviction.
The idea
The core asymmetry (Liu et al. 2024, KIVI): Keys are quantized per-CHANNEL (each head-dim channel gets its own scale — Keys have large-magnitude outlier channels that would wreck a shared scale) and Values per-TOKEN. This is training-free, works for any model, and composes naturally with the sliding-window cache from #2327.
Validated on real GPT-2 K/V activations:
The per-channel-K layout matters: int4 per-channel-K is 1.75–1.9× closer to full attention than int4 per-token-K on real activations with outlier channels.
What's in the PR
QuantValueCache— per-token u8 storage: each token D bytes + 2 f32 params. Memory:T×D + T×8bytes.QuantKeyCache— per-channel u8 storage: running scale per channel, updated on each new token. Memory:T×D + D×8bytes.QuantizedKvSdpa— stateful fused op that owns the K/V packed caches, dequantizes per-head on each decode step, and attends viaFlashSdpaOp(GQA handled). Inputs[Q, K_new, V_new], output has Q's shape.QuantizedKvSdpaTransform— auto-wires{DynKeyValueCache(K), DynKeyValueCache(V), Sdpa}→QuantizedKvSdpa, so existing decode models adopt quantized storage transparently (mirrors the pattern from transformers: in-place KV cache for decode via a fused InPlaceKvSdpa op #2321 and onnx,transformers: sliding-window attention — GQA window + bounded ring-buffer decode (#2323) #2327).tract_transformers_quantized_kv_sdpa, registered.Correctness & gates
7 tests: quality validation (round-trip bounded, per-channel beats per-token on outlier channels, 8-bit near-lossless for attention);
packed_u8_saves_memory_vs_f32(>3× measured);quantized_kv_sdpa_runs_in_model(runs through the engine, near-lossless vs f32 reference);transform_fuses_cache_sdpa_to_quantized(structural auto-wiring); NNEF round-trip.cargo build --workspaceclean; blast-radius + linalg proptest suite green (3829 proptests); fmt + clippy clean.Relationship to other PRs
~33 MB→~8 MBKV.Research & prior art
🤖 Generated with Claude Code