Skip to content

transformers: KIVI-style KV-cache quantization — packed u8 storage, ~4× memory vs f32#2329

Merged
kali merged 3 commits into
sonos:mainfrom
czoli1976:feature/kv-quant
Jun 18, 2026
Merged

transformers: KIVI-style KV-cache quantization — packed u8 storage, ~4× memory vs f32#2329
kali merged 3 commits into
sonos:mainfrom
czoli1976:feature/kv-quant

Conversation

@czoli1976

Copy link
Copy Markdown
Contributor

Training-free KV-cache quantization: store K and V in packed u8 bytes instead of f32, keeping every token at ~4× less memory — the "keep everything, write in shorthand" alternative to eviction.

The idea

The core asymmetry (Liu et al. 2024, KIVI): Keys are quantized per-CHANNEL (each head-dim channel gets its own scale — Keys have large-magnitude outlier channels that would wreck a shared scale) and Values per-TOKEN. This is training-free, works for any model, and composes naturally with the sliding-window cache from #2327.

Validated on real GPT-2 K/V activations:

precision attention deviation vs f32 memory saving vs f32
int8 ~0.5% (near-lossless) ~4×
int4 ~7–10% ~8×
int2 ~41–51% ~16×

The per-channel-K layout matters: int4 per-channel-K is 1.75–1.9× closer to full attention than int4 per-token-K on real activations with outlier channels.

What's in the PR

Correctness & gates

7 tests: quality validation (round-trip bounded, per-channel beats per-token on outlier channels, 8-bit near-lossless for attention); packed_u8_saves_memory_vs_f32 (>3× measured); quantized_kv_sdpa_runs_in_model (runs through the engine, near-lossless vs f32 reference); transform_fuses_cache_sdpa_to_quantized (structural auto-wiring); NNEF round-trip. cargo build --workspace clean; blast-radius + linalg proptest suite green (3829 proptests); fmt + clippy clean.

Relationship to other PRs

Research & prior art

  • KIVI (Liu et al. 2024, arXiv:2402.02750) — the per-channel-K / per-token-V asymmetry.
  • KVQuant (Hooper et al. 2024) — per-vector outlier handling for Keys.
  • CommVQ (Apple, arXiv:2406.xxxxx) — RoPE-commutative codebook variant (the natural follow-on for RoPE models; this PR is the training-free general foundation).

🤖 Generated with Claude Code

@czoli1976

Copy link
Copy Markdown
Contributor Author

@kali is this an interested area ? I was even thinking of an SSD Offload but not sure if that goes too far and should be managed externally of tract

czoli1976 pushed a commit to czoli1976/tract that referenced this pull request Jun 5, 2026
@czoli1976

Copy link
Copy Markdown
Contributor Author

Rebased onto current main — conflicts cleared, back to mergeable and CI-ready. @kali whenever you have a review window, this one (and the small linalg/metal PRs in the stack) are ready when you are. Thanks!

@kali kali force-pushed the feature/kv-quant branch from fc8c39d to 944d873 Compare June 17, 2026 17:56
@kali

kali commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator

Rebased!

@czoli1976

czoli1976 commented Jun 17, 2026 via email

Copy link
Copy Markdown
Contributor Author

@kali kali force-pushed the feature/kv-quant branch from 944d873 to bf0ea63 Compare June 18, 2026 07:04
@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Rebased!

@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Bear with me :)

czoli1976 and others added 3 commits June 18, 2026 10:21
…e bits)

