Skip to content

refactor(kv): index_copy_ for KV-cache writes (enabler for multi-user batching + CUDA Graphs)#53

Open
prashant182 wants to merge 12 commits into
Robbyant:mainfrom
prashant182:feat/e1-index-copy-kv-write
Open

refactor(kv): index_copy_ for KV-cache writes (enabler for multi-user batching + CUDA Graphs)#53
prashant182 wants to merge 12 commits into
Robbyant:mainfrom
prashant182:feat/e1-index-copy-kv-write

Conversation

@prashant182
Copy link
Copy Markdown
Contributor

Summary

Replaces kv_cache[\"k\"][:, current_start:current_end] = roped_key with kv_cache[\"k\"].index_copy_(1, kv_write_index, roped_key) in the local_attn_size == -1 fast-path. kv_write_index is a [seq_lens]-shape tensor built once per chunk in generate() and threaded as a kwarg. Mirrored in both non-SP (CausalWanSelfAttention.forward) and SP (sp_attn_forward_causal) paths.

Bit-identical at B=1 (MD5 ed2f82628308a3f8acd9b7935bb84401), works correctly at B=2 (verified in an isolated test).

What this is and isn't

This PR ships no perf win on its own at B=1 — it's a refactor that swaps one correct primitive for another. The reason it's worth landing now is that it opens two doors that the slice-indexing form keeps shut:

  1. Multi-user batching (B>1). Slice-assign with a Python-int current_start works at B=1 but doesn't generalize to multi-user inference, where each batch element may be writing the same positions of its own KV slab. index_copy_ does generalize — the same kv_write_index is correct for all batch elements in homogeneous batching (different users on the same chunk timeline). Heterogeneous batching needs PagedAttention; explicitly out of scope here.
  2. CUDA Graphs. Removes one of two Python-int slice sources that were keeping the DiT forward out of torch.compile(mode='reduce-overhead'). The cache READ slice (cache[\"k\"][:, :local_end_index]) is still Python-int; addressing that is a separate experiment.

Verification

  • Isolated test mirrors CausalWanSelfAttention.forward's control flow with tiny dims (no FSDP/SP). Three cases:
    • B=1, 7-chunk sequence — outputs and final cache state bit-equal to slice path.
    • B=2, 7-chunk sequence — outputs and final cache state bit-equal (validates the multi-user precondition).
    • torch.compile probe — dynamic=True collapses unique graph count from 5 → 2 (read-slice work pending for a follow-up).
  • End-to-end: bench MD5 ed2f82628308a3f8acd9b7935bb84401 matches the locked baseline; generate() 13414 ms is within run-to-run noise of the post-B3 baseline (13523 ms).

Non-regressive design

Every new kwarg (kv_write_index) defaults to None. With None, the slice path remains the eager fallback. External callers of WanModelFast.forward, sp_dit_forward_causal, or the attention forwards keep working unchanged.

Stack

Stacked on #51 (B3). Reported diff shrinks once that lands.

prashant182 and others added 12 commits May 17, 2026 17:35
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls
prewarm() once, then runs generate() in a loop over multiple prompts.
Borrows the "don't recreate the engine per request" pattern from LLM
serving (vLLM continuous batching, SGLang).

The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100)
is paid once. Each subsequent generate() runs at amortized steady-state
~15.5s. For users running multiple generations from the same starting
image, this is a large multiplier on top of the existing prewarm() win
from #1.

Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames:

  construct:                 128932 ms (once)
  prewarm:                     7020 ms (once)
  generate[0]:                15587 ms
  generate[1]:                15202 ms
  generate[2]:                15160 ms
  total wall-clock:          181901 ms

  naive (3 separate invocations of generate_fast.py):  ~568s
  speedup with persistent pipe:                         3.13x

Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x.

Library code is unchanged. The example is pure documentation of how to
use the existing public API (WanI2VFast, prewarm, generate) correctly
for the multi-call case.

Files:
  + examples/persistent_inference.py  (new, 174 lines)

Tier A: outputs are bit-identical to the canonical baseline at the same
seed; the script is functionally equivalent to running generate_fast.py
3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.

Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.

Implementation:
  - `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
  - `clear_text_cache()` public method to free the dict
  - In generate(), check cache before invoking text_encoder; populate
    after first compute.

Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.

Measured on 8xH100, 480*832/81 frames, same prompt, twice:

  call[0] (cold T5):  15520.9 ms  md5=ed2f8262… (cache_size: 1)
  call[1] (cached):   15089.4 ms  md5=ed2f8262… (cache_size: 1)
  delta:              +431.5 ms saved

Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.

Δ generate_ms (repeat call): 15521 → 15089  (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen
kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock,
CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips
the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention
forward (model_fast.py:108, sequence_parallel.py:462) — caller already has
the value as a Python int.

Default frame_seqlen=None falls back to the original .item() path for any
external callers. Bit-identical otherwise.

Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the
generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate()
through WanModelFast.forward / sp_dit_forward_causal /
CausalWanAttentionBlock.forward into WanCrossAttention.forward.

WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the
top of generate() and flipped True after the first DiT forward. When the
kwarg is provided, WanCrossAttention uses it as the gate; otherwise it
falls back to the existing crossattn_cache["is_init"].item() check, so
external callers (and the prewarm() throwaway forward) keep working.

The .fill_(1) on the tensor is preserved to keep cache state consistent
for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both
CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal
(sequence_parallel.py) for the local_attn_size == -1 case (global cache,
the default and only path our shipped models use).

In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"]
start at 0 and advance by current_end - current_start every forward, so
local_end_index always equals current_end and local_start_index equals
current_start — both already available as Python ints. Eliminates the
two .item() syncs in the previous else-branch.

The sliding-window eviction logic (local_attn_size > 0) is preserved
verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens)
once at the top (replacing two separate int(seq_lens) casts on lines 309-310)
and threads it through kwargs into sp_attn_forward_causal.

Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per
attention layer (~32 syncs per forward). Now the per-layer cast is gone;
the value arrives as a Python int via the new kwarg.

CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept
the same kwarg for signature parity (ignored in the non-SP path).
perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())
E1 of the autoresearch sequence — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md.

Replaces `kv_cache["k"][:, current_start:current_end] = roped_key` with
`kv_cache["k"].index_copy_(1, kv_write_index, roped_key)` in the
local_attn_size == -1 fast-path. kv_write_index is a [seq_lens]-shape
tensor built once per chunk in generate() via torch.arange and threaded
through as a kwarg.

Mirrored in both the non-SP (CausalWanSelfAttention.forward) and SP
(sp_attn_forward_causal) paths.

Why
- Multi-tenant prerequisite: at B>1 different users can be batched into
  one forward, each writing the same positions of their own KV slab.
  Slice-assign with a Python-int range works at B=1 but doesn't
  generalize cleanly; index_copy_ does.
- Graph capture: removes one of two Python-int slice sources keeping the
  DiT forward out of CUDA Graphs. (The cache READ slice
  `cache["k"][:, :local_end_index]` is still Python-int — that's the next
  experiment, E2.)

Verification
- Isolated test (test_e1_index_copy.py): B=1 and B=2, 7-chunk sequences,
  output and cache state bit-equal to slice path.
- End-to-end bench MD5 ed2f82628308a3f8acd9b7935bb84401 (locked).
- generate() 13414 ms — within noise of the post-B3 baseline (13523 ms).
  This commit ships no perf gain on its own; it's the enabler.

Defaults preserved: every new kwarg defaults to None; the slice path
remains the eager fallback for callers that don't pass kv_write_index.
@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 21, 2026

Nudge for review when you have a moment please cc @Robbyant @JingyeChen @qiuyu96

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant