Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions examples/persistent_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Persistent-pipeline example: amortize model load + warmup across many prompts.

Pattern borrowed from LLM serving (vLLM, SGLang): don't construct the engine
per request. Construct WanI2VFast once, call prewarm() once, then loop over
prompts. Cold-start cost (~100 s model load + ~7 s prewarm) is paid once;
each subsequent generate() runs at steady-state speed.

Usage:
MASTER_ADDR=127.0.0.1 MASTER_PORT=29500 \\
torchrun --nproc_per_node=8 \\
--master_addr=127.0.0.1 --master_port=29500 \\
examples/persistent_inference.py \\
--ckpt_dir lingbot-world-base-cam \\
--image examples/03/image.jpg \\
--action_path examples/03 \\
--save_dir output

Three prompts run back-to-back. Output: output/persistent_{0,1,2}.mp4.

Wall-clock expectation (8xH100, 480*832, 81 frames):
Naive (3 separate `generate_fast.py` invocations): ~3 x 126 s ~= 378 s.
Persistent (this script): ~104 s + 7 s + 3 x 15.5 s ~= 158 s.
"""

import argparse
import logging
import os
import sys
import time
from pathlib import Path

# Allow running this file directly from the repo root: add the parent dir
# (repo root) to sys.path so `import wan` resolves to the in-tree package.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

import torch
import torch.distributed as dist
from PIL import Image

import wan # noqa: E402
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS # noqa: E402
from wan.distributed.util import init_distributed_group # noqa: E402
from wan.utils.utils import save_video # noqa: E402


PROMPTS = [
"A serene lakeside scene with a lone tree standing in calm water, "
"surrounded by distant snow-capped mountains under a bright blue sky "
"with drifting white clouds — gentle ripples reflect the tree and "
"sky, creating a tranquil, meditative atmosphere.",
"A sweeping cinematic journey along the Great Wall of China, winding "
"through golden autumn hills under a brilliant blue sky — stone "
"pathways stretch into the distance, watchtowers stand sentinel.",
"Aerial flight over a vast tropical rainforest at dawn, mist rising "
"from the canopy, sunlight breaking through tall trees, a winding "
"river snaking through the green expanse below.",
]


def _parse_args():
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument("--task", default="i2v-A14B",
choices=list(WAN_CONFIGS.keys()))
parser.add_argument("--size", default="480*832")
parser.add_argument("--ckpt_dir", required=True)
parser.add_argument("--image", required=True)
parser.add_argument("--action_path", default=None)
parser.add_argument("--frame_num", type=int, default=81)
parser.add_argument("--chunk_size", type=int, default=3)
parser.add_argument("--base_seed", type=int, default=42)
parser.add_argument("--save_dir", default="output")
parser.add_argument("--ulysses_size", type=int, default=8)
return parser.parse_args()


def _init_distributed():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if world_size > 1:
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
init_distributed_group()
if rank == 0:
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)],
)
else:
logging.basicConfig(level=logging.ERROR)
return rank, local_rank, world_size


def main():
args = _parse_args()
rank, local_rank, world_size = _init_distributed()

cfg = WAN_CONFIGS[args.task]
img = Image.open(args.image).convert("RGB")

if rank == 0:
logging.info(f"Persistent inference: {len(PROMPTS)} prompts, "
f"size={args.size}, frame_num={args.frame_num}")

# ── Cold start: construct pipe once. Paid before the timed window. ──
t_construct = time.perf_counter()
pipe = wan.WanI2VFast(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=local_rank,
rank=rank,
t5_fsdp=True,
dit_fsdp=True,
use_sp=(args.ulysses_size > 1),
)
construct_ms = (time.perf_counter() - t_construct) * 1000.0
if rank == 0:
logging.info(f"WanI2VFast() construction: {construct_ms:.0f} ms")

# ── One-time prewarm. After this, generate() runs at steady state. ──
t_warm = time.perf_counter()
pipe.prewarm(
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
chunk_size=args.chunk_size,
)
warm_ms = (time.perf_counter() - t_warm) * 1000.0
if rank == 0:
logging.info(f"prewarm(): {warm_ms:.0f} ms")

# ── Generate loop. Each call should run at amortized steady-state cost. ──
save_dir = Path(args.save_dir)
if rank == 0:
save_dir.mkdir(parents=True, exist_ok=True)

per_call_ms = []
for i, prompt in enumerate(PROMPTS):
if dist.is_initialized():
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()

video = pipe.generate(
prompt,
img,
action_path=args.action_path,
chunk_size=args.chunk_size,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=cfg.sample_shift,
seed=args.base_seed + i,
offload_model=False,
)

if dist.is_initialized():
torch.cuda.synchronize()
dist.barrier()
elapsed_ms = (time.perf_counter() - t0) * 1000.0
per_call_ms.append(elapsed_ms)

if rank == 0:
out_path = save_dir / f"persistent_{i}.mp4"
save_video(
tensor=video[None],
save_file=str(out_path),
fps=cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1),
)
logging.info(
f"prompt[{i}]: generate() {elapsed_ms:.0f} ms -> {out_path}")

if rank == 0:
total = construct_ms + warm_ms + sum(per_call_ms)
logging.info("=" * 60)
logging.info(f"SUMMARY prompts={len(PROMPTS)} hardware=8xH100")
logging.info(f" construct: {construct_ms:>8.0f} ms (once)")
logging.info(f" prewarm: {warm_ms:>8.0f} ms (once)")
for i, ms in enumerate(per_call_ms):
logging.info(f" generate[{i}]: {ms:>8.0f} ms")
avg = sum(per_call_ms) / len(per_call_ms)
logging.info(f" generate avg: {avg:>8.0f} ms")
logging.info(f" total wall-clock: {total:>8.0f} ms")
naive_est = construct_ms + warm_ms + len(PROMPTS) * (construct_ms + avg)
logging.info(f" naive estimate (separate invocations): {naive_est:>8.0f} ms")
logging.info(f" speedup vs naive: {naive_est / total:.2f}x")
logging.info("=" * 60)

if dist.is_initialized():
dist.barrier()
dist.destroy_process_group()


if __name__ == "__main__":
main()
48 changes: 40 additions & 8 deletions wan/distributed/sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def sp_dit_forward_causal(
crossattn_cache=None,
current_start=0,
max_attention_size=1_000_000,
frame_seqlen=None,
cross_attn_first_call=None,
kv_write_index=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
Expand Down Expand Up @@ -304,10 +307,14 @@ def sp_dit_forward_causal(
assert seq_lens.max() <= seq_len
x = torch.cat(x)

# Pad sequence to be divisible by world_size for SP chunking
# Pad sequence to be divisible by world_size for SP chunking.
# int(seq_lens) is one cudaStreamSynchronize; cache it once so each
# attention layer can reuse via the seq_lens_int kwarg instead of
# re-casting per-layer (32x).
sp_size = get_world_size()
padded_seq_lens = int((seq_lens + sp_size - 1) // sp_size * sp_size)
sp_pad_len = padded_seq_lens - int(seq_lens)
seq_lens_int = int(seq_lens)
padded_seq_lens = ((seq_lens_int + sp_size - 1) // sp_size) * sp_size
sp_pad_len = padded_seq_lens - seq_lens_int
if sp_pad_len > 0:
x = torch.cat([x, x.new_zeros(x.size(0), sp_pad_len, x.size(2))], dim=1)

Expand Down Expand Up @@ -378,7 +385,11 @@ def sp_dit_forward_causal(
context=context,
context_lens=context_lens,
dit_cond_dict=dit_cond_dict,
max_attention_size=max_attention_size)
max_attention_size=max_attention_size,
frame_seqlen=frame_seqlen,
cross_attn_first_call=cross_attn_first_call,
seq_lens_int=seq_lens_int,
kv_write_index=kv_write_index)

for block_index, block in enumerate(self.blocks):
kwargs.update(
Expand Down Expand Up @@ -410,7 +421,10 @@ def sp_attn_forward_causal(
freqs,
kv_cache=None,
current_start=0,
max_attention_size=1_000_000):
max_attention_size=1_000_000,
frame_seqlen=None,
seq_lens_int=None,
kv_write_index=None):
r"""
Sequence-parallel causal self-attention using Ulysses all-to-all.