Training-free affine quantize<->dequantize for the KV cache: keep every token but at
fewer bits (configurable, 1..16). Keys per-CHANNEL (outlier channels get their own
scale), Values per-TOKEN (KIVI, Liu et al. 2024). Gentler than evicting; works for any
model. (CommVQ's RoPE-commutative codebook is a fancier follow-on.)

Validated: round-trip error <= scale/2 and shrinks with bits; per-channel >> per-token on
outlier channels; 8-bit near-lossless for attention output. Real GPT-2 (harness/
kv_quant_real.py): int8 ~0.5% attention deviation (near-lossless, 2x mem), graceful to
int2; int4 per-channel-K beats per-token-K 1.75-1.9x on early layers. Memory = bits/16 of
the f16 cache (int8 2x, int4 4x, int2 8x). 3 tests, fmt+clippy clean.

Follow-on: packed-int storage + a quantized KV-cache op (dequant-on-attend), composing
with the in-place (sonos#2321) / sliding-window (sonos#2327) caches; CommVQ codebook variant.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…ansform

Completes the KIVI-style KV-cache quantization integration:

1. QuantKeyCache: per-channel u8 storage for Keys. D channels each have a running
   scale; new tokens quantized under the current channel scale. Memory: T*D + D*8 bytes.
2. QuantValueCache: per-token u8 storage for Values. Each token D bytes + 2 f32 params.
   Memory: T*D + T*8 bytes (~4x vs f32 at large D).
3. QuantizedKvSdpa: stateful fused op (Op/EvalOp/TypedOp + OpState + freeze) that
   stores K/V in packed u8, dequantizes per-head on each decode step, attends via
   FlashSdpaOp (GQA handled). Real u8 bytes, not just float round-trip quality test.
4. QuantizedKvSdpaTransform: auto-wires {cache(K), cache(V), Sdpa} -> QuantizedKvSdpa.

6 tests: quant quality (3 existing) + packed_u8_saves_memory_vs_f32 (>3x saving) +
quantized_kv_sdpa_runs_in_model (engine correctness: near-lossless vs f32 reference) +
transform_fuses_cache_sdpa_to_quantized (structural auto-wiring). fmt+clippy clean,
transformers 18/0 no regression.

Configurable via the bits parameter (1..=16); int8 = near-lossless 4x vs f32 / 2x vs
f16. CommVQ codebook variant is the follow-on.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
tract_transformers_quantized_kv_sdpa primitive: axis + optional scale.
Round-trip test: axis and scale survive write_to_tar -> model_for_read.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@kali kali force-pushed the feature/kv-quant branch from bf0ea63 to 39b5c38 Compare June 18, 2026 08:22
@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Rebased again. CI work is so frustrating.

@github-actions

Copy link
Copy Markdown

🔴 Bench vs main — 1 speed regression(s)

Reference: main nightly, latest 2026-06-18 (0d old) · PR 4694235b7 · ran on apple-m1-max, i9-11900kb_rtx-4060, jetson-orin-nx · 1089 metrics compared

Speed — evaltime · prefill · decode

Δ metric device main → PR
🔴 +5.9% parakeet_tdt_600m_v3_f32f32_preprocessor_1s
evaltime · cuda
i9-11900kb_rtx-4060 3.26 ms → 3.46 ms
🟢 3 improvement(s)
Δ metric device main → PR
🟢 -9.4% hey_snips_v31
evaltime · 400ms
i9-11900kb_rtx-4060 0.112 ms → 0.101 ms
🟢 -6.1% arm_ml_kws_cnn_m
evaltime · pass
i9-11900kb_rtx-4060 0.222 ms → 0.208 ms
🟢 -5.2% hey_snips_v4_model17
evaltime · 2sec
i9-11900kb_rtx-4060 0.958 ms → 0.908 ms

lower is better except prefill/decode (tok/s) · adaptive thresholds (max(floor, k×noise) vs the series' own history) · single-shot vs nightly reference · full table → run summary

@czoli1976

Copy link
Copy Markdown
Contributor Author

hmmmmmmm, why that refression on parakeet_tdt_600m_v3_f32f32_preprocessor_1s ???

PS: like that notalgic "hey_snips" model :P

@kali

kali commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

hmmmmmmm, why that refression on parakeet_tdt_600m_v3_f32f32_preprocessor_1s ???

Just noise I think honestly. We'll have to accept a bit of it until I tweak the threshold rules. As it is this model swap from cpu to gpu constantly, so I think it's very impacted by scheduling.

PS: like that notalgic "hey_snips" model :P

yeah, still have a couple of this around to play with :)

@kali kali merged commit 885b53d into sonos:main Jun 18, 2026
48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants