diff --git a/README.md b/README.md
index e2700b54..46dcda9b 100755
--- a/README.md
+++ b/README.md
@@ -16,6 +16,8 @@
--------------------------------------------------------------------------------
+> This repository is a fork of **LightX2V** with [MagiCompiler](https://github.com/SandAI-org/MagiCompiler) integrated. Try it out and check the [MagiCompiler Documentation](README_MagiCompiler.md) for details!
+
**LightX2V** is an advanced lightweight image/video generation inference framework engineered to deliver efficient, high-performance image/video synthesis solutions. This unified platform integrates multiple state-of-the-art image/video generation techniques, supporting diverse generation tasks including text-to-video (T2V), image-to-video (I2V), text-to-image (T2I), image-editing (I2I). **X2V represents the transformation of different input modalities (X, such as text or images) into vision output (Vision)**.
> 🌐 **Try it online now!** Experience LightX2V without installation: **[LightX2V Online Service](https://x2v.light-ai.top/login)** - Free, lightweight, and fast AI digital human video generation platform.
diff --git a/README_MagiCompiler.md b/README_MagiCompiler.md
new file mode 100644
index 00000000..2f2a315e
--- /dev/null
+++ b/README_MagiCompiler.md
@@ -0,0 +1,123 @@
+
+
+# LightX2V-MagiCompiler
+
+
+
+
+[MagiCompiler](https://github.com/SandAI-org/MagiCompiler.git) is an advanced compiler and runtime augmentation framework built on top of `torch.compile`. Designed specifically for large-scale Transformer-like architectures, it addresses the critical bottlenecks of memory walls and operator overheads.
+
+By stepping beyond traditional local operator optimization, MagiCompiler introduces system-level optimizations, seamlessly accelerating both training and multi-modality inference workloads with minimal code intrusion.
+
+### 🚀 Using MagiCompiler in LightX2V
+
+To accelerate LightX2V with MagiCompiler, you only need to add minimal code changes to register custom operators and decorate the main inference function:
+
+**1. Register Custom Attention Operators**
+Use `@magi_register_custom_op` to register attention functions (like FlashAttention or SageAttention) so they can be recognized and optimized by MagiCompiler.
+
+```python
+import torch
+from magi_compiler import magi_register_custom_op
+
+# Example: Registering Flash Attention
+@magi_register_custom_op("magi_compiler::flash_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True)
+def flash_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ # Your attention implementation (e.g., using flash_attn_interface)
+ pass
+
+```
+
+**2. Decorate the Inference Function**
+Use the `@magi_compile` decorator on the core transformer loop (usually the `infer_without_offload` method) and specify dynamic shape dimensions to enable TorchDynamo tracing and graph optimization.
+
+```python
+from magi_compiler import magi_compile
+
+class TransformerInfer(BaseTransformerInfer):
+ # Specify dynamic dimensions for tensors that change shape (e.g., sequence length)
+ @magi_compile(dynamic_arg_dims={
+ "x": 0,
+ "pre_infer_out.embed": 0,
+ "pre_infer_out.x": 0,
+ "pre_infer_out.cos_sin": 0
+ })
+ def infer_without_offload(self, blocks, x, pre_infer_out):
+ for block_idx in range(len(blocks)):
+ x = self.infer_block(blocks[block_idx], x, pre_infer_out)
+ return x
+```
+
+With just these simple decorators, MagiCompiler can perform graph-level optimizations, fuse operators, and significantly improve the inference speed of LightX2V models like HunyuanVideo and Wan2.2.
+
+## 💡 Quick Start
+
+### Option 1: Installation via Docker (Recommended)
+Using Docker is the simplest and fastest way to set up the environment, avoiding tedious environment dependency configurations.
+
+```bash
+# 1. Pull the latest MagiCompiler Docker image
+docker pull sandai/magi-compiler:latest
+
+# 2. Run and enter the container
+# (Please replace /path/to/models with your local models directory)
+docker run -it --gpus all -v /path/to/models:/models sandai/magi-compiler:latest bash
+
+# 3. Clone and install MagiCompiler
+git clone https://github.com/SandAI-org/MagiCompiler.git
+cd MagiCompiler
+pip install -r requirements.txt
+pip install .
+# pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable
+cd ..
+
+# 4. Clone and install LightX2V-MagiCompiler
+git clone https://github.com/SandAI-org/LightX2V-MagiCompiler.git
+cd LightX2V-MagiCompiler
+pip install -v -e .
+```
+
+### Option 2: Installation via Conda
+If you prefer a local environment, you can create an isolated virtual environment using Conda for source installation.
+
+```bash
+# 1. Create and activate a Conda environment (Python 3.12 or higher is recommended)
+conda create -n lightx2v python=3.12
+conda activate lightx2v
+
+# 2. Install PyTorch
+pip install torch==2.9.0 torchvision==0.24.0 torchaudio==2.9.0
+
+# 3. Install Flash Attention In Hopper
+git clone https://github.com/Dao-AILab/flash-attention
+cd flash-attention/hopper
+python setup.py install
+cd ../..
+
+# 4. Install MagiCompiler
+git clone https://github.com/SandAI-org/MagiCompiler.git
+cd MagiCompiler
+pip install -r requirements.txt
+pip install .
+# pip install -e . --no-build-isolation --config-settings editable_mode=compat # Developer / editable
+cd ..
+
+# 5. Clone the source code and install project dependencies
+git clone https://github.com/SandAI-org/LightX2V-MagiCompiler.git
+cd LightX2V-MagiCompiler
+pip install -r requirements.txt
+pip install -v -e .
+```
+
+
+## 🚀 Run LightX2V-MagiCompiler Examples
+
+**Run Wan2.2TI2V-5B**
+```bash
+bash ./magi_scripts/run_wan.sh
+```
+
+**Run Hunyuan1.5 480p_t2v_distilled**
+```bash
+bash ./magi_scripts/run_hunyuan.sh
+```
diff --git a/examples/hunyuan_video/hunyuan_t2v_distill.py b/examples/hunyuan_video/hunyuan_t2v_distill.py
index d7c7cdd5..1d93c927 100755
--- a/examples/hunyuan_video/hunyuan_t2v_distill.py
+++ b/examples/hunyuan_video/hunyuan_t2v_distill.py
@@ -3,16 +3,27 @@
This example demonstrates how to use LightX2V with HunyuanVideo-1.5 4-step distilled model for T2V generation.
"""
+import os
+from datetime import datetime
+
from lightx2v import LightX2VPipeline
+CP_SIZE = int(os.environ.get("CP_SIZE", 1))
+CPU_OFFLOAD = os.environ.get("CPU_OFFLOAD", "false")
+
# Initialize pipeline for HunyuanVideo-1.5
pipe = LightX2VPipeline(
- model_path="/path/to/ckpts/hunyuanvideo-1.5/",
+ model_path="path/to/HunyuanVideo-1.5/",
model_cls="hunyuan_video_1.5",
transformer_model_name="480p_t2v",
task="t2v",
# 4-step distilled model ckpt
- dit_original_ckpt="/path/to/hy1.5_t2v_480p_lightx2v_4step.safetensors",
+ dit_original_ckpt="path/to/HunyuanVideo-1.5/transformer/480p_t2v_distilled/diffusion_pytorch_model.safetensors",
+)
+
+pipe.enable_parallel(
+ seq_p_size=CP_SIZE, # Sequence parallel size
+ seq_p_attn_type="ulysses", # Sequence parallel attention type
)
# Alternative: create generator from config JSON file
@@ -20,31 +31,24 @@
# Enable offloading to significantly reduce VRAM usage with minimal speed impact
# Suitable for RTX 30/40/50 consumer GPUs
-pipe.enable_offload(
- cpu_offload=True,
- offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
- text_encoder_offload=True,
- image_encoder_offload=False,
- vae_offload=False,
-)
-
-# Use lighttae
-pipe.enable_lightvae(
- use_tae=True,
- tae_path="/path/to/lighttaehy1_5.safetensors",
- use_lightvae=False,
- vae_path=None,
-)
+if CPU_OFFLOAD == "true":
+ pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
+ text_encoder_offload=True,
+ image_encoder_offload=True,
+ vae_offload=True,
+ )
# Create generator with specified parameters
-pipe.create_generator(attn_mode="sage_attn2", infer_steps=4, num_frames=81, guidance_scale=1, sample_shift=9.0, aspect_ratio="16:9", fps=16, denoising_step_list=[1000, 750, 500, 250])
-
+pipe.create_generator(attn_mode="flash_attn3", infer_steps=10, num_frames=121, guidance_scale=1, sample_shift=5.0, aspect_ratio="16:9", fps=24, denoising_step_list=[1000, 750, 500, 250])
+negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
# Generation parameters
seed = 123
prompt = "A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
-negative_prompt = ""
-save_result_path = "/data/nvme0/gushiqiao/LightX2V/save_results/output.mp4"
+suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
+save_result_path = f"./output_hunyuan_t2v_distill_{suffix}.mp4"
# Generate video
pipe.generate(
diff --git a/examples/wan/wan_ti2v.py b/examples/wan/wan_ti2v.py
new file mode 100644
index 00000000..950b0cd0
--- /dev/null
+++ b/examples/wan/wan_ti2v.py
@@ -0,0 +1,70 @@
+"""
+Wan2.2 image-to-video generation example.
+This example demonstrates how to use LightX2V with Wan2.2 model for I2V generation.
+"""
+
+import os
+from datetime import datetime
+
+from lightx2v import LightX2VPipeline
+
+CP_SIZE = int(os.environ.get("CP_SIZE", 1))
+
+
+def env_is_true(env_name: str) -> bool:
+ return str(os.environ.get(env_name, "0")).lower() in {"1", "true", "yes", "y", "on", "enabled"}
+
+
+CPU_OFFLOAD = env_is_true("CPU_OFFLOAD")
+
+
+# Generate video
+
+pipe = LightX2VPipeline(
+ model_path="path/to/Wan2.2-TI2V-5B",
+ model_cls="wan2.2",
+ task="i2v",
+)
+pipe.enable_parallel(
+ seq_p_size=CP_SIZE, # Sequence parallel size
+ seq_p_attn_type="ulysses", # Sequence parallel attention type
+)
+
+if CPU_OFFLOAD:
+ print("Enabling CPU offload")
+ pipe.enable_offload(
+ cpu_offload=True,
+ offload_granularity="block", # For HunyuanVideo-1.5, only "block" is supported
+ text_encoder_offload=True,
+ image_encoder_offload=True,
+ vae_offload=True,
+ )
+
+pipe.create_generator(
+ attn_mode="flash_attn3",
+ infer_steps=10,
+ height=704, # Can be set to 720 for higher resolution
+ width=1280, # Can be set to 1280 for higher resolution
+ num_frames=121,
+ fps=24,
+ guidance_scale=5.0, # For wan2.1, guidance_scale is a scalar (e.g., 5.0)
+ sample_shift=5.0,
+ rope_type="torch",
+ # config_json="../../configs/wan22/wan_ti2v_i2v.json"
+)
+
+
+seed = 42
+prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+negative_prompt = "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+image_path = "path/to/i2v_input.JPG"
+suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
+save_result_path = f"./output_wan_ti2v_{suffix}.mp4"
+
+pipe.generate(
+ seed=seed,
+ image_path=image_path,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ save_result_path=save_result_path,
+)
diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py
index dcf1a407..98d3741d 100755
--- a/lightx2v/common/ops/attn/flash_attn.py
+++ b/lightx2v/common/ops/attn/flash_attn.py
@@ -1,23 +1,31 @@
-from loguru import logger
-
-try:
- import flash_attn # noqa: F401
- from flash_attn.flash_attn_interface import flash_attn_varlen_func
-except ImportError:
- logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
- flash_attn_varlen_func = None
-
-try:
- from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
-except ImportError:
- logger.info("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
- flash_attn_varlen_func_v3 = None
+import torch
+from magi_compiler import magi_register_custom_op
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
+@magi_register_custom_op("magi_compiler::flash_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True)
+def flash_attn_varlen_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_q: int, max_seqlen_kv: int) -> torch.Tensor:
+ try:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as fa2
+
+ return fa2(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ except ImportError:
+ raise ImportError("flash_attn_varlen_func not found, please install flash_attn2 first")
+
+
+@magi_register_custom_op("magi_compiler::flash_attn_v3", infer_output_meta_fn=["q"], is_subgraph_boundary=True)
+def flash_attn_varlen_func_v3(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_kv: torch.Tensor, max_seqlen_q: int, max_seqlen_kv: int) -> torch.Tensor:
+ try:
+ from flash_attn_interface import flash_attn_varlen_func as fa3
+
+ return fa3(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
+ except ImportError:
+ raise ImportError("flash_attn_varlen_func_v3 not found, please install flash_attn3 first")
+
+
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
def __init__(self):
diff --git a/lightx2v/common/ops/attn/sage_attn.py b/lightx2v/common/ops/attn/sage_attn.py
index 3d9a7e15..eecd2933 100755
--- a/lightx2v/common/ops/attn/sage_attn.py
+++ b/lightx2v/common/ops/attn/sage_attn.py
@@ -1,23 +1,12 @@
import torch
from loguru import logger
+from magi_compiler import magi_register_custom_op
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else None
-if capability in [(8, 9), (12, 0)]:
- try:
- from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
- except ImportError:
- logger.info("sageattn not found, please install sageattention first")
- sageattn = None
-else:
- try:
- from sageattention import sageattn
- except ImportError:
- logger.info("sageattn not found, please install sageattention first")
- sageattn = None
try:
from sageattn3 import sageattn3_blackwell
@@ -32,6 +21,24 @@
sageattn3_sparse_blackwell = None
+@magi_register_custom_op("magi_compiler::sage_attn", infer_output_meta_fn=["q"], is_subgraph_boundary=True)
+def sageattn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "NHD") -> torch.Tensor:
+ if capability in [(8, 9), (12, 0)]:
+ try:
+ from sageattention import sageattn_qk_int8_pv_fp16_triton
+
+ return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout)
+ except ImportError:
+ raise ImportError("sageattn_qk_int8_pv_fp16_triton not found, please install sageattention first")
+ else:
+ try:
+ from sageattention import sageattn as sageattn2
+
+ return sageattn2(q, k, v, tensor_layout=tensor_layout)
+ except ImportError:
+ raise ImportError("sageattn not found, please install sageattention first")
+
+
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py
index 303e8e42..7afe918c 100755
--- a/lightx2v/common/ops/attn/ulysses_attn.py
+++ b/lightx2v/common/ops/attn/ulysses_attn.py
@@ -70,10 +70,10 @@ def apply(
if img_first:
img_qkv_len = slice_qkv_len
if len(cu_seqlens_qkv) == 3:
- txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
- txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len # 文本掩码长度
+ txt_qkv_len = q.shape[1] - slice_qkv_len # 文本查询、键和值的长度
+ txt_mask_len = q.shape[1] - slice_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
- txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
+ txt_qkv_len = q.shape[0] - slice_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
else:
# assert len(cu_seqlens_qkv) == 2
@@ -389,7 +389,6 @@ def apply(
return attn # 返回最终的注意力结果
- @torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
@@ -679,7 +678,6 @@ def apply(
return attn # 返回最终的注意力结果
- @torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
diff --git a/lightx2v/common/ops/attn/utils/all2all.py b/lightx2v/common/ops/attn/utils/all2all.py
index 757ce74e..037eb00b 100644
--- a/lightx2v/common/ops/attn/utils/all2all.py
+++ b/lightx2v/common/ops/attn/utils/all2all.py
@@ -1,9 +1,7 @@
import torch
-import torch._dynamo as dynamo
import torch.distributed as dist
-@dynamo.disable
def all2all_seq2head(input, group=None):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
@@ -44,7 +42,6 @@ def all2all_seq2head(input, group=None):
return output # 返回转换后的输出张量
-@dynamo.disable
def all2all_head2seq(input, group=None):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
diff --git a/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
index d34d23b6..09a1b898 100755
--- a/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
+++ b/lightx2v/models/networks/hunyuan_video/infer/pre_infer.py
@@ -82,6 +82,11 @@ def __init__(self, config):
self.cos_sin = None
self.grid_sizes = (0, 0, 0) # (t, h, w)
+ if self.config.get("seq_parallel", False):
+ self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
+ else:
+ self.seq_p_group = None
+
def set_scheduler(self, scheduler):
self.scheduler = scheduler
diff --git a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
index 84872922..d925378b 100755
--- a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
+++ b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
@@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F
from einops import rearrange
+from magi_compiler import magi_compile
try:
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
@@ -96,8 +97,11 @@ def _apply_rope(x: torch.Tensor) -> torch.Tensor:
return xq_out, xk_out
-class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
+class HunyuanVideo15TransformerInfer(BaseTransformerInfer, torch.nn.Module):
+ __constants__ = ["heads_num"]
+
def __init__(self, config):
+ torch.nn.Module.__init__(self)
self.config = config
self.double_blocks_num = config["mm_double_blocks_depth"]
self.heads_num = config["heads_num"]
@@ -112,11 +116,11 @@ def __init__(self, config):
self.seq_p_fp4_comm = False
self.enable_head_parallel = False
self.infer_func = self.infer_without_offload
- if self.config.get("modulate_type", "triton") == "triton":
+ if self.config.get("modulate_type", "torch") == "triton":
self.modulate_func = fuse_scale_shift_kernel
else:
self.modulate_func = modulate
- if self.config.get("rope_type", "flashinfer") == "flashinfer":
+ if self.config.get("rope_type", "torch") == "flashinfer":
self.apply_rope_func = apply_hunyuan_rope_with_flashinfer
else:
self.apply_rope_func = apply_hunyuan_rope_with_torch
@@ -132,6 +136,7 @@ def infer(self, weights, infer_module_out):
return x
@torch.no_grad()
+ @magi_compile(dynamic_arg_dims={"infer_module_out.img": 1, "infer_module_out.txt": 1, "infer_module_out.cos_sin": 0})
def infer_without_offload(self, weights, infer_module_out):
for i in range(self.double_blocks_num):
infer_module_out.img, infer_module_out.txt = self.infer_double_block(weights.double_blocks[i], infer_module_out)
diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py
index febfa1b8..833030a6 100755
--- a/lightx2v/models/networks/wan/infer/transformer_infer.py
+++ b/lightx2v/models/networks/wan/infer/transformer_infer.py
@@ -1,6 +1,7 @@
from functools import partial
import torch
+from magi_compiler import magi_compile
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.utils.envs import *
@@ -11,11 +12,14 @@
def modulate(x, scale, shift):
- return x * (1 + scale.squeeze()) + shift.squeeze()
+ return x * (1 + scale.squeeze(1)) + shift.squeeze(1)
-class WanTransformerInfer(BaseTransformerInfer):
+class WanTransformerInfer(BaseTransformerInfer, torch.nn.Module):
+ __constants__ = ["num_heads", "head_dim"]
+
def __init__(self, config):
+ torch.nn.Module.__init__(self)
self.config = config
self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2")
@@ -29,7 +33,7 @@ def __init__(self, config):
self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
- if self.config.get("modulate_type", "triton") == "triton":
+ if self.config.get("modulate_type", "torch") == "triton":
self.modulate_func = fuse_scale_shift_kernel
else:
self.modulate_func = modulate
@@ -38,7 +42,7 @@ def __init__(self, config):
"torch": apply_wan_rope_with_torch,
"torch_naive": apply_wan_rope_with_torch_naive,
}
- rope_type = self.config.get("rope_type", "flashinfer")
+ rope_type = self.config.get("rope_type", "torch")
# Try to get rope function from registry first (for platform-specific implementations)
if rope_type in ROPE_REGISTER:
rope_class = ROPE_REGISTER[rope_type]
@@ -90,7 +94,6 @@ def reset_infer_states(self):
@torch.no_grad()
def infer(self, weights, pre_infer_out):
- self.cos_sin = pre_infer_out.cos_sin
self.reset_infer_states()
x = self.infer_main_blocks(weights.blocks, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
@@ -123,7 +126,9 @@ def infer_non_blocks(self, weights, x, e):
torch.cuda.empty_cache()
return x
+ @magi_compile(dynamic_arg_dims={"x": 0, "pre_infer_out.embed": 0, "pre_infer_out.x": 0, "pre_infer_out.embed0": 0, "pre_infer_out.context": 0, "pre_infer_out.cos_sin": 0})
def infer_without_offload(self, blocks, x, pre_infer_out):
+ self.cos_sin = pre_infer_out.cos_sin
for block_idx in range(len(blocks)):
self.block_idx = block_idx
x = self.infer_block(blocks[block_idx], x, pre_infer_out)
@@ -186,10 +191,10 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa):
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.infer_dtype)
- s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
- q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(s, n, d)
- k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(s, n, d)
- v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
+ n, d = self.num_heads, self.head_dim
+ q = phase.self_attn_norm_q.apply(phase.self_attn_q.apply(norm1_out)).view(-1, n, d)
+ k = phase.self_attn_norm_k.apply(phase.self_attn_k.apply(norm1_out)).view(-1, n, d)
+ v = phase.self_attn_v.apply(norm1_out).view(-1, n, d)
q, k = self.apply_rope_func(q, k, cos_sin)
img_qkv_len = q.shape[0]
if self.self_attn_cu_seqlens_qkv is None:
@@ -247,7 +252,7 @@ def infer_cross_attn(self, phase, x, context, y_out, gate_msa):
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y_out.to(self.sensitive_layer_dtype) * gate_msa.squeeze()
else:
- x.add_(y_out * gate_msa.squeeze())
+ x = x + (y_out * gate_msa.squeeze(1))
norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v", "animate", "s2v", "rs2v"] and self.config.get("use_image_encoder", True):
@@ -356,7 +361,7 @@ def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
if self.sensitive_layer_dtype != self.infer_dtype:
x = x.to(self.sensitive_layer_dtype) + y.to(self.sensitive_layer_dtype) * c_gate_msa.squeeze()
else:
- x.add_(y * c_gate_msa.squeeze())
+ x = x + (y * c_gate_msa.squeeze(1))
if self.clean_cuda_cache:
del y, c_gate_msa
diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py
index 5b15b9b0..4b676ff4 100755
--- a/lightx2v/models/networks/wan/model.py
+++ b/lightx2v/models/networks/wan/model.py
@@ -138,9 +138,10 @@ def _seq_parallel_pre_process(self, pre_infer_out):
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v", "rs2v"]:
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
- padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
+ # Fix: reuse the padding_size calculated above for embed and embed0
if padding_size > 0:
embed = F.pad(embed, (0, 0, 0, padding_size))
+ # Note: if embed0 is a 3D tensor, the first few 0s pad the later dimensions, only pad the 0th dimension (Sequence dimension)
embed0 = F.pad(embed0, (0, 0, 0, 0, 0, padding_size))
pre_infer_out.embed = torch.chunk(embed, world_size, dim=0)[cur_rank]
diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py
index 8c2e2989..4f7db118 100755
--- a/lightx2v/models/runners/default_runner.py
+++ b/lightx2v/models/runners/default_runner.py
@@ -348,9 +348,14 @@ def init_run(self):
self.model = self.load_transformer()
self.model.set_scheduler(self.scheduler)
- self.model.scheduler.prepare(
- seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.model.scheduler.infer_steps, image_encoder_output=self.inputs["image_encoder_output"]
- )
+ import inspect
+
+ prepare_kwargs = {"seed": self.input_info.seed, "latent_shape": self.input_info.latent_shape, "image_encoder_output": self.inputs.get("image_encoder_output", None)}
+ sig = inspect.signature(self.model.scheduler.prepare)
+ if "infer_steps" in sig.parameters:
+ prepare_kwargs["infer_steps"] = self.model.scheduler.infer_steps
+
+ self.model.scheduler.prepare(**prepare_kwargs)
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]:
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py
index 5c92284e..e9c31e5a 100755
--- a/lightx2v/pipeline.py
+++ b/lightx2v/pipeline.py
@@ -100,6 +100,7 @@ def __init__(
elif self.model_cls in ["wan2.2"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 48
+ self.use_image_encoder = False
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
self.vae_stride = (4, 16, 16)
self.num_channels_latents = 32
diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py
index 3017c9f6..082b1ee8 100755
--- a/lightx2v/utils/set_config.py
+++ b/lightx2v/utils/set_config.py
@@ -37,7 +37,20 @@ def get_default_config():
def set_args2config(args):
config = get_default_config()
- config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS})
+
+ task = getattr(args, "task", None)
+ if task:
+ from lightx2v.utils.input_info import init_empty_input_info
+
+ try:
+ input_info = init_empty_input_info(task)
+ exclude_keys = set(input_info.__dataclass_fields__.keys())
+ except Exception:
+ exclude_keys = ALL_INPUT_INFO_KEYS
+ else:
+ exclude_keys = ALL_INPUT_INFO_KEYS
+
+ config.update({k: v for k, v in vars(args).items() if k not in exclude_keys})
return config
diff --git a/magi_scripts/run_hunyuan.sh b/magi_scripts/run_hunyuan.sh
new file mode 100644
index 00000000..255d4ab5
--- /dev/null
+++ b/magi_scripts/run_hunyuan.sh
@@ -0,0 +1,23 @@
+### Distributed args ###
+MASTER_ADDR=${MASTER_ADDR:-localhost}
+MASTER_PORT=${MASTER_PORT:-29501}
+
+# GPUS_PER_NODE=${GPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT"
+
+CP_SIZE=${CP_SIZE:-$GPUS_PER_NODE}
+export CP_SIZE=$CP_SIZE
+
+CPU_OFFLOAD=${CPU_OFFLOAD:-false}
+export CPU_OFFLOAD=$CPU_OFFLOAD
+
+### Log dir ###
+current_time=$(date "+%Y-%m-%d_%H:%M:%S")
+LOG_FILE=${LOG_FILE:-log_${current_time}.log}
+
+torchrun $DISTRIBUTED_ARGS examples/hunyuan_video/hunyuan_t2v_distill.py \
+ ${POST_ARGS} 2>&1 | tee $LOG_FILE
diff --git a/magi_scripts/run_wan.sh b/magi_scripts/run_wan.sh
new file mode 100644
index 00000000..a8b966d9
--- /dev/null
+++ b/magi_scripts/run_wan.sh
@@ -0,0 +1,23 @@
+### Distributed args ###
+MASTER_ADDR=${MASTER_ADDR:-localhost}
+MASTER_PORT=${MASTER_PORT:-29501}
+
+# GPUS_PER_NODE=${GPUS_PER_NODE:-$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)}
+GPUS_PER_NODE=${GPUS_PER_NODE:-1}
+NNODES=${NNODES:-1}
+NODE_RANK=${NODE_RANK:-0}
+WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
+DISTRIBUTED_ARGS="--nnodes=$NNODES --node_rank=$NODE_RANK --nproc_per_node=$GPUS_PER_NODE --rdzv-backend=c10d --rdzv-endpoint=$MASTER_ADDR:$MASTER_PORT"
+
+CP_SIZE=${CP_SIZE:-$GPUS_PER_NODE}
+export CP_SIZE=$CP_SIZE
+
+CPU_OFFLOAD=${CPU_OFFLOAD:-false}
+export CPU_OFFLOAD=$CPU_OFFLOAD
+
+### Log dir ###
+current_time=$(date "+%Y-%m-%d_%H:%M:%S")
+LOG_FILE=${LOG_FILE:-log_${current_time}.log}
+
+$LAUNCH_PREFIX torchrun $DISTRIBUTED_ARGS examples/wan/wan_ti2v.py \
+ ${POST_ARGS} 2>&1 | tee $LOG_FILE
diff --git a/pyproject.toml b/pyproject.toml
index c62264b4..d500fd9c 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,9 +33,7 @@ classifiers = [
dependencies = [
"numpy",
"scipy",
- "torch<=2.8.0",
- "torchvision<=0.23.0",
- "torchaudio<=2.8.0",
+ "pydantic_settings",
"diffusers",
"transformers",
"tokenizers",
diff --git a/requirements.txt b/requirements.txt
index e83cb4f5..765121d5 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,9 +2,9 @@ packaging
ninja
numpy
scipy
-torch<=2.8.0
-torchvision<=0.23.0
-torchaudio<=2.8.0
+torch==2.9.0
+torchvision==0.24.0
+torchaudio==2.9.0
torchao
diffusers
transformers