refactor(kv): index_copy_ for KV-cache writes (enabler for multi-user batching + CUDA Graphs)#53
Open
prashant182 wants to merge 12 commits into
Open
refactor(kv): index_copy_ for KV-cache writes (enabler for multi-user batching + CUDA Graphs)#53prashant182 wants to merge 12 commits into
prashant182 wants to merge 12 commits into
Conversation
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls prewarm() once, then runs generate() in a loop over multiple prompts. Borrows the "don't recreate the engine per request" pattern from LLM serving (vLLM continuous batching, SGLang). The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100) is paid once. Each subsequent generate() runs at amortized steady-state ~15.5s. For users running multiple generations from the same starting image, this is a large multiplier on top of the existing prewarm() win from #1. Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames: construct: 128932 ms (once) prewarm: 7020 ms (once) generate[0]: 15587 ms generate[1]: 15202 ms generate[2]: 15160 ms total wall-clock: 181901 ms naive (3 separate invocations of generate_fast.py): ~568s speedup with persistent pipe: 3.13x Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x. Library code is unchanged. The example is pure documentation of how to use the existing public API (WanI2VFast, prewarm, generate) correctly for the multi-call case. Files: + examples/persistent_inference.py (new, 174 lines) Tier A: outputs are bit-identical to the canonical baseline at the same seed; the script is functionally equivalent to running generate_fast.py 3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.
Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.
Implementation:
- `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
- `clear_text_cache()` public method to free the dict
- In generate(), check cache before invoking text_encoder; populate
after first compute.
Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.
Measured on 8xH100, 480*832/81 frames, same prompt, twice:
call[0] (cold T5): 15520.9 ms md5=ed2f8262… (cache_size: 1)
call[1] (cached): 15089.4 ms md5=ed2f8262… (cache_size: 1)
delta: +431.5 ms saved
Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.
Δ generate_ms (repeat call): 15521 → 15089 (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock, CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention forward (model_fast.py:108, sequence_parallel.py:462) — caller already has the value as a Python int. Default frame_seqlen=None falls back to the original .item() path for any external callers. Bit-identical otherwise. Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate() through WanModelFast.forward / sp_dit_forward_causal / CausalWanAttentionBlock.forward into WanCrossAttention.forward. WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the top of generate() and flipped True after the first DiT forward. When the kwarg is provided, WanCrossAttention uses it as the gate; otherwise it falls back to the existing crossattn_cache["is_init"].item() check, so external callers (and the prewarm() throwaway forward) keep working. The .fill_(1) on the tensor is preserved to keep cache state consistent for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal (sequence_parallel.py) for the local_attn_size == -1 case (global cache, the default and only path our shipped models use). In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"] start at 0 and advance by current_end - current_start every forward, so local_end_index always equals current_end and local_start_index equals current_start — both already available as Python ints. Eliminates the two .item() syncs in the previous else-branch. The sliding-window eviction logic (local_attn_size > 0) is preserved verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens) once at the top (replacing two separate int(seq_lens) casts on lines 309-310) and threads it through kwargs into sp_attn_forward_causal. Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per attention layer (~32 syncs per forward). Now the per-layer cast is gone; the value arrives as a Python int via the new kwarg. CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept the same kwarg for signature parity (ignored in the non-SP path).
perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())
E1 of the autoresearch sequence — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md. Replaces `kv_cache["k"][:, current_start:current_end] = roped_key` with `kv_cache["k"].index_copy_(1, kv_write_index, roped_key)` in the local_attn_size == -1 fast-path. kv_write_index is a [seq_lens]-shape tensor built once per chunk in generate() via torch.arange and threaded through as a kwarg. Mirrored in both the non-SP (CausalWanSelfAttention.forward) and SP (sp_attn_forward_causal) paths. Why - Multi-tenant prerequisite: at B>1 different users can be batched into one forward, each writing the same positions of their own KV slab. Slice-assign with a Python-int range works at B=1 but doesn't generalize cleanly; index_copy_ does. - Graph capture: removes one of two Python-int slice sources keeping the DiT forward out of CUDA Graphs. (The cache READ slice `cache["k"][:, :local_end_index]` is still Python-int — that's the next experiment, E2.) Verification - Isolated test (test_e1_index_copy.py): B=1 and B=2, 7-chunk sequences, output and cache state bit-equal to slice path. - End-to-end bench MD5 ed2f82628308a3f8acd9b7935bb84401 (locked). - generate() 13414 ms — within noise of the post-B3 baseline (13523 ms). This commit ships no perf gain on its own; it's the enabler. Defaults preserved: every new kwarg defaults to None; the slice path remains the eager fallback for callers that don't pass kv_write_index.
This was referenced May 20, 2026
Contributor
Author
|
Nudge for review when you have a moment please cc @Robbyant @JingyeChen @qiuyu96 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replaces
kv_cache[\"k\"][:, current_start:current_end] = roped_keywithkv_cache[\"k\"].index_copy_(1, kv_write_index, roped_key)in thelocal_attn_size == -1fast-path.kv_write_indexis a[seq_lens]-shape tensor built once per chunk ingenerate()and threaded as a kwarg. Mirrored in both non-SP (CausalWanSelfAttention.forward) and SP (sp_attn_forward_causal) paths.Bit-identical at B=1 (MD5
ed2f82628308a3f8acd9b7935bb84401), works correctly at B=2 (verified in an isolated test).What this is and isn't
This PR ships no perf win on its own at B=1 — it's a refactor that swaps one correct primitive for another. The reason it's worth landing now is that it opens two doors that the slice-indexing form keeps shut:
current_startworks at B=1 but doesn't generalize to multi-user inference, where each batch element may be writing the same positions of its own KV slab.index_copy_does generalize — the samekv_write_indexis correct for all batch elements in homogeneous batching (different users on the same chunk timeline). Heterogeneous batching needs PagedAttention; explicitly out of scope here.torch.compile(mode='reduce-overhead'). The cache READ slice (cache[\"k\"][:, :local_end_index]) is still Python-int; addressing that is a separate experiment.Verification
CausalWanSelfAttention.forward's control flow with tiny dims (no FSDP/SP). Three cases:torch.compileprobe —dynamic=Truecollapses unique graph count from 5 → 2 (read-slice work pending for a follow-up).ed2f82628308a3f8acd9b7935bb84401matches the locked baseline;generate()13414 ms is within run-to-run noise of the post-B3 baseline (13523 ms).Non-regressive design
Every new kwarg (
kv_write_index) defaults toNone. WithNone, the slice path remains the eager fallback. External callers ofWanModelFast.forward,sp_dit_forward_causal, or the attention forwards keep working unchanged.Stack
Stacked on #51 (B3). Reported diff shrinks once that lands.