Expand Down Expand Up @@ -457,9 +471,11 @@ def qkv_fn(x):

# padded_seq_lens = s * sp_size may exceed seq_lens due to SP padding
padded_seq_lens = s * sp_size
seq_lens_int = int(seq_lens)
if seq_lens_int is None:
seq_lens_int = int(seq_lens)

frame_seqlen = math.prod(grid_sizes[0][1:]).item()
if frame_seqlen is None:
frame_seqlen = math.prod(grid_sizes[0][1:]).item()
current_start_frame = current_start // frame_seqlen

# Apply causal RoPE on full (padded) sequence with local heads
Expand All @@ -475,7 +491,23 @@ def qkv_fn(x):

current_end = current_start + seq_lens
kv_cache_size = kv_cache["k"].shape[1]
if self.local_attn_size != -1 and (current_end > kv_cache["global_end_index"].item()) and (
if self.local_attn_size == -1:
# Fast path (no eviction possible — cache is global). Both indices
# advance identically every forward, so local_end_index ==
# current_end and local_start_index == current_start.
local_end_index = current_start + seq_lens_int
local_start_index = current_start
if kv_write_index is not None:
# Graph-stable write: index is a tensor input (shape fixed,
# contents vary per chunk). Lets torch.compile capture this
# forward once and replay across chunks instead of recompiling
# per current_start.
kv_cache["k"].index_copy_(1, kv_write_index, key)
kv_cache["v"].index_copy_(1, kv_write_index, v)
else:
kv_cache["k"][:, local_start_index:local_end_index] = key
kv_cache["v"][:, local_start_index:local_end_index] = v
elif (current_end > kv_cache["global_end_index"].item()) and (
seq_lens + kv_cache["local_end_index"].item() > kv_cache_size):
# Calculate the number of new tokens added in this step
# Shift existing cache content left to discard oldest tokens
Expand Down
Loading