diff --git a/examples/persistent_inference.py b/examples/persistent_inference.py new file mode 100644 index 0000000..9948343 --- /dev/null +++ b/examples/persistent_inference.py @@ -0,0 +1,198 @@ +"""Persistent-pipeline example: amortize model load + warmup across many prompts. + +Pattern borrowed from LLM serving (vLLM, SGLang): don't construct the engine +per request. Construct WanI2VFast once, call prewarm() once, then loop over +prompts. Cold-start cost (~100 s model load + ~7 s prewarm) is paid once; +each subsequent generate() runs at steady-state speed. + +Usage: + MASTER_ADDR=127.0.0.1 MASTER_PORT=29500 \\ + torchrun --nproc_per_node=8 \\ + --master_addr=127.0.0.1 --master_port=29500 \\ + examples/persistent_inference.py \\ + --ckpt_dir lingbot-world-base-cam \\ + --image examples/03/image.jpg \\ + --action_path examples/03 \\ + --save_dir output + +Three prompts run back-to-back. Output: output/persistent_{0,1,2}.mp4. + +Wall-clock expectation (8xH100, 480*832, 81 frames): + Naive (3 separate `generate_fast.py` invocations): ~3 x 126 s ~= 378 s. + Persistent (this script): ~104 s + 7 s + 3 x 15.5 s ~= 158 s. +""" + +import argparse +import logging +import os +import sys +import time +from pathlib import Path + +# Allow running this file directly from the repo root: add the parent dir +# (repo root) to sys.path so `import wan` resolves to the in-tree package. +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +import torch +import torch.distributed as dist +from PIL import Image + +import wan # noqa: E402 +from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS # noqa: E402 +from wan.distributed.util import init_distributed_group # noqa: E402 +from wan.utils.utils import save_video # noqa: E402 + + +PROMPTS = [ + "A serene lakeside scene with a lone tree standing in calm water, " + "surrounded by distant snow-capped mountains under a bright blue sky " + "with drifting white clouds — gentle ripples reflect the tree and " + "sky, creating a tranquil, meditative atmosphere.", + "A sweeping cinematic journey along the Great Wall of China, winding " + "through golden autumn hills under a brilliant blue sky — stone " + "pathways stretch into the distance, watchtowers stand sentinel.", + "Aerial flight over a vast tropical rainforest at dawn, mist rising " + "from the canopy, sunlight breaking through tall trees, a winding " + "river snaking through the green expanse below.", +] + + +def _parse_args(): + parser = argparse.ArgumentParser(description=__doc__.split("\n")[0]) + parser.add_argument("--task", default="i2v-A14B", + choices=list(WAN_CONFIGS.keys())) + parser.add_argument("--size", default="480*832") + parser.add_argument("--ckpt_dir", required=True) + parser.add_argument("--image", required=True) + parser.add_argument("--action_path", default=None) + parser.add_argument("--frame_num", type=int, default=81) + parser.add_argument("--chunk_size", type=int, default=3) + parser.add_argument("--base_seed", type=int, default=42) + parser.add_argument("--save_dir", default="output") + parser.add_argument("--ulysses_size", type=int, default=8) + return parser.parse_args() + + +def _init_distributed(): + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + if world_size > 1: + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + init_distributed_group() + if rank == 0: + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) + else: + logging.basicConfig(level=logging.ERROR) + return rank, local_rank, world_size + + +def main(): + args = _parse_args() + rank, local_rank, world_size = _init_distributed() + + cfg = WAN_CONFIGS[args.task] + img = Image.open(args.image).convert("RGB") + + if rank == 0: + logging.info(f"Persistent inference: {len(PROMPTS)} prompts, " + f"size={args.size}, frame_num={args.frame_num}") + + # ── Cold start: construct pipe once. Paid before the timed window. ── + t_construct = time.perf_counter() + pipe = wan.WanI2VFast( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=local_rank, + rank=rank, + t5_fsdp=True, + dit_fsdp=True, + use_sp=(args.ulysses_size > 1), + ) + construct_ms = (time.perf_counter() - t_construct) * 1000.0 + if rank == 0: + logging.info(f"WanI2VFast() construction: {construct_ms:.0f} ms") + + # ── One-time prewarm. After this, generate() runs at steady state. ── + t_warm = time.perf_counter() + pipe.prewarm( + img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + chunk_size=args.chunk_size, + ) + warm_ms = (time.perf_counter() - t_warm) * 1000.0 + if rank == 0: + logging.info(f"prewarm(): {warm_ms:.0f} ms") + + # ── Generate loop. Each call should run at amortized steady-state cost. ── + save_dir = Path(args.save_dir) + if rank == 0: + save_dir.mkdir(parents=True, exist_ok=True) + + per_call_ms = [] + for i, prompt in enumerate(PROMPTS): + if dist.is_initialized(): + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + + video = pipe.generate( + prompt, + img, + action_path=args.action_path, + chunk_size=args.chunk_size, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=cfg.sample_shift, + seed=args.base_seed + i, + offload_model=False, + ) + + if dist.is_initialized(): + torch.cuda.synchronize() + dist.barrier() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + per_call_ms.append(elapsed_ms) + + if rank == 0: + out_path = save_dir / f"persistent_{i}.mp4" + save_video( + tensor=video[None], + save_file=str(out_path), + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1), + ) + logging.info( + f"prompt[{i}]: generate() {elapsed_ms:.0f} ms -> {out_path}") + + if rank == 0: + total = construct_ms + warm_ms + sum(per_call_ms) + logging.info("=" * 60) + logging.info(f"SUMMARY prompts={len(PROMPTS)} hardware=8xH100") + logging.info(f" construct: {construct_ms:>8.0f} ms (once)") + logging.info(f" prewarm: {warm_ms:>8.0f} ms (once)") + for i, ms in enumerate(per_call_ms): + logging.info(f" generate[{i}]: {ms:>8.0f} ms") + avg = sum(per_call_ms) / len(per_call_ms) + logging.info(f" generate avg: {avg:>8.0f} ms") + logging.info(f" total wall-clock: {total:>8.0f} ms") + naive_est = construct_ms + warm_ms + len(PROMPTS) * (construct_ms + avg) + logging.info(f" naive estimate (separate invocations): {naive_est:>8.0f} ms") + logging.info(f" speedup vs naive: {naive_est / total:.2f}x") + logging.info("=" * 60) + + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/wan/distributed/sequence_parallel.py b/wan/distributed/sequence_parallel.py index f6d27cd..b25620a 100644 --- a/wan/distributed/sequence_parallel.py +++ b/wan/distributed/sequence_parallel.py @@ -258,6 +258,9 @@ def sp_dit_forward_causal( crossattn_cache=None, current_start=0, max_attention_size=1_000_000, + frame_seqlen=None, + cross_attn_first_call=None, + kv_write_index=None, ): """ x: A list of videos each with shape [C, T, H, W]. @@ -304,10 +307,14 @@ def sp_dit_forward_causal( assert seq_lens.max() <= seq_len x = torch.cat(x) - # Pad sequence to be divisible by world_size for SP chunking + # Pad sequence to be divisible by world_size for SP chunking. + # int(seq_lens) is one cudaStreamSynchronize; cache it once so each + # attention layer can reuse via the seq_lens_int kwarg instead of + # re-casting per-layer (32x). sp_size = get_world_size() - padded_seq_lens = int((seq_lens + sp_size - 1) // sp_size * sp_size) - sp_pad_len = padded_seq_lens - int(seq_lens) + seq_lens_int = int(seq_lens) + padded_seq_lens = ((seq_lens_int + sp_size - 1) // sp_size) * sp_size + sp_pad_len = padded_seq_lens - seq_lens_int if sp_pad_len > 0: x = torch.cat([x, x.new_zeros(x.size(0), sp_pad_len, x.size(2))], dim=1) @@ -378,7 +385,11 @@ def sp_dit_forward_causal( context=context, context_lens=context_lens, dit_cond_dict=dit_cond_dict, - max_attention_size=max_attention_size) + max_attention_size=max_attention_size, + frame_seqlen=frame_seqlen, + cross_attn_first_call=cross_attn_first_call, + seq_lens_int=seq_lens_int, + kv_write_index=kv_write_index) for block_index, block in enumerate(self.blocks): kwargs.update( @@ -410,7 +421,10 @@ def sp_attn_forward_causal( freqs, kv_cache=None, current_start=0, - max_attention_size=1_000_000): + max_attention_size=1_000_000, + frame_seqlen=None, + seq_lens_int=None, + kv_write_index=None): r""" Sequence-parallel causal self-attention using Ulysses all-to-all. @@ -457,9 +471,11 @@ def qkv_fn(x): # padded_seq_lens = s * sp_size may exceed seq_lens due to SP padding padded_seq_lens = s * sp_size - seq_lens_int = int(seq_lens) + if seq_lens_int is None: + seq_lens_int = int(seq_lens) - frame_seqlen = math.prod(grid_sizes[0][1:]).item() + if frame_seqlen is None: + frame_seqlen = math.prod(grid_sizes[0][1:]).item() current_start_frame = current_start // frame_seqlen # Apply causal RoPE on full (padded) sequence with local heads @@ -475,7 +491,23 @@ def qkv_fn(x): current_end = current_start + seq_lens kv_cache_size = kv_cache["k"].shape[1] - if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and ( + if self.local_attn_size == -1: + # Fast path (no eviction possible — cache is global). Both indices + # advance identically every forward, so local_end_index == + # current_end and local_start_index == current_start. + local_end_index = current_start + seq_lens_int + local_start_index = current_start + if kv_write_index is not None: + # Graph-stable write: index is a tensor input (shape fixed, + # contents vary per chunk). Lets torch.compile capture this + # forward once and replay across chunks instead of recompiling + # per current_start. + kv_cache["k"].index_copy_(1, kv_write_index, key) + kv_cache["v"].index_copy_(1, kv_write_index, v) + else: + kv_cache["k"][:, local_start_index:local_end_index] = key + kv_cache["v"][:, local_start_index:local_end_index] = v + elif (current_end > kv_cache["global_end_index"].item()) and ( seq_lens + kv_cache["local_end_index"].item() > kv_cache_size): # Calculate the number of new tokens added in this step # Shift existing cache content left to discard oldest tokens diff --git a/wan/image2video_fast.py b/wan/image2video_fast.py index 4061f47..bd08026 100644 --- a/wan/image2video_fast.py +++ b/wan/image2video_fast.py @@ -1,4 +1,5 @@ import gc +import hashlib import logging import math import os @@ -142,6 +143,22 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt + # T5 prompt-embedding cache. Same-prompt re-encodes hit this dict + # instead of re-running the umt5-xxl encoder (~360 ms/call). + # Keyed by sha256(prompt.utf8); value is the list returned by + # T5EncoderModel.__call__ (already device-resident). Unbounded; + # callers can clear via `pipe.clear_text_cache()` if needed. + self._t5_cache: dict[str, list] = {} + + # Reset per generate() and flipped True after the first DiT forward. + # Passed into model.forward as `cross_attn_first_call` to skip the + # crossattn_cache["is_init"].item() sync inside WanCrossAttention. + self._cross_attn_initialized: bool = False + + def clear_text_cache(self): + """Drop all cached T5 prompt embeddings. Frees ~4 MB per entry.""" + self._t5_cache.clear() + def prewarm( self, img, @@ -267,6 +284,10 @@ def _noop_no_sync(): crossattn_cache=warmup_cross_kv, current_start=0, max_attention_size=kv_size, + frame_seqlen=frame_seqlen, + kv_write_index=torch.arange( + 0, chunk_size * frame_seqlen, + device=self.device, dtype=torch.long), ) if dist.is_initialized(): @@ -432,6 +453,10 @@ def generate(self, max_seq_len = chunk_size * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2]) max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size + # Reset per-generate state: cross-attn K/V cache will be freshly + # initialized below; the first DiT forward must compute and store. + self._cross_attn_initialized = False + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) @@ -458,14 +483,22 @@ def generate(self, timesteps = self.scheduler.timesteps[timesteps_index] # preprocess - if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([input_prompt], self.device) - if offload_model: - self.text_encoder.model.cpu() + # T5 cache: skip the encoder entirely if we've seen this exact prompt + # before in this pipe instance. Bit-identical: cached tensor is the + # same object returned by the prior call. + cache_key = hashlib.sha256(input_prompt.encode('utf-8')).hexdigest() + if cache_key in self._t5_cache: + context = self._t5_cache[cache_key] else: - context = self.text_encoder([input_prompt], torch.device('cpu')) - context = [t.to(self.device) for t in context] + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([input_prompt], self.device) + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([input_prompt], torch.device('cpu')) + context = [t.to(self.device) for t in context] + self._t5_cache[cache_key] = context # cam preparation (only if action_path is provided) dit_cond_dict = None @@ -581,6 +614,15 @@ def noop_no_sync(): "c2ws_plucker_emb": current_c2ws_plucker_emb.chunk(1, dim=0), } + current_start_int = chunk_id * chunk_size * frame_seqlen + current_end_int = current_start_int + chunk_size * frame_seqlen + # Pre-built KV-write index. Shape [seq_lens] is fixed; only + # the contents shift per chunk. Lets the inner attention + # forward use index_copy_ instead of Python-int slice + # indexing, which is the gate to torch.compile / CUDA Graphs. + kv_write_index = torch.arange( + current_start_int, current_end_int, + device=self.device, dtype=torch.long) kwargs = { 'context': [context[0]], 'seq_len': max_seq_len, @@ -588,8 +630,10 @@ def noop_no_sync(): 'dit_cond_dict': dit_cond_dict, 'kv_cache': self_kv_cache, 'crossattn_cache': cross_kv_cache, - 'current_start': chunk_id * chunk_size * frame_seqlen, - 'max_attention_size': kv_size if max_attention_size is None else max_attention_size + 'current_start': current_start_int, + 'max_attention_size': kv_size if max_attention_size is None else max_attention_size, + 'frame_seqlen': frame_seqlen, + 'kv_write_index': kv_write_index, } if offload_model: @@ -602,7 +646,10 @@ def noop_no_sync(): timestep = torch.stack(current_timestep).to(self.device) noise_pred = self.model( - x=latent_model_input, t=timestep, **kwargs)[0] + x=latent_model_input, t=timestep, + cross_attn_first_call=not self._cross_attn_initialized, + **kwargs)[0] + self._cross_attn_initialized = True if offload_model: torch.cuda.empty_cache() @@ -626,7 +673,9 @@ def noop_no_sync(): # Update kv cache context_timestep = [timesteps[-1] * 0.0] timestep = torch.stack(context_timestep).to(self.device) - self.model(x=[x0], t=timestep, **kwargs) + self.model(x=[x0], t=timestep, + cross_attn_first_call=False, + **kwargs) pred_latent_chunks = torch.cat(pred_latent_chunks, dim=1) diff --git a/wan/modules/model_fast.py b/wan/modules/model_fast.py index 3c18ed2..0779b1a 100644 --- a/wan/modules/model_fast.py +++ b/wan/modules/model_fast.py @@ -85,15 +85,29 @@ def forward( freqs, kv_cache=None, current_start=0, - max_attention_size=1_000_000 + max_attention_size=1_000_000, + frame_seqlen=None, + seq_lens_int=None, + kv_write_index=None, ): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] - block_mask (BlockMask) + frame_seqlen(int, optional): Pre-computed H*W/(patch_h*patch_w). If + provided, skips a `.item()` sync on `grid_sizes`. + seq_lens_int(int, optional): Accepted for signature parity with + the SP path (sp_attn_forward_causal). Unused here. + kv_write_index(LongTensor, optional): Pre-built index tensor of + shape [seq_lens] holding positions [current_start ... + current_end-1]. When provided, the fast-path uses + `index_copy_` instead of Python-int slice indexing — graph- + stable across chunks (the tensor's *contents* vary, but its + shape is fixed, so torch.compile / CUDA Graphs treats it as + one input rather than triggering per-chunk recompiles). """ + del seq_lens_int b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function @@ -105,7 +119,8 @@ def qkv_fn(x): q, k, v = qkv_fn(x) - frame_seqlen = math.prod(grid_sizes[0][1:]).item() + if frame_seqlen is None: + frame_seqlen = math.prod(grid_sizes[0][1:]).item() current_start_frame = current_start // frame_seqlen roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) roped_key = causal_rope_apply(k, grid_sizes, freqs, start_frame=current_start_frame).type_as(v) @@ -114,7 +129,24 @@ def qkv_fn(x): # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache kv_cache_size = kv_cache["k"].shape[1] num_new_tokens = roped_query.shape[1] - if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and ( + if self.local_attn_size == -1: + # Fast path (no eviction possible — cache is global). Both + # indices start at 0 and advance identically every forward, so + # local_end_index == current_end and local_start_index == + # current_start. + local_end_index = current_end + local_start_index = current_start + if kv_write_index is not None: + # Graph-stable write: index is a tensor input. No Python int + # in the slice → torch.compile / CUDA Graphs can capture + # this forward once and replay across chunks. + kv_cache["k"].index_copy_(1, kv_write_index, roped_key) + kv_cache["v"].index_copy_(1, kv_write_index, v) + else: + # Eager fallback for callers that don't pass kv_write_index. + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + elif (current_end > kv_cache["global_end_index"].item()) and ( num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size): # Calculate the number of new tokens added in this step # Shift existing cache content left to discard oldest tokens @@ -153,12 +185,17 @@ def qkv_fn(x): class WanCrossAttention(WanSelfAttention): - def forward(self, x, context, context_lens, crossattn_cache=None): + def forward(self, x, context, context_lens, crossattn_cache=None, + cross_attn_first_call=None): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] + cross_attn_first_call(bool, optional): If provided, used as the + "first call this generation" gate instead of reading + crossattn_cache["is_init"].item() (which forces a CPU↔GPU + sync). Caller (pipeline) tracks this as a Python bool. """ b, n, d = x.size(0), self.num_heads, self.head_dim @@ -166,7 +203,11 @@ def forward(self, x, context, context_lens, crossattn_cache=None): q = self.norm_q(self.q(x)).view(b, -1, n, d) if crossattn_cache is not None: - if crossattn_cache["is_init"].item() == 0: + if cross_attn_first_call is None: + is_first = crossattn_cache["is_init"].item() == 0 + else: + is_first = cross_attn_first_call + if is_first: crossattn_cache["is_init"].fill_(1) k = self.norm_k(self.k(context)).view(b, -1, n, d) v = self.v(context).view(b, -1, n, d) @@ -246,7 +287,11 @@ def forward( kv_cache=None, crossattn_cache=None, current_start=0, - max_attention_size=1_000_000 + max_attention_size=1_000_000, + frame_seqlen=None, + cross_attn_first_call=None, + seq_lens_int=None, + kv_write_index=None, ): r""" Args: @@ -262,7 +307,9 @@ def forward( # self-attention y = self.self_attn( self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), - seq_lens, grid_sizes, freqs, kv_cache, current_start, max_attention_size) + seq_lens, grid_sizes, freqs, kv_cache, current_start, max_attention_size, + frame_seqlen=frame_seqlen, seq_lens_int=seq_lens_int, + kv_write_index=kv_write_index) with torch.amp.autocast('cuda', dtype=torch.float32): x = x + y * e[2].squeeze(2) @@ -276,16 +323,19 @@ def forward( x = (1.0 + cam_scale) * x + cam_shift # cross-attention & ffn function - def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None): - x = x + self.cross_attn(self.norm3(x), context, context_lens, - crossattn_cache=crossattn_cache) + def cross_attn_ffn(x, context, context_lens, e, crossattn_cache=None, + cross_attn_first_call=None): + x = x + self.cross_attn(self.norm3(x), context, context_lens, + crossattn_cache=crossattn_cache, + cross_attn_first_call=cross_attn_first_call) y = self.ffn( self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) with torch.amp.autocast('cuda', dtype=torch.float32): x = x + y * e[5].squeeze(2) return x - x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache) + x = cross_attn_ffn(x, context, context_lens, e, crossattn_cache, + cross_attn_first_call=cross_attn_first_call) return x @@ -468,7 +518,10 @@ def forward( kv_cache=None, crossattn_cache=None, current_start=0, - max_attention_size=1_000_000 + max_attention_size=1_000_000, + frame_seqlen=None, + cross_attn_first_call=None, + kv_write_index=None, ): r""" Run the diffusion model with kv caching. @@ -580,7 +633,10 @@ def forward( context=context, context_lens=context_lens, dit_cond_dict=dit_cond_dict, - max_attention_size=max_attention_size) + max_attention_size=max_attention_size, + frame_seqlen=frame_seqlen, + cross_attn_first_call=cross_attn_first_call, + kv_write_index=kv_write_index) for block_index, block in enumerate(self.blocks): kwargs.update(