diff --git a/README.md b/README.md index e2700b54..46dcda9b 100755 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ -------------------------------------------------------------------------------- +> This repository is a fork of **LightX2V** with [MagiCompiler](https://github.com/SandAI-org/MagiCompiler) integrated. Try it out and check the [MagiCompiler Documentation](README_MagiCompiler.md) for details! + **LightX2V** is an advanced lightweight image/video generation inference framework engineered to deliver efficient, high-performance image/video synthesis solutions. This unified platform integrates multiple state-of-the-art image/video generation techniques, supporting diverse generation tasks including text-to-video (T2V), image-to-video (I2V), text-to-image (T2I), image-editing (I2I). **X2V represents the transformation of different input modalities (X, such as text or images) into vision output (Vision)**. > 🌐 **Try it online now!** Experience LightX2V without installation: **[LightX2V Online Service](https://x2v.light-ai.top/login)** - Free, lightweight, and fast AI digital human video generation platform. diff --git a/README_MagiCompiler.md b/README_MagiCompiler.md new file mode 100644 index 00000000..2f2a315e --- /dev/null +++ b/README_MagiCompiler.md @@ -0,0 +1,123 @@ +
+ +# LightX2V-MagiCompiler + +
+ + +[MagiCompiler](https://github.com/SandAI-org/MagiCompiler.git) is an advanced compiler and runtime augmentation framework built on top of `torch.compile`. Designed specifically for large-scale Transformer-like architectures, it addresses the critical bottlenecks of memory walls and operator overheads. + +By stepping beyond traditional local operator optimization, MagiCompiler introduces system-level optimizations, seamlessly accelerating both training and multi-modality inference workloads with minimal code intrusion. + +### 🚀 Using MagiCompiler in LightX2V + +To accelerate LightX2V with MagiCompiler, you only need to add minimal code changes to register custom operators and decorate the main inference function: + +**1. Register Custom Attention Operators** +Use `@magi_register_custom_op` to register attention functions (like FlashAttention or SageAttention) so they can be recognized and optimized by MagiCompiler. + +```python +import torch +from magi_compiler import magi_register_custom_op + +# Example: Registering Flash Attention +@magi_register_custom_op("magi_compiler::flash_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True) +def flash_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + # Your attention implementation (e.g., using flash_attn_interface) + pass + +``` + +**2. Decorate the Inference Function** +Use the `@magi_compile` decorator on the core transformer loop (usually the `infer_without_offload` method) and specify dynamic shape dimensions to enable TorchDynamo tracing and graph optimization. + +```python +from magi_compiler import magi_compile + +class TransformerInfer(BaseTransformerInfer): + # Specify dynamic dimensions for tensors that change shape (e.g., sequence length) + @magi_compile(dynamic_arg_dims={ + "x": 0, + "pre_infer_out.embed": 0, + "pre_infer_out.x": 0, + "pre_infer_out.cos_sin": 0 + }) + def infer_without_offload(self, blocks, x, pre_infer_out): + for block_idx in range(len(blocks)): + x = self.infer_block(blocks[block_idx], x, pre_infer_out) + return x +``` + +With just these simple decorators, MagiCompiler can perform graph-level optimizations, fuse operators, and significantly improve the inference speed of LightX2V models like HunyuanVideo and Wan2.2. + +## 💡 Quick Start + +### Option 1: Installation via Docker (Recommended) +Using Docker is the simplest and fastest way to set up the environment, avoiding tedious environment dependency configurations. + +```bash +# 1. Pull the latest MagiCompiler Docker image +docker pull sandai/magi-compiler:latest + +# 2. Run and enter the container +# (Please replace /path/to/models with your local models directory) +docker run -it --gpus all -v /path/to/models:/models sandai/magi-compiler:latest bash + +# 3. Clone and install MagiCompiler +git clone https://github.com/SandAI-org/MagiCompiler.git +cd MagiCompiler +pip install -r requirements.txt +pip install . +# pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable +cd .. + +# 4. Clone and install LightX2V-MagiCompiler +git clone https://github.com/SandAI-org/LightX2V-MagiCompiler.git +cd LightX2V-MagiCompiler +pip install -v -e . +``` + +### Option 2: Installation via Conda +If you prefer a local environment, you can create an isolated virtual environment using Conda for source installation. + +```bash +# 1. Create and activate a Conda environment (Python 3.12 or higher is recommended) +conda create -n lightx2v python=3.12 +conda activate lightx2v + +# 2. Install PyTorch +pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0 + +# 3. Install Flash Attention In Hopper +git clone https://github.com/Dao-AILab/flash-attention +cd flash-attention/hopper +python setup.py install +cd ../.. + +# 4. Install MagiCompiler +git clone https://github.com/SandAI-org/MagiCompiler.git +cd MagiCompiler +pip install -r requirements.txt +pip install . +# pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable +cd .. + +# 5. Clone the source code and install project dependencies +git clone https://github.com/SandAI-org/LightX2V-MagiCompiler.git +cd LightX2V-MagiCompiler +pip install -r requirements.txt +pip install -v -e . +``` + + +## 🚀 Run LightX2V-MagiCompiler Examples + +**Run Wan2.2TI2V-5B** +```bash +bash ./magi_scripts/run_wan.sh +``` + +**Run Hunyuan1.5 480p_t2v_distilled** +```bash +bash ./magi_scripts/run_hunyuan.sh +``` diff --git a/examples/hunyuan_video/hunyuan_t2v_distill.py b/examples/hunyuan_video/hunyuan_t2v_distill.py index d7c7cdd5..1d93c927 100755 --- a/examples/hunyuan_video/hunyuan_t2v_distill.py +++ b/examples/hunyuan_video/hunyuan_t2v_distill.py @@ -3,16 +3,27 @@ This example demonstrates how to use LightX2V with HunyuanVideo-1.5 4-step distilled model for T2V generation. """ +import os +from datetime import datetime + from lightx2v import LightX2VPipeline +CP_SIZE = int(os.environ.get("CP_SIZE", 1)) +CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "false") + # Initialize pipeline for HunyuanVideo-1.5 pipe = LightX2VPipeline( - model_path="/path/to/ckpts/hunyuanvideo-1.5/", + model_path="path/to/HunyuanVideo-1.5/", model_cls="hunyuan_video_1.5", transformer_model_name="480p_t2v", task="t2v", # 4-step distilled model ckpt - dit_original_ckpt="/path/to/hy1.5_t2v_480p_lightx2v_4step.safetensors", + dit_original_ckpt="path/to/HunyuanVideo-1.5/transformer/480p_t2v_distilled/diffusion_pytorch_model.safetensors", +) + +pipe.enable_parallel( + seq_p_size=CP_SIZE, # Sequence parallel size + seq_p_attn_type="ulysses", # Sequence parallel attention type ) # Alternative: create generator from config JSON file @@ -20,31 +31,24 @@ # Enable offloading to significantly reduce VRAM usage with minimal speed impact # Suitable for RTX 30/40/50 consumer GPUs -pipe.enable_offload( - cpu_offload=True, - offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported - text_encoder_offload=True, - image_encoder_offload=False, - vae_offload=False, -) - -# Use lighttae -pipe.enable_lightvae( - use_tae=True, - tae_path="/path/to/lighttaehy1_5.safetensors", - use_lightvae=False, - vae_path=None, -) +if CPU_OFFLOAD == "true": + pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=True, + vae_offload=True, + ) # Create generator with specified parameters -pipe.create_generator(attn_mode="sage_attn2", infer_steps=4, num_frames=81, guidance_scale=1, sample_shift=9.0, aspect_ratio="16:9", fps=16, denoising_step_list=[1000, 750, 500, 250]) - +pipe.create_generator(attn_mode="flash_attn3", infer_steps=10, num_frames=121, guidance_scale=1, sample_shift=5.0, aspect_ratio="16:9", fps=24, denoising_step_list=[1000, 750, 500, 250]) +negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # Generation parameters seed = 123 prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style." -negative_prompt = "" -save_result_path = "/data/nvme0/gushiqiao/LightX2V/save_results/output.mp4" +suffix = datetime.now().strftime("%Y%m%d_%H%M%S") +save_result_path = f"./output_hunyuan_t2v_distill_{suffix}.mp4" # Generate video pipe.generate( diff --git a/examples/wan/wan_ti2v.py b/examples/wan/wan_ti2v.py new file mode 100644 index 00000000..950b0cd0 --- /dev/null +++ b/examples/wan/wan_ti2v.py @@ -0,0 +1,70 @@ +""" +Wan2.2 image-to-video generation example. +This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation. +""" + +import os +from datetime import datetime + +from lightx2v import LightX2VPipeline + +CP_SIZE = int(os.environ.get("CP_SIZE", 1)) + + +def env_is_true(env_name: str) -> bool: + return str(os.environ.get(env_name, "0")).lower() in {"1", "true", "yes", "y", "on", "enabled"} + + +CPU_OFFLOAD = env_is_true("CPU_OFFLOAD") + + +# Generate video + +pipe = LightX2VPipeline( + model_path="path/to/Wan2.2-TI2V-5B", + model_cls="wan2.2", + task="i2v", +) +pipe.enable_parallel( + seq_p_size=CP_SIZE, # Sequence parallel size + seq_p_attn_type="ulysses", # Sequence parallel attention type +) + +if CPU_OFFLOAD: + print("Enabling CPU offload") + pipe.enable_offload( + cpu_offload=True, + offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported + text_encoder_offload=True, + image_encoder_offload=True, + vae_offload=True, + ) + +pipe.create_generator( + attn_mode="flash_attn3", + infer_steps=10, + height=704, # Can be set to 720 for higher resolution + width=1280, # Can be set to 1280 for higher resolution + num_frames=121, + fps=24, + guidance_scale=5.0, # For wan2.1, guidance_scale is a scalar (e.g., 5.0) + sample_shift=5.0, + rope_type="torch", + # config_json="../../configs/wan22/wan_ti2v_i2v.json" +) + + +seed = 42 +prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." +negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" +image_path = "path/to/i2v_input.JPG" +suffix = datetime.now().strftime("%Y%m%d_%H%M%S") +save_result_path = f"./output_wan_ti2v_{suffix}.mp4" + +pipe.generate( + seed=seed, + image_path=image_path, + prompt=prompt, + negative_prompt=negative_prompt, + save_result_path=save_result_path, +) diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py index dcf1a407..98d3741d 100755 --- a/lightx2v/common/ops/attn/flash_attn.py +++ b/lightx2v/common/ops/attn/flash_attn.py @@ -1,23 +1,31 @@ -from loguru import logger - -try: - import flash_attn # noqa: F401 - from flash_attn.flash_attn_interface import flash_attn_varlen_func -except ImportError: - logger.info("flash_attn_varlen_func not found, please install flash_attn2 first") - flash_attn_varlen_func = None - -try: - from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3 -except ImportError: - logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") - flash_attn_varlen_func_v3 = None +import torch +from magi_compiler import magi_register_custom_op from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate +@magi_register_custom_op("magi_compiler::flash_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True) +def flash_attn_varlen_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_q: int, max_seqlen_kv: int) -> torch.Tensor: + try: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as fa2 + + return fa2(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + except ImportError: + raise ImportError("flash_attn_varlen_func not found, please install flash_attn2 first") + + +@magi_register_custom_op("magi_compiler::flash_attn_v3", infer_output_meta_fn=["q"], is_subgraph_boundary=True) +def flash_attn_varlen_func_v3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_q: int, max_seqlen_kv: int) -> torch.Tensor: + try: + from flash_attn_interface import flash_attn_varlen_func as fa3 + + return fa3(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv) + except ImportError: + raise ImportError("flash_attn_varlen_func_v3 not found, please install flash_attn3 first") + + @ATTN_WEIGHT_REGISTER("flash_attn2") class FlashAttn2Weight(AttnWeightTemplate): def __init__(self): diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py index 3d9a7e15..eecd2933 100755 --- a/lightx2v/common/ops/attn/sage_attn.py +++ b/lightx2v/common/ops/attn/sage_attn.py @@ -1,23 +1,12 @@ import torch from loguru import logger +from magi_compiler import magi_register_custom_op from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else None -if capability in [(8, 9), (12, 0)]: - try: - from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn - except ImportError: - logger.info("sageattn not found, please install sageattention first") - sageattn = None -else: - try: - from sageattention import sageattn - except ImportError: - logger.info("sageattn not found, please install sageattention first") - sageattn = None try: from sageattn3 import sageattn3_blackwell @@ -32,6 +21,24 @@ sageattn3_sparse_blackwell = None +@magi_register_custom_op("magi_compiler::sage_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True) +def sageattn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "NHD") -> torch.Tensor: + if capability in [(8, 9), (12, 0)]: + try: + from sageattention import sageattn_qk_int8_pv_fp16_triton + + return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout) + except ImportError: + raise ImportError("sageattn_qk_int8_pv_fp16_triton not found, please install sageattention first") + else: + try: + from sageattention import sageattn as sageattn2 + + return sageattn2(q, k, v, tensor_layout=tensor_layout) + except ImportError: + raise ImportError("sageattn not found, please install sageattention first") + + @ATTN_WEIGHT_REGISTER("sage_attn2") class SageAttn2Weight(AttnWeightTemplate): def __init__(self): diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index 303e8e42..7afe918c 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -70,10 +70,10 @@ def apply( if img_first: img_qkv_len = slice_qkv_len if len(cu_seqlens_qkv) == 3: - txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度 - txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len # 文本掩码长度 + txt_qkv_len = q.shape[1] - slice_qkv_len # 文本查询、键和值的长度 + txt_mask_len = q.shape[1] - slice_qkv_len # 文本掩码长度 elif len(cu_seqlens_qkv) == 2: - txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度 + txt_qkv_len = q.shape[0] - slice_qkv_len # 文本查询、键和值的长度 txt_mask_len = None else: # assert len(cu_seqlens_qkv) == 2 @@ -389,7 +389,6 @@ def apply( return attn # 返回最终的注意力结果 - @torch.compiler.disable def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm): img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 @@ -679,7 +678,6 @@ def apply( return attn # 返回最终的注意力结果 - @torch.compiler.disable def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm): cur_rank = dist.get_rank(seq_p_group) global_world_size = dist.get_world_size() diff --git a/lightx2v/common/ops/attn/utils/all2all.py b/lightx2v/common/ops/attn/utils/all2all.py index 757ce74e..037eb00b 100644 --- a/lightx2v/common/ops/attn/utils/all2all.py +++ b/lightx2v/common/ops/attn/utils/all2all.py @@ -1,9 +1,7 @@ import torch -import torch._dynamo as dynamo import torch.distributed as dist -@dynamo.disable def all2all_seq2head(input, group=None): """ 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。 @@ -44,7 +42,6 @@ def all2all_seq2head(input, group=None): return output # 返回转换后的输出张量 -@dynamo.disable def all2all_head2seq(input, group=None): """ 将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。 diff --git a/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py index d34d23b6..09a1b898 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py @@ -82,6 +82,11 @@ def __init__(self, config): self.cos_sin = None self.grid_sizes = (0, 0, 0) # (t, h, w) + if self.config.get("seq_parallel", False): + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + def set_scheduler(self, scheduler): self.scheduler = scheduler diff --git a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py index 84872922..d925378b 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from einops import rearrange +from magi_compiler import magi_compile try: from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace @@ -96,8 +97,11 @@ def _apply_rope(x: torch.Tensor) -> torch.Tensor: return xq_out, xk_out -class HunyuanVideo15TransformerInfer(BaseTransformerInfer): +class HunyuanVideo15TransformerInfer(BaseTransformerInfer, torch.nn.Module): + __constants__ = ["heads_num"] + def __init__(self, config): + torch.nn.Module.__init__(self) self.config = config self.double_blocks_num = config["mm_double_blocks_depth"] self.heads_num = config["heads_num"] @@ -112,11 +116,11 @@ def __init__(self, config): self.seq_p_fp4_comm = False self.enable_head_parallel = False self.infer_func = self.infer_without_offload - if self.config.get("modulate_type", "triton") == "triton": + if self.config.get("modulate_type", "torch") == "triton": self.modulate_func = fuse_scale_shift_kernel else: self.modulate_func = modulate - if self.config.get("rope_type", "flashinfer") == "flashinfer": + if self.config.get("rope_type", "torch") == "flashinfer": self.apply_rope_func = apply_hunyuan_rope_with_flashinfer else: self.apply_rope_func = apply_hunyuan_rope_with_torch @@ -132,6 +136,7 @@ def infer(self, weights, infer_module_out): return x @torch.no_grad() + @magi_compile(dynamic_arg_dims={"infer_module_out.img": 1, "infer_module_out.txt": 1, "infer_module_out.cos_sin": 0}) def infer_without_offload(self, weights, infer_module_out): for i in range(self.double_blocks_num): infer_module_out.img, infer_module_out.txt = self.infer_double_block(weights.double_blocks[i], infer_module_out) diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index febfa1b8..833030a6 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -1,6 +1,7 @@ from functools import partial import torch +from magi_compiler import magi_compile from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.utils.envs import * @@ -11,11 +12,14 @@ def modulate(x, scale, shift): - return x * (1 + scale.squeeze()) + shift.squeeze() + return x * (1 + scale.squeeze(1)) + shift.squeeze(1) -class WanTransformerInfer(BaseTransformerInfer): +class WanTransformerInfer(BaseTransformerInfer, torch.nn.Module): + __constants__ = ["num_heads", "head_dim"] + def __init__(self, config): + torch.nn.Module.__init__(self) self.config = config self.task = config["task"] self.attention_type = config.get("attention_type", "flash_attn2") @@ -29,7 +33,7 @@ def __init__(self, config): self.head_dim = config["dim"] // config["num_heads"] self.window_size = config.get("window_size", (-1, -1)) self.parallel_attention = None - if self.config.get("modulate_type", "triton") == "triton": + if self.config.get("modulate_type", "torch") == "triton": self.modulate_func = fuse_scale_shift_kernel else: self.modulate_func = modulate @@ -38,7 +42,7 @@ def __init__(self, config): "torch": apply_wan_rope_with_torch, "torch_naive": apply_wan_rope_with_torch_naive, } - rope_type = self.config.get("rope_type", "flashinfer") + rope_type = self.config.get("rope_type", "torch") # Try to get rope function from registry first (for platform-specific implementations) if rope_type in ROPE_REGISTER: rope_class = ROPE_REGISTER[rope_type] @@ -90,7 +94,6 @@ def reset_infer_states(self): @torch.no_grad() def infer(self, weights, pre_infer_out): - self.cos_sin = pre_infer_out.cos_sin self.reset_infer_states() x = self.infer_main_blocks(weights.blocks, pre_infer_out) return self.infer_non_blocks(weights, x, pre_infer_out.embed) @@ -123,7 +126,9 @@ def infer_non_blocks(self, weights, x, e): torch.cuda.empty_cache() return x + @magi_compile(dynamic_arg_dims={"x": 0, "pre_infer_out.embed": 0, "pre_infer_out.x": 0, "pre_infer_out.embed0": 0, "pre_infer_out.context": 0, "pre_infer_out.cos_sin": 0}) def infer_without_offload(self, blocks, x, pre_infer_out): + self.cos_sin = pre_infer_out.cos_sin for block_idx in range(len(blocks)): self.block_idx = block_idx x = self.infer_block(blocks[block_idx], x, pre_infer_out) @@ -186,10 +191,10 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): if self.sensitive_layer_dtype != self.infer_dtype: norm1_out = norm1_out.to(self.infer_dtype) - s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim - q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d) - k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d) - v = phase.self_attn_v.apply(norm1_out).view(s, n, d) + n, d = self.num_heads, self.head_dim + q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(-1, n, d) + k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(-1, n, d) + v = phase.self_attn_v.apply(norm1_out).view(-1, n, d) q, k = self.apply_rope_func(q, k, cos_sin) img_qkv_len = q.shape[0] if self.self_attn_cu_seqlens_qkv is None: @@ -247,7 +252,7 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa): if self.sensitive_layer_dtype != self.infer_dtype: x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze() else: - x.add_(y_out * gate_msa.squeeze()) + x = x + (y_out * gate_msa.squeeze(1)) norm3_out = phase.norm3.apply(x) if self.task in ["i2v", "flf2v", "animate", "s2v", "rs2v"] and self.config.get("use_image_encoder", True): @@ -356,7 +361,7 @@ def post_process(self, x, y, c_gate_msa, pre_infer_out=None): if self.sensitive_layer_dtype != self.infer_dtype: x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze() else: - x.add_(y * c_gate_msa.squeeze()) + x = x + (y * c_gate_msa.squeeze(1)) if self.clean_cuda_cache: del y, c_gate_msa diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index 5b15b9b0..4b676ff4 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -138,9 +138,10 @@ def _seq_parallel_pre_process(self, pre_infer_out): if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v", "rs2v"]: embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 - padding_size = (world_size - (embed.shape[0] % world_size)) % world_size + # Fix: reuse the padding_size calculated above for embed and embed0 if padding_size > 0: embed = F.pad(embed, (0, 0, 0, padding_size)) + # Note: if embed0 is a 3D tensor, the first few 0s pad the later dimensions, only pad the 0th dimension (Sequence dimension) embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size)) pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank] diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index 8c2e2989..4f7db118 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -348,9 +348,14 @@ def init_run(self): self.model = self.load_transformer() self.model.set_scheduler(self.scheduler) - self.model.scheduler.prepare( - seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.model.scheduler.infer_steps, image_encoder_output=self.inputs["image_encoder_output"] - ) + import inspect + + prepare_kwargs = {"seed": self.input_info.seed, "latent_shape": self.input_info.latent_shape, "image_encoder_output": self.inputs.get("image_encoder_output", None)} + sig = inspect.signature(self.model.scheduler.prepare) + if "infer_steps" in sig.parameters: + prepare_kwargs["infer_steps"] = self.model.scheduler.infer_steps + + self.model.scheduler.prepare(**prepare_kwargs) if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.inputs["image_encoder_output"]["vae_encoder_out"] = None diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index 5c92284e..e9c31e5a 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -100,6 +100,7 @@ def __init__( elif self.model_cls in ["wan2.2"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 48 + self.use_image_encoder = False elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 32 diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 3017c9f6..082b1ee8 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -37,7 +37,20 @@ def get_default_config(): def set_args2config(args): config = get_default_config() - config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS}) + + task = getattr(args, "task", None) + if task: + from lightx2v.utils.input_info import init_empty_input_info + + try: + input_info = init_empty_input_info(task) + exclude_keys = set(input_info.__dataclass_fields__.keys()) + except Exception: + exclude_keys = ALL_INPUT_INFO_KEYS + else: + exclude_keys = ALL_INPUT_INFO_KEYS + + config.update({k: v for k, v in vars(args).items() if k not in exclude_keys}) return config diff --git a/magi_scripts/run_hunyuan.sh b/magi_scripts/run_hunyuan.sh new file mode 100644 index 00000000..255d4ab5 --- /dev/null +++ b/magi_scripts/run_hunyuan.sh @@ -0,0 +1,23 @@ +### Distributed args ### +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29501} + +# GPUS_PER_NODE=${GPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT" + +CP_SIZE=${CP_SIZE:-$GPUS_PER_NODE} +export CP_SIZE=$CP_SIZE + +CPU_OFFLOAD=${CPU_OFFLOAD:-false} +export CPU_OFFLOAD=$CPU_OFFLOAD + +### Log dir ### +current_time=$(date "+%Y-%m-%d_%H:%M:%S") +LOG_FILE=${LOG_FILE:-log_${current_time}.log} + +torchrun $DISTRIBUTED_ARGS examples/hunyuan_video/hunyuan_t2v_distill.py \ + ${POST_ARGS} 2>&1 | tee $LOG_FILE diff --git a/magi_scripts/run_wan.sh b/magi_scripts/run_wan.sh new file mode 100644 index 00000000..a8b966d9 --- /dev/null +++ b/magi_scripts/run_wan.sh @@ -0,0 +1,23 @@ +### Distributed args ### +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-29501} + +# GPUS_PER_NODE=${GPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)} +GPUS_PER_NODE=${GPUS_PER_NODE:-1} +NNODES=${NNODES:-1} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) +DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT" + +CP_SIZE=${CP_SIZE:-$GPUS_PER_NODE} +export CP_SIZE=$CP_SIZE + +CPU_OFFLOAD=${CPU_OFFLOAD:-false} +export CPU_OFFLOAD=$CPU_OFFLOAD + +### Log dir ### +current_time=$(date "+%Y-%m-%d_%H:%M:%S") +LOG_FILE=${LOG_FILE:-log_${current_time}.log} + +$LAUNCH_PREFIX torchrun $DISTRIBUTED_ARGS examples/wan/wan_ti2v.py \ + ${POST_ARGS} 2>&1 | tee $LOG_FILE diff --git a/pyproject.toml b/pyproject.toml index c62264b4..d500fd9c 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,7 @@ classifiers = [ dependencies = [ "numpy", "scipy", - "torch<=2.8.0", - "torchvision<=0.23.0", - "torchaudio<=2.8.0", + "pydantic_settings", "diffusers", "transformers", "tokenizers", diff --git a/requirements.txt b/requirements.txt index e83cb4f5..765121d5 100755 --- a/requirements.txt +++ b/requirements.txt @@ -2,9 +2,9 @@ packaging ninja numpy scipy -torch<=2.8.0 -torchvision<=0.23.0 -torchaudio<=2.8.0 +torch==2.9.0 +torchvision==0.24.0 +torchaudio==2.9.0 torchao diffusers transformers