Spaces:
Running
Running
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- README.md +6 -5
- app.py +387 -0
- examples/i2v_input.JPG +3 -0
- generate.py +411 -0
- requirements.txt +15 -0
- wan/__init__.py +5 -0
- wan/__pycache__/__init__.cpython-310.pyc +0 -0
- wan/__pycache__/image2video.cpython-310.pyc +0 -0
- wan/__pycache__/text2video.cpython-310.pyc +0 -0
- wan/__pycache__/textimage2video.cpython-310.pyc +0 -0
- wan/configs/__init__.py +39 -0
- wan/configs/__pycache__/__init__.cpython-310.pyc +0 -0
- wan/configs/__pycache__/shared_config.cpython-310.pyc +0 -0
- wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc +0 -0
- wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc +0 -0
- wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc +0 -0
- wan/configs/shared_config.py +20 -0
- wan/configs/wan_i2v_A14B.py +37 -0
- wan/configs/wan_t2v_A14B.py +37 -0
- wan/configs/wan_ti2v_5B.py +36 -0
- wan/distributed/__init__.py +1 -0
- wan/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
- wan/distributed/__pycache__/fsdp.cpython-310.pyc +0 -0
- wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc +0 -0
- wan/distributed/__pycache__/ulysses.cpython-310.pyc +0 -0
- wan/distributed/__pycache__/util.cpython-310.pyc +0 -0
- wan/distributed/fsdp.py +43 -0
- wan/distributed/sequence_parallel.py +176 -0
- wan/distributed/ulysses.py +47 -0
- wan/distributed/util.py +51 -0
- wan/image2video.py +431 -0
- wan/modules/__init__.py +19 -0
- wan/modules/__pycache__/__init__.cpython-310.pyc +0 -0
- wan/modules/__pycache__/attention.cpython-310.pyc +0 -0
- wan/modules/__pycache__/model.cpython-310.pyc +0 -0
- wan/modules/__pycache__/t5.cpython-310.pyc +0 -0
- wan/modules/__pycache__/tokenizers.cpython-310.pyc +0 -0
- wan/modules/__pycache__/vae2_1.cpython-310.pyc +0 -0
- wan/modules/__pycache__/vae2_2.cpython-310.pyc +0 -0
- wan/modules/attention.py +179 -0
- wan/modules/model.py +546 -0
- wan/modules/t5.py +513 -0
- wan/modules/tokenizers.py +82 -0
- wan/modules/vae2_1.py +663 -0
- wan/modules/vae2_2.py +1051 -0
- wan/text2video.py +378 -0
- wan/textimage2video.py +619 -0
- wan/utils/__init__.py +12 -0
- wan/utils/__pycache__/__init__.cpython-310.pyc +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: wan2.2_enhanced_amd
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.39.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
short_description: Wan 2.2 5B
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 4 |
+
|
| 5 |
+
#import subprocess
|
| 6 |
+
#subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 7 |
+
|
| 8 |
+
# wan2.2-main/gradio_ti2v.py
|
| 9 |
+
import gradio as gr
|
| 10 |
+
import torch
|
| 11 |
+
from huggingface_hub import snapshot_download
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import random
|
| 14 |
+
import numpy as np
|
| 15 |
+
import spaces
|
| 16 |
+
|
| 17 |
+
import wan
|
| 18 |
+
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
|
| 19 |
+
from wan.utils.utils import cache_video
|
| 20 |
+
|
| 21 |
+
import gc
|
| 22 |
+
|
| 23 |
+
# --- 1. Global Setup and Model Loading ---
|
| 24 |
+
|
| 25 |
+
print("Starting Gradio App for Wan 2.2 TI2V-5B...")
|
| 26 |
+
|
| 27 |
+
# Download model snapshots from Hugging Face Hub
|
| 28 |
+
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
|
| 29 |
+
print(f"Downloading/loading checkpoints for {repo_id}...")
|
| 30 |
+
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
|
| 31 |
+
print(f"Using checkpoints from {ckpt_dir}")
|
| 32 |
+
|
| 33 |
+
# Load the model configuration
|
| 34 |
+
TASK_NAME = 'ti2v-5B'
|
| 35 |
+
cfg = WAN_CONFIGS[TASK_NAME]
|
| 36 |
+
FIXED_FPS = 24
|
| 37 |
+
MIN_FRAMES_MODEL = 8
|
| 38 |
+
MAX_FRAMES_MODEL = 121
|
| 39 |
+
|
| 40 |
+
# Instantiate the pipeline in the global scope
|
| 41 |
+
print("Initializing WanTI2V pipeline...")
|
| 42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
device_id = 0 if torch.cuda.is_available() else -1
|
| 44 |
+
pipeline = wan.WanTI2V(
|
| 45 |
+
config=cfg,
|
| 46 |
+
checkpoint_dir=ckpt_dir,
|
| 47 |
+
device_id=device_id,
|
| 48 |
+
rank=0,
|
| 49 |
+
t5_fsdp=False,
|
| 50 |
+
dit_fsdp=False,
|
| 51 |
+
use_sp=False,
|
| 52 |
+
t5_cpu=False,
|
| 53 |
+
init_on_cpu=False,
|
| 54 |
+
convert_model_dtype=True,
|
| 55 |
+
)
|
| 56 |
+
print("Pipeline initialized and ready.")
|
| 57 |
+
|
| 58 |
+
# --- Helper Functions ---
|
| 59 |
+
def clear_gpu_memory():
|
| 60 |
+
"""Clear GPU memory more thoroughly"""
|
| 61 |
+
if torch.cuda.is_available():
|
| 62 |
+
torch.cuda.empty_cache()
|
| 63 |
+
torch.cuda.ipc_collect()
|
| 64 |
+
gc.collect()
|
| 65 |
+
|
| 66 |
+
def select_best_size_for_image(image, available_sizes):
|
| 67 |
+
"""Select the size option with aspect ratio closest to the input image."""
|
| 68 |
+
if image is None:
|
| 69 |
+
return available_sizes[0] # Return first option if no image
|
| 70 |
+
|
| 71 |
+
img_width, img_height = image.size
|
| 72 |
+
img_aspect_ratio = img_height / img_width
|
| 73 |
+
|
| 74 |
+
best_size = available_sizes[0]
|
| 75 |
+
best_diff = float('inf')
|
| 76 |
+
|
| 77 |
+
for size_str in available_sizes:
|
| 78 |
+
# Parse size string like "704*1280"
|
| 79 |
+
height, width = map(int, size_str.split('*'))
|
| 80 |
+
size_aspect_ratio = height / width
|
| 81 |
+
diff = abs(img_aspect_ratio - size_aspect_ratio)
|
| 82 |
+
|
| 83 |
+
if diff < best_diff:
|
| 84 |
+
best_diff = diff
|
| 85 |
+
best_size = size_str
|
| 86 |
+
|
| 87 |
+
return best_size
|
| 88 |
+
|
| 89 |
+
def handle_image_upload(image):
|
| 90 |
+
"""Handle image upload and return the best matching size."""
|
| 91 |
+
if image is None:
|
| 92 |
+
return gr.update()
|
| 93 |
+
|
| 94 |
+
pil_image = Image.fromarray(image).convert("RGB")
|
| 95 |
+
available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
|
| 96 |
+
best_size = select_best_size_for_image(pil_image, available_sizes)
|
| 97 |
+
|
| 98 |
+
return gr.update(value=best_size)
|
| 99 |
+
|
| 100 |
+
def validate_inputs(image, prompt, duration_seconds):
|
| 101 |
+
"""Validate user inputs"""
|
| 102 |
+
errors = []
|
| 103 |
+
|
| 104 |
+
if not prompt or len(prompt.strip()) < 5:
|
| 105 |
+
errors.append("Prompt must be at least 5 characters long.")
|
| 106 |
+
|
| 107 |
+
if image is not None:
|
| 108 |
+
img = Image.fromarray(image)
|
| 109 |
+
if img.size[0] * img.size[1] > 4096 * 4096:
|
| 110 |
+
errors.append("Image size is too large (maximum 4096x4096).")
|
| 111 |
+
|
| 112 |
+
if duration_seconds > 5.0 and image is None:
|
| 113 |
+
errors.append("Videos longer than 5 seconds require an input image.")
|
| 114 |
+
|
| 115 |
+
return errors
|
| 116 |
+
|
| 117 |
+
def get_duration(image,
|
| 118 |
+
prompt,
|
| 119 |
+
size,
|
| 120 |
+
duration_seconds,
|
| 121 |
+
sampling_steps,
|
| 122 |
+
guide_scale,
|
| 123 |
+
shift,
|
| 124 |
+
seed,
|
| 125 |
+
progress):
|
| 126 |
+
"""Calculate dynamic GPU duration based on parameters."""
|
| 127 |
+
if sampling_steps > 35 and duration_seconds >= 2:
|
| 128 |
+
return 120
|
| 129 |
+
elif sampling_steps < 35 or duration_seconds < 2:
|
| 130 |
+
return 105
|
| 131 |
+
else:
|
| 132 |
+
return 90
|
| 133 |
+
|
| 134 |
+
def apply_template(template, current_prompt):
|
| 135 |
+
"""Apply prompt template"""
|
| 136 |
+
if "{subject}" in template:
|
| 137 |
+
# Extract the main subject from current prompt (simple heuristic)
|
| 138 |
+
subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
|
| 139 |
+
return template.replace("{subject}", subject)
|
| 140 |
+
return template + " " + current_prompt
|
| 141 |
+
|
| 142 |
+
# --- 2. Gradio Inference Function ---
|
| 143 |
+
@spaces.GPU(duration=get_duration)
|
| 144 |
+
def generate_video(
|
| 145 |
+
image,
|
| 146 |
+
prompt,
|
| 147 |
+
size,
|
| 148 |
+
duration_seconds,
|
| 149 |
+
sampling_steps,
|
| 150 |
+
guide_scale,
|
| 151 |
+
shift,
|
| 152 |
+
seed,
|
| 153 |
+
progress=gr.Progress(track_tqdm=True)
|
| 154 |
+
):
|
| 155 |
+
"""The main function to generate video, called by the Gradio interface."""
|
| 156 |
+
# Validate inputs
|
| 157 |
+
errors = validate_inputs(image, prompt, duration_seconds)
|
| 158 |
+
if errors:
|
| 159 |
+
raise gr.Error("\n".join(errors))
|
| 160 |
+
|
| 161 |
+
progress(0, desc="Setting up...")
|
| 162 |
+
|
| 163 |
+
if seed == -1:
|
| 164 |
+
seed = random.randint(0, sys.maxsize)
|
| 165 |
+
|
| 166 |
+
progress(0.1, desc="Processing image...")
|
| 167 |
+
|
| 168 |
+
input_image = None
|
| 169 |
+
if image is not None:
|
| 170 |
+
input_image = Image.fromarray(image).convert("RGB")
|
| 171 |
+
# Resize image to match selected size
|
| 172 |
+
target_height, target_width = map(int, size.split('*'))
|
| 173 |
+
input_image = input_image.resize((target_width, target_height))
|
| 174 |
+
|
| 175 |
+
# Calculate number of frames based on duration
|
| 176 |
+
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
|
| 177 |
+
|
| 178 |
+
progress(0.2, desc="Generating video...")
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
video_tensor = pipeline.generate(
|
| 182 |
+
input_prompt=prompt,
|
| 183 |
+
img=input_image, # Pass None for T2V, Image for I2V
|
| 184 |
+
size=SIZE_CONFIGS[size],
|
| 185 |
+
max_area=MAX_AREA_CONFIGS[size],
|
| 186 |
+
frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
|
| 187 |
+
shift=shift,
|
| 188 |
+
sample_solver='unipc',
|
| 189 |
+
sampling_steps=int(sampling_steps),
|
| 190 |
+
guide_scale=guide_scale,
|
| 191 |
+
seed=seed,
|
| 192 |
+
offload_model=True
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
progress(0.9, desc="Saving video...")
|
| 196 |
+
|
| 197 |
+
# Save the video to a temporary file
|
| 198 |
+
video_path = cache_video(
|
| 199 |
+
tensor=video_tensor[None], # Add a batch dimension
|
| 200 |
+
save_file=None, # cache_video will create a temp file
|
| 201 |
+
fps=cfg.sample_fps,
|
| 202 |
+
normalize=True,
|
| 203 |
+
value_range=(-1, 1)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
progress(1.0, desc="Complete!")
|
| 207 |
+
|
| 208 |
+
except torch.cuda.OutOfMemoryError:
|
| 209 |
+
clear_gpu_memory()
|
| 210 |
+
raise gr.Error("GPU out of memory. Please try with lower settings.")
|
| 211 |
+
except Exception as e:
|
| 212 |
+
raise gr.Error(f"Video generation failed: {str(e)}")
|
| 213 |
+
finally:
|
| 214 |
+
if 'video_tensor' in locals():
|
| 215 |
+
del video_tensor
|
| 216 |
+
clear_gpu_memory()
|
| 217 |
+
|
| 218 |
+
return video_path
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# --- 3. Gradio Interface ---
|
| 222 |
+
css = """
|
| 223 |
+
.gradio-container {max-width: 1100px !important; margin: 0 auto}
|
| 224 |
+
#output_video {height: 500px;}
|
| 225 |
+
#input_image {height: 500px;}
|
| 226 |
+
.template-btn {margin: 2px !important;}
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
# Default prompt with motion emphasis
|
| 230 |
+
DEFAULT_PROMPT = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
|
| 231 |
+
|
| 232 |
+
# Prompt templates
|
| 233 |
+
templates = {
|
| 234 |
+
"Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
|
| 235 |
+
"Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
|
| 236 |
+
"Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
|
| 237 |
+
"Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
|
| 238 |
+
"Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
|
| 242 |
+
gr.Markdown("""
|
| 243 |
+
# Wan 2.2 TI2V Enhanced running on AMD MI355
|
| 244 |
+
|
| 245 |
+
Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**
|
| 246 |
+
[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)
|
| 247 |
+
|
| 248 |
+
### 💡 Tips for best results:
|
| 249 |
+
- 🖼️ Upload an image for better control over the video content
|
| 250 |
+
- ⏱️ Longer videos require more processing time
|
| 251 |
+
- 🎯 Be specific and descriptive in your prompts
|
| 252 |
+
- 🎬 Include motion-related keywords for dynamic videos
|
| 253 |
+
""")
|
| 254 |
+
|
| 255 |
+
with gr.Row():
|
| 256 |
+
with gr.Column(scale=2):
|
| 257 |
+
image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
|
| 258 |
+
prompt_input = gr.Textbox(
|
| 259 |
+
label="Prompt",
|
| 260 |
+
value=DEFAULT_PROMPT,
|
| 261 |
+
lines=3,
|
| 262 |
+
placeholder="Describe the video you want to generate..."
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Prompt templates section
|
| 266 |
+
with gr.Accordion("Prompt Templates", open=False):
|
| 267 |
+
gr.Markdown("Click a template to apply it to your prompt:")
|
| 268 |
+
with gr.Row():
|
| 269 |
+
template_buttons = {}
|
| 270 |
+
for name, template in templates.items():
|
| 271 |
+
btn = gr.Button(name, size="sm", elem_classes=["template-btn"])
|
| 272 |
+
template_buttons[name] = (btn, template)
|
| 273 |
+
|
| 274 |
+
# Connect template buttons
|
| 275 |
+
for name, (btn, template) in template_buttons.items():
|
| 276 |
+
btn.click(
|
| 277 |
+
fn=lambda t=template, p=prompt_input: apply_template(t, p),
|
| 278 |
+
inputs=[prompt_input],
|
| 279 |
+
outputs=prompt_input
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
duration_input = gr.Slider(
|
| 283 |
+
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
|
| 284 |
+
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
|
| 285 |
+
step=0.1,
|
| 286 |
+
value=2.0,
|
| 287 |
+
label="Duration (seconds)",
|
| 288 |
+
info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
|
| 289 |
+
)
|
| 290 |
+
size_input = gr.Dropdown(
|
| 291 |
+
label="Output Resolution",
|
| 292 |
+
choices=list(SUPPORTED_SIZES[TASK_NAME]),
|
| 293 |
+
value="704*1280"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
with gr.Column(scale=2):
|
| 297 |
+
video_output = gr.Video(label="Generated Video", elem_id="output_video")
|
| 298 |
+
|
| 299 |
+
# Status indicators
|
| 300 |
+
with gr.Row():
|
| 301 |
+
status_text = gr.Textbox(
|
| 302 |
+
label="Status",
|
| 303 |
+
value="Ready",
|
| 304 |
+
interactive=False,
|
| 305 |
+
max_lines=1
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 309 |
+
steps_input = gr.Slider(
|
| 310 |
+
label="Sampling Steps",
|
| 311 |
+
minimum=10,
|
| 312 |
+
maximum=50,
|
| 313 |
+
value=38,
|
| 314 |
+
step=1,
|
| 315 |
+
info="Higher values = better quality but slower"
|
| 316 |
+
)
|
| 317 |
+
scale_input = gr.Slider(
|
| 318 |
+
label="Guidance Scale",
|
| 319 |
+
minimum=1.0,
|
| 320 |
+
maximum=10.0,
|
| 321 |
+
value=cfg.sample_guide_scale,
|
| 322 |
+
step=0.1,
|
| 323 |
+
info="Higher values = closer to prompt but less creative"
|
| 324 |
+
)
|
| 325 |
+
shift_input = gr.Slider(
|
| 326 |
+
label="Sample Shift",
|
| 327 |
+
minimum=1.0,
|
| 328 |
+
maximum=20.0,
|
| 329 |
+
value=cfg.sample_shift,
|
| 330 |
+
step=0.1,
|
| 331 |
+
info="Affects the sampling process dynamics"
|
| 332 |
+
)
|
| 333 |
+
seed_input = gr.Number(
|
| 334 |
+
label="Seed (-1 for random)",
|
| 335 |
+
value=-1,
|
| 336 |
+
precision=0,
|
| 337 |
+
info="Use same seed for reproducible results"
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
run_button = gr.Button("Generate Video", variant="primary", size="lg")
|
| 341 |
+
|
| 342 |
+
# Add image upload handler
|
| 343 |
+
image_input.upload(
|
| 344 |
+
fn=handle_image_upload,
|
| 345 |
+
inputs=[image_input],
|
| 346 |
+
outputs=[size_input]
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
image_input.clear(
|
| 350 |
+
fn=handle_image_upload,
|
| 351 |
+
inputs=[image_input],
|
| 352 |
+
outputs=[size_input]
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
# Update status when generating
|
| 356 |
+
def update_status_and_generate(*args):
|
| 357 |
+
status_text.value = "Generating..."
|
| 358 |
+
try:
|
| 359 |
+
result = generate_video(*args)
|
| 360 |
+
status_text.value = "Complete!"
|
| 361 |
+
return result
|
| 362 |
+
except Exception as e:
|
| 363 |
+
status_text.value = "Error occurred"
|
| 364 |
+
raise e
|
| 365 |
+
|
| 366 |
+
example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
|
| 367 |
+
gr.Examples(
|
| 368 |
+
examples=[
|
| 369 |
+
[None, "Golden hour, soft lighting, warm colors, saturated colors, wide shot, left-heavy composition. A weathered gondolier stands in a flat-bottomed boat, propelling it forward with a long wooden pole through the flooded ruins of Venice. The decaying buildings on either side are cloaked in creeping vines and marked by rusted metalwork, their once-proud facades now crumbling into the water. The camera moves slowly forward and tilts left, revealing behind him the majestic remnants of the city bathed in the amber glow of the setting sun. Silhouettes of collapsed archways and broken domes rise against the golden skyline, while the still water reflects the warm hues of the sky and surrounding structures.", "1280*704", 4.0],
|
| 370 |
+
[None, "In a surreal video, four miniature skiers glide down a winding, three-dimensional trail of thick white paint on a plain white canvas-like background. The textured paint mimics snow, with visible brushstrokes and uneven edges, enhanced by light and shadow. The skiers, in colorful gear, are posed dynamically from top to bottom, each casting a shadow that heightens the illusion of depth. This scene miniaturizes a grand outdoor sport into a vivid, imaginative artwork.", "1280*704", 2.0],
|
| 371 |
+
[None, "In a time-lapse video, a crane slowly lifts a steel beam on a construction site. The camera pulls back slowly from a close-up, revealing details of the crane and the steel beam. The skyline transitions from day to night, with buildings and machinery in the background constantly operating. As the camera pulls further back, the busy scene of the entire construction site comes into view; cranes and other equipment continue working under the night sky, shaping the city's outline.", "704*1280", 2.5],
|
| 372 |
+
[None, "Cinematic racetrack scene: Low-angle medium long shot of jockey-horse leap. High-contrast backlighting, warm tones, silhouettes. Slow-motion freeze with dust for dynamic tension. Scoreboard detail. Optimized for immersive video generation.", "1280*704", 3.0],
|
| 373 |
+
],
|
| 374 |
+
inputs=[image_input, prompt_input, size_input, duration_input],
|
| 375 |
+
outputs=video_output,
|
| 376 |
+
fn=generate_video,
|
| 377 |
+
cache_examples=False,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
run_button.click(
|
| 381 |
+
fn=generate_video,
|
| 382 |
+
inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
|
| 383 |
+
outputs=video_output
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
if __name__ == "__main__":
|
| 387 |
+
demo.launch()
|
examples/i2v_input.JPG
ADDED
|
|
Git LFS Details
|
generate.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import warnings
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
warnings.filterwarnings('ignore')
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
import wan
|
| 18 |
+
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
| 19 |
+
from wan.distributed.util import init_distributed_group
|
| 20 |
+
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
| 21 |
+
from wan.utils.utils import cache_video, str2bool
|
| 22 |
+
|
| 23 |
+
EXAMPLE_PROMPT = {
|
| 24 |
+
"t2v-A14B": {
|
| 25 |
+
"prompt":
|
| 26 |
+
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
| 27 |
+
},
|
| 28 |
+
"i2v-A14B": {
|
| 29 |
+
"prompt":
|
| 30 |
+
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
| 31 |
+
"image":
|
| 32 |
+
"examples/i2v_input.JPG",
|
| 33 |
+
},
|
| 34 |
+
"ti2v-5B": {
|
| 35 |
+
"prompt":
|
| 36 |
+
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
| 37 |
+
},
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _validate_args(args):
|
| 42 |
+
# Basic check
|
| 43 |
+
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
| 44 |
+
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
| 45 |
+
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
| 46 |
+
|
| 47 |
+
if args.prompt is None:
|
| 48 |
+
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
| 49 |
+
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
|
| 50 |
+
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
| 51 |
+
|
| 52 |
+
if args.task == "i2v-A14B":
|
| 53 |
+
assert args.image is not None, "Please specify the image path for i2v."
|
| 54 |
+
|
| 55 |
+
cfg = WAN_CONFIGS[args.task]
|
| 56 |
+
|
| 57 |
+
if args.sample_steps is None:
|
| 58 |
+
args.sample_steps = cfg.sample_steps
|
| 59 |
+
|
| 60 |
+
if args.sample_shift is None:
|
| 61 |
+
args.sample_shift = cfg.sample_shift
|
| 62 |
+
|
| 63 |
+
if args.sample_guide_scale is None:
|
| 64 |
+
args.sample_guide_scale = cfg.sample_guide_scale
|
| 65 |
+
|
| 66 |
+
if args.frame_num is None:
|
| 67 |
+
args.frame_num = cfg.frame_num
|
| 68 |
+
|
| 69 |
+
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
| 70 |
+
0, sys.maxsize)
|
| 71 |
+
# Size check
|
| 72 |
+
assert args.size in SUPPORTED_SIZES[
|
| 73 |
+
args.
|
| 74 |
+
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _parse_args():
|
| 78 |
+
parser = argparse.ArgumentParser(
|
| 79 |
+
description="Generate a image or video from a text prompt or image using Wan"
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--task",
|
| 83 |
+
type=str,
|
| 84 |
+
default="t2v-A14B",
|
| 85 |
+
choices=list(WAN_CONFIGS.keys()),
|
| 86 |
+
help="The task to run.")
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--size",
|
| 89 |
+
type=str,
|
| 90 |
+
default="1280*720",
|
| 91 |
+
choices=list(SIZE_CONFIGS.keys()),
|
| 92 |
+
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--frame_num",
|
| 96 |
+
type=int,
|
| 97 |
+
default=None,
|
| 98 |
+
help="How many frames of video are generated. The number should be 4n+1"
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--ckpt_dir",
|
| 102 |
+
type=str,
|
| 103 |
+
default=None,
|
| 104 |
+
help="The path to the checkpoint directory.")
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--offload_model",
|
| 107 |
+
type=str2bool,
|
| 108 |
+
default=None,
|
| 109 |
+
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--ulysses_size",
|
| 113 |
+
type=int,
|
| 114 |
+
default=1,
|
| 115 |
+
help="The size of the ulysses parallelism in DiT.")
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--t5_fsdp",
|
| 118 |
+
action="store_true",
|
| 119 |
+
default=False,
|
| 120 |
+
help="Whether to use FSDP for T5.")
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--t5_cpu",
|
| 123 |
+
action="store_true",
|
| 124 |
+
default=False,
|
| 125 |
+
help="Whether to place T5 model on CPU.")
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--dit_fsdp",
|
| 128 |
+
action="store_true",
|
| 129 |
+
default=False,
|
| 130 |
+
help="Whether to use FSDP for DiT.")
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--save_file",
|
| 133 |
+
type=str,
|
| 134 |
+
default=None,
|
| 135 |
+
help="The file to save the generated video to.")
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--prompt",
|
| 138 |
+
type=str,
|
| 139 |
+
default=None,
|
| 140 |
+
help="The prompt to generate the video from.")
|
| 141 |
+
parser.add_argument(
|
| 142 |
+
"--use_prompt_extend",
|
| 143 |
+
action="store_true",
|
| 144 |
+
default=False,
|
| 145 |
+
help="Whether to use prompt extend.")
|
| 146 |
+
parser.add_argument(
|
| 147 |
+
"--prompt_extend_method",
|
| 148 |
+
type=str,
|
| 149 |
+
default="local_qwen",
|
| 150 |
+
choices=["dashscope", "local_qwen"],
|
| 151 |
+
help="The prompt extend method to use.")
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--prompt_extend_model",
|
| 154 |
+
type=str,
|
| 155 |
+
default=None,
|
| 156 |
+
help="The prompt extend model to use.")
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--prompt_extend_target_lang",
|
| 159 |
+
type=str,
|
| 160 |
+
default="zh",
|
| 161 |
+
choices=["zh", "en"],
|
| 162 |
+
help="The target language of prompt extend.")
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--base_seed",
|
| 165 |
+
type=int,
|
| 166 |
+
default=-1,
|
| 167 |
+
help="The seed to use for generating the video.")
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--image",
|
| 170 |
+
type=str,
|
| 171 |
+
default=None,
|
| 172 |
+
help="The image to generate the video from.")
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--sample_solver",
|
| 175 |
+
type=str,
|
| 176 |
+
default='unipc',
|
| 177 |
+
choices=['unipc', 'dpm++'],
|
| 178 |
+
help="The solver used to sample.")
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--sample_steps", type=int, default=None, help="The sampling steps.")
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--sample_shift",
|
| 183 |
+
type=float,
|
| 184 |
+
default=None,
|
| 185 |
+
help="Sampling shift factor for flow matching schedulers.")
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--sample_guide_scale",
|
| 188 |
+
type=float,
|
| 189 |
+
default=None,
|
| 190 |
+
help="Classifier free guidance scale.")
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--convert_model_dtype",
|
| 193 |
+
action="store_true",
|
| 194 |
+
default=False,
|
| 195 |
+
help="Whether to convert model paramerters dtype.")
|
| 196 |
+
|
| 197 |
+
args = parser.parse_args()
|
| 198 |
+
|
| 199 |
+
_validate_args(args)
|
| 200 |
+
|
| 201 |
+
return args
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _init_logging(rank):
|
| 205 |
+
# logging
|
| 206 |
+
if rank == 0:
|
| 207 |
+
# set format
|
| 208 |
+
logging.basicConfig(
|
| 209 |
+
level=logging.INFO,
|
| 210 |
+
format="[%(asctime)s] %(levelname)s: %(message)s",
|
| 211 |
+
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
| 212 |
+
else:
|
| 213 |
+
logging.basicConfig(level=logging.ERROR)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def generate(args):
|
| 217 |
+
rank = int(os.getenv("RANK", 0))
|
| 218 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 219 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 220 |
+
device = local_rank
|
| 221 |
+
_init_logging(rank)
|
| 222 |
+
|
| 223 |
+
if args.offload_model is None:
|
| 224 |
+
args.offload_model = False if world_size > 1 else True
|
| 225 |
+
logging.info(
|
| 226 |
+
f"offload_model is not specified, set to {args.offload_model}.")
|
| 227 |
+
if world_size > 1:
|
| 228 |
+
torch.cuda.set_device(local_rank)
|
| 229 |
+
dist.init_process_group(
|
| 230 |
+
backend="nccl",
|
| 231 |
+
init_method="env://",
|
| 232 |
+
rank=rank,
|
| 233 |
+
world_size=world_size)
|
| 234 |
+
else:
|
| 235 |
+
assert not (
|
| 236 |
+
args.t5_fsdp or args.dit_fsdp
|
| 237 |
+
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
|
| 238 |
+
assert not (
|
| 239 |
+
args.ulysses_size > 1
|
| 240 |
+
), f"sequence parallel are not supported in non-distributed environments."
|
| 241 |
+
|
| 242 |
+
if args.ulysses_size > 1:
|
| 243 |
+
assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
|
| 244 |
+
init_distributed_group()
|
| 245 |
+
|
| 246 |
+
if args.use_prompt_extend:
|
| 247 |
+
if args.prompt_extend_method == "dashscope":
|
| 248 |
+
prompt_expander = DashScopePromptExpander(
|
| 249 |
+
model_name=args.prompt_extend_model,
|
| 250 |
+
task=args.task,
|
| 251 |
+
is_vl=args.image is not None)
|
| 252 |
+
elif args.prompt_extend_method == "local_qwen":
|
| 253 |
+
prompt_expander = QwenPromptExpander(
|
| 254 |
+
model_name=args.prompt_extend_model,
|
| 255 |
+
task=args.task,
|
| 256 |
+
is_vl=args.image is not None,
|
| 257 |
+
device=rank)
|
| 258 |
+
else:
|
| 259 |
+
raise NotImplementedError(
|
| 260 |
+
f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
|
| 261 |
+
|
| 262 |
+
cfg = WAN_CONFIGS[args.task]
|
| 263 |
+
if args.ulysses_size > 1:
|
| 264 |
+
assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
|
| 265 |
+
|
| 266 |
+
logging.info(f"Generation job args: {args}")
|
| 267 |
+
logging.info(f"Generation model config: {cfg}")
|
| 268 |
+
|
| 269 |
+
if dist.is_initialized():
|
| 270 |
+
base_seed = [args.base_seed] if rank == 0 else [None]
|
| 271 |
+
dist.broadcast_object_list(base_seed, src=0)
|
| 272 |
+
args.base_seed = base_seed[0]
|
| 273 |
+
|
| 274 |
+
logging.info(f"Input prompt: {args.prompt}")
|
| 275 |
+
img = None
|
| 276 |
+
if args.image is not None:
|
| 277 |
+
img = Image.open(args.image).convert("RGB")
|
| 278 |
+
logging.info(f"Input image: {args.image}")
|
| 279 |
+
|
| 280 |
+
# prompt extend
|
| 281 |
+
if args.use_prompt_extend:
|
| 282 |
+
logging.info("Extending prompt ...")
|
| 283 |
+
if rank == 0:
|
| 284 |
+
prompt_output = prompt_expander(
|
| 285 |
+
args.prompt,
|
| 286 |
+
image=img,
|
| 287 |
+
tar_lang=args.prompt_extend_target_lang,
|
| 288 |
+
seed=args.base_seed)
|
| 289 |
+
if prompt_output.status == False:
|
| 290 |
+
logging.info(
|
| 291 |
+
f"Extending prompt failed: {prompt_output.message}")
|
| 292 |
+
logging.info("Falling back to original prompt.")
|
| 293 |
+
input_prompt = args.prompt
|
| 294 |
+
else:
|
| 295 |
+
input_prompt = prompt_output.prompt
|
| 296 |
+
input_prompt = [input_prompt]
|
| 297 |
+
else:
|
| 298 |
+
input_prompt = [None]
|
| 299 |
+
if dist.is_initialized():
|
| 300 |
+
dist.broadcast_object_list(input_prompt, src=0)
|
| 301 |
+
args.prompt = input_prompt[0]
|
| 302 |
+
logging.info(f"Extended prompt: {args.prompt}")
|
| 303 |
+
|
| 304 |
+
if "t2v" in args.task:
|
| 305 |
+
logging.info("Creating WanT2V pipeline.")
|
| 306 |
+
wan_t2v = wan.WanT2V(
|
| 307 |
+
config=cfg,
|
| 308 |
+
checkpoint_dir=args.ckpt_dir,
|
| 309 |
+
device_id=device,
|
| 310 |
+
rank=rank,
|
| 311 |
+
t5_fsdp=args.t5_fsdp,
|
| 312 |
+
dit_fsdp=args.dit_fsdp,
|
| 313 |
+
use_sp=(args.ulysses_size > 1),
|
| 314 |
+
t5_cpu=args.t5_cpu,
|
| 315 |
+
convert_model_dtype=args.convert_model_dtype,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
logging.info(f"Generating video ...")
|
| 319 |
+
video = wan_t2v.generate(
|
| 320 |
+
args.prompt,
|
| 321 |
+
size=SIZE_CONFIGS[args.size],
|
| 322 |
+
frame_num=args.frame_num,
|
| 323 |
+
shift=args.sample_shift,
|
| 324 |
+
sample_solver=args.sample_solver,
|
| 325 |
+
sampling_steps=args.sample_steps,
|
| 326 |
+
guide_scale=args.sample_guide_scale,
|
| 327 |
+
seed=args.base_seed,
|
| 328 |
+
offload_model=args.offload_model)
|
| 329 |
+
elif "ti2v" in args.task:
|
| 330 |
+
logging.info("Creating WanTI2V pipeline.")
|
| 331 |
+
wan_ti2v = wan.WanTI2V(
|
| 332 |
+
config=cfg,
|
| 333 |
+
checkpoint_dir=args.ckpt_dir,
|
| 334 |
+
device_id=device,
|
| 335 |
+
rank=rank,
|
| 336 |
+
t5_fsdp=args.t5_fsdp,
|
| 337 |
+
dit_fsdp=args.dit_fsdp,
|
| 338 |
+
use_sp=(args.ulysses_size > 1),
|
| 339 |
+
t5_cpu=args.t5_cpu,
|
| 340 |
+
convert_model_dtype=args.convert_model_dtype,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
logging.info(f"Generating video ...")
|
| 344 |
+
video = wan_ti2v.generate(
|
| 345 |
+
args.prompt,
|
| 346 |
+
img=img,
|
| 347 |
+
size=SIZE_CONFIGS[args.size],
|
| 348 |
+
max_area=MAX_AREA_CONFIGS[args.size],
|
| 349 |
+
frame_num=args.frame_num,
|
| 350 |
+
shift=args.sample_shift,
|
| 351 |
+
sample_solver=args.sample_solver,
|
| 352 |
+
sampling_steps=args.sample_steps,
|
| 353 |
+
guide_scale=args.sample_guide_scale,
|
| 354 |
+
seed=args.base_seed,
|
| 355 |
+
offload_model=args.offload_model)
|
| 356 |
+
else:
|
| 357 |
+
logging.info("Creating WanI2V pipeline.")
|
| 358 |
+
wan_i2v = wan.WanI2V(
|
| 359 |
+
config=cfg,
|
| 360 |
+
checkpoint_dir=args.ckpt_dir,
|
| 361 |
+
device_id=device,
|
| 362 |
+
rank=rank,
|
| 363 |
+
t5_fsdp=args.t5_fsdp,
|
| 364 |
+
dit_fsdp=args.dit_fsdp,
|
| 365 |
+
use_sp=(args.ulysses_size > 1),
|
| 366 |
+
t5_cpu=args.t5_cpu,
|
| 367 |
+
convert_model_dtype=args.convert_model_dtype,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
logging.info("Generating video ...")
|
| 371 |
+
video = wan_i2v.generate(
|
| 372 |
+
args.prompt,
|
| 373 |
+
img,
|
| 374 |
+
max_area=MAX_AREA_CONFIGS[args.size],
|
| 375 |
+
frame_num=args.frame_num,
|
| 376 |
+
shift=args.sample_shift,
|
| 377 |
+
sample_solver=args.sample_solver,
|
| 378 |
+
sampling_steps=args.sample_steps,
|
| 379 |
+
guide_scale=args.sample_guide_scale,
|
| 380 |
+
seed=args.base_seed,
|
| 381 |
+
offload_model=args.offload_model)
|
| 382 |
+
|
| 383 |
+
if rank == 0:
|
| 384 |
+
if args.save_file is None:
|
| 385 |
+
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 386 |
+
formatted_prompt = args.prompt.replace(" ", "_").replace("/",
|
| 387 |
+
"_")[:50]
|
| 388 |
+
suffix = '.mp4'
|
| 389 |
+
args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
|
| 390 |
+
|
| 391 |
+
logging.info(f"Saving generated video to {args.save_file}")
|
| 392 |
+
cache_video(
|
| 393 |
+
tensor=video[None],
|
| 394 |
+
save_file=args.save_file,
|
| 395 |
+
fps=cfg.sample_fps,
|
| 396 |
+
nrow=1,
|
| 397 |
+
normalize=True,
|
| 398 |
+
value_range=(-1, 1))
|
| 399 |
+
del video
|
| 400 |
+
|
| 401 |
+
torch.cuda.synchronize()
|
| 402 |
+
if dist.is_initialized():
|
| 403 |
+
dist.barrier()
|
| 404 |
+
dist.destroy_process_group()
|
| 405 |
+
|
| 406 |
+
logging.info("Finished.")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if __name__ == "__main__":
|
| 410 |
+
args = _parse_args()
|
| 411 |
+
generate(args)
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.4.0
|
| 2 |
+
torchvision>=0.19.0
|
| 3 |
+
opencv-python>=4.9.0.80
|
| 4 |
+
diffusers>=0.31.0
|
| 5 |
+
transformers>=4.49.0
|
| 6 |
+
tokenizers>=0.20.3
|
| 7 |
+
accelerate>=1.1.1
|
| 8 |
+
tqdm
|
| 9 |
+
imageio
|
| 10 |
+
easydict
|
| 11 |
+
ftfy
|
| 12 |
+
dashscope
|
| 13 |
+
imageio-ffmpeg
|
| 14 |
+
https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
|
| 15 |
+
numpy>=1.23.5,<2
|
wan/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from . import configs, distributed, modules
|
| 3 |
+
from .image2video import WanI2V
|
| 4 |
+
from .text2video import WanT2V
|
| 5 |
+
from .textimage2video import WanTI2V
|
wan/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (333 Bytes). View file
|
|
|
wan/__pycache__/image2video.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
wan/__pycache__/text2video.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
wan/__pycache__/textimage2video.cpython-310.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
wan/configs/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
|
| 7 |
+
from .wan_i2v_A14B import i2v_A14B
|
| 8 |
+
from .wan_t2v_A14B import t2v_A14B
|
| 9 |
+
from .wan_ti2v_5B import ti2v_5B
|
| 10 |
+
|
| 11 |
+
WAN_CONFIGS = {
|
| 12 |
+
't2v-A14B': t2v_A14B,
|
| 13 |
+
'i2v-A14B': i2v_A14B,
|
| 14 |
+
'ti2v-5B': ti2v_5B,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
SIZE_CONFIGS = {
|
| 18 |
+
'720*1280': (720, 1280),
|
| 19 |
+
'1280*720': (1280, 720),
|
| 20 |
+
'480*832': (480, 832),
|
| 21 |
+
'832*480': (832, 480),
|
| 22 |
+
'704*1280': (704, 1280),
|
| 23 |
+
'1280*704': (1280, 704)
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
MAX_AREA_CONFIGS = {
|
| 27 |
+
'720*1280': 720 * 1280,
|
| 28 |
+
'1280*720': 1280 * 720,
|
| 29 |
+
'480*832': 480 * 832,
|
| 30 |
+
'832*480': 832 * 480,
|
| 31 |
+
'704*1280': 704 * 1280,
|
| 32 |
+
'1280*704': 1280 * 704,
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
SUPPORTED_SIZES = {
|
| 36 |
+
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 37 |
+
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 38 |
+
'ti2v-5B': ('704*1280', '1280*704'),
|
| 39 |
+
}
|
wan/configs/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (737 Bytes). View file
|
|
|
wan/configs/__pycache__/shared_config.cpython-310.pyc
ADDED
|
Binary file (848 Bytes). View file
|
|
|
wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc
ADDED
|
Binary file (968 Bytes). View file
|
|
|
wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc
ADDED
|
Binary file (955 Bytes). View file
|
|
|
wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc
ADDED
|
Binary file (863 Bytes). View file
|
|
|
wan/configs/shared_config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
#------------------------ Wan shared config ------------------------#
|
| 6 |
+
wan_shared_cfg = EasyDict()
|
| 7 |
+
|
| 8 |
+
# t5
|
| 9 |
+
wan_shared_cfg.t5_model = 'umt5_xxl'
|
| 10 |
+
wan_shared_cfg.t5_dtype = torch.bfloat16
|
| 11 |
+
wan_shared_cfg.text_len = 512
|
| 12 |
+
|
| 13 |
+
# transformer
|
| 14 |
+
wan_shared_cfg.param_dtype = torch.bfloat16
|
| 15 |
+
|
| 16 |
+
# inference
|
| 17 |
+
wan_shared_cfg.num_train_timesteps = 1000
|
| 18 |
+
wan_shared_cfg.sample_fps = 16
|
| 19 |
+
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
| 20 |
+
wan_shared_cfg.frame_num = 81
|
wan/configs/wan_i2v_A14B.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
#------------------------ Wan I2V A14B ------------------------#
|
| 8 |
+
|
| 9 |
+
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
|
| 10 |
+
i2v_A14B.update(wan_shared_cfg)
|
| 11 |
+
|
| 12 |
+
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
i2v_A14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
i2v_A14B.patch_size = (1, 2, 2)
|
| 21 |
+
i2v_A14B.dim = 5120
|
| 22 |
+
i2v_A14B.ffn_dim = 13824
|
| 23 |
+
i2v_A14B.freq_dim = 256
|
| 24 |
+
i2v_A14B.num_heads = 40
|
| 25 |
+
i2v_A14B.num_layers = 40
|
| 26 |
+
i2v_A14B.window_size = (-1, -1)
|
| 27 |
+
i2v_A14B.qk_norm = True
|
| 28 |
+
i2v_A14B.cross_attn_norm = True
|
| 29 |
+
i2v_A14B.eps = 1e-6
|
| 30 |
+
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
| 31 |
+
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
| 32 |
+
|
| 33 |
+
# inference
|
| 34 |
+
i2v_A14B.sample_shift = 5.0
|
| 35 |
+
i2v_A14B.sample_steps = 40
|
| 36 |
+
i2v_A14B.boundary = 0.900
|
| 37 |
+
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
|
wan/configs/wan_t2v_A14B.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V A14B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
|
| 9 |
+
t2v_A14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_A14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_A14B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_A14B.dim = 5120
|
| 22 |
+
t2v_A14B.ffn_dim = 13824
|
| 23 |
+
t2v_A14B.freq_dim = 256
|
| 24 |
+
t2v_A14B.num_heads = 40
|
| 25 |
+
t2v_A14B.num_layers = 40
|
| 26 |
+
t2v_A14B.window_size = (-1, -1)
|
| 27 |
+
t2v_A14B.qk_norm = True
|
| 28 |
+
t2v_A14B.cross_attn_norm = True
|
| 29 |
+
t2v_A14B.eps = 1e-6
|
| 30 |
+
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
| 31 |
+
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
| 32 |
+
|
| 33 |
+
# inference
|
| 34 |
+
t2v_A14B.sample_shift = 12.0
|
| 35 |
+
t2v_A14B.sample_steps = 40
|
| 36 |
+
t2v_A14B.boundary = 0.875
|
| 37 |
+
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
|
wan/configs/wan_ti2v_5B.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan TI2V 5B ------------------------#
|
| 7 |
+
|
| 8 |
+
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
|
| 9 |
+
ti2v_5B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
|
| 17 |
+
ti2v_5B.vae_stride = (4, 16, 16)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
ti2v_5B.patch_size = (1, 2, 2)
|
| 21 |
+
ti2v_5B.dim = 3072
|
| 22 |
+
ti2v_5B.ffn_dim = 14336
|
| 23 |
+
ti2v_5B.freq_dim = 256
|
| 24 |
+
ti2v_5B.num_heads = 24
|
| 25 |
+
ti2v_5B.num_layers = 30
|
| 26 |
+
ti2v_5B.window_size = (-1, -1)
|
| 27 |
+
ti2v_5B.qk_norm = True
|
| 28 |
+
ti2v_5B.cross_attn_norm = True
|
| 29 |
+
ti2v_5B.eps = 1e-6
|
| 30 |
+
|
| 31 |
+
# inference
|
| 32 |
+
ti2v_5B.sample_fps = 24
|
| 33 |
+
ti2v_5B.sample_shift = 5.0
|
| 34 |
+
ti2v_5B.sample_steps = 50
|
| 35 |
+
ti2v_5B.sample_guide_scale = 5.0
|
| 36 |
+
ti2v_5B.frame_num = 121
|
wan/distributed/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
wan/distributed/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
wan/distributed/__pycache__/fsdp.cpython-310.pyc
ADDED
|
Binary file (1.36 kB). View file
|
|
|
wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc
ADDED
|
Binary file (5.24 kB). View file
|
|
|
wan/distributed/__pycache__/ulysses.cpython-310.pyc
ADDED
|
Binary file (1.23 kB). View file
|
|
|
wan/distributed/__pycache__/util.cpython-310.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
wan/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 9 |
+
from torch.distributed.utils import _free_storage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def shard_model(
|
| 13 |
+
model,
|
| 14 |
+
device_id,
|
| 15 |
+
param_dtype=torch.bfloat16,
|
| 16 |
+
reduce_dtype=torch.float32,
|
| 17 |
+
buffer_dtype=torch.float32,
|
| 18 |
+
process_group=None,
|
| 19 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 20 |
+
sync_module_states=True,
|
| 21 |
+
):
|
| 22 |
+
model = FSDP(
|
| 23 |
+
module=model,
|
| 24 |
+
process_group=process_group,
|
| 25 |
+
sharding_strategy=sharding_strategy,
|
| 26 |
+
auto_wrap_policy=partial(
|
| 27 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 28 |
+
mixed_precision=MixedPrecision(
|
| 29 |
+
param_dtype=param_dtype,
|
| 30 |
+
reduce_dtype=reduce_dtype,
|
| 31 |
+
buffer_dtype=buffer_dtype),
|
| 32 |
+
device_id=device_id,
|
| 33 |
+
sync_module_states=sync_module_states)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def free_model(model):
|
| 38 |
+
for m in model.modules():
|
| 39 |
+
if isinstance(m, FSDP):
|
| 40 |
+
_free_storage(m._handle.flat_param.data)
|
| 41 |
+
del model
|
| 42 |
+
gc.collect()
|
| 43 |
+
torch.cuda.empty_cache()
|
wan/distributed/sequence_parallel.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.cuda.amp as amp
|
| 4 |
+
|
| 5 |
+
from ..modules.model import sinusoidal_embedding_1d
|
| 6 |
+
from .ulysses import distributed_attention
|
| 7 |
+
from .util import gather_forward, get_rank, get_world_size
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pad_freqs(original_tensor, target_len):
|
| 11 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 12 |
+
pad_size = target_len - seq_len
|
| 13 |
+
padding_tensor = torch.ones(
|
| 14 |
+
pad_size,
|
| 15 |
+
s1,
|
| 16 |
+
s2,
|
| 17 |
+
dtype=original_tensor.dtype,
|
| 18 |
+
device=original_tensor.device)
|
| 19 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 20 |
+
return padded_tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 24 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 25 |
+
"""
|
| 26 |
+
x: [B, L, N, C].
|
| 27 |
+
grid_sizes: [B, 3].
|
| 28 |
+
freqs: [M, C // 2].
|
| 29 |
+
"""
|
| 30 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 31 |
+
# split freqs
|
| 32 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 33 |
+
|
| 34 |
+
# loop over samples
|
| 35 |
+
output = []
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 41 |
+
s, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
sp_size = get_world_size()
|
| 51 |
+
sp_rank = get_rank()
|
| 52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 53 |
+
s_per_rank = s
|
| 54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 55 |
+
s_per_rank), :, :]
|
| 56 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 57 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 58 |
+
|
| 59 |
+
# append to collection
|
| 60 |
+
output.append(x_i)
|
| 61 |
+
return torch.stack(output).float()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def sp_dit_forward(
|
| 65 |
+
self,
|
| 66 |
+
x,
|
| 67 |
+
t,
|
| 68 |
+
context,
|
| 69 |
+
seq_len,
|
| 70 |
+
y=None,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 74 |
+
t: [B].
|
| 75 |
+
context: A list of text embeddings each with shape [L, C].
|
| 76 |
+
"""
|
| 77 |
+
if self.model_type == 'i2v':
|
| 78 |
+
assert y is not None
|
| 79 |
+
# params
|
| 80 |
+
device = self.patch_embedding.weight.device
|
| 81 |
+
if self.freqs.device != device:
|
| 82 |
+
self.freqs = self.freqs.to(device)
|
| 83 |
+
|
| 84 |
+
if y is not None:
|
| 85 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 86 |
+
|
| 87 |
+
# embeddings
|
| 88 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 89 |
+
grid_sizes = torch.stack(
|
| 90 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 91 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 92 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 93 |
+
assert seq_lens.max() <= seq_len
|
| 94 |
+
x = torch.cat([
|
| 95 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 96 |
+
for u in x
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
# time embeddings
|
| 100 |
+
if t.dim() == 1:
|
| 101 |
+
t = t.expand(t.size(0), seq_len)
|
| 102 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 103 |
+
bt = t.size(0)
|
| 104 |
+
t = t.flatten()
|
| 105 |
+
e = self.time_embedding(
|
| 106 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 107 |
+
t).unflatten(0, (bt, seq_len)).float())
|
| 108 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 109 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 110 |
+
|
| 111 |
+
# context
|
| 112 |
+
context_lens = None
|
| 113 |
+
context = self.text_embedding(
|
| 114 |
+
torch.stack([
|
| 115 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 116 |
+
for u in context
|
| 117 |
+
]))
|
| 118 |
+
|
| 119 |
+
# Context Parallel
|
| 120 |
+
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
|
| 121 |
+
e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
|
| 122 |
+
e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
|
| 123 |
+
|
| 124 |
+
# arguments
|
| 125 |
+
kwargs = dict(
|
| 126 |
+
e=e0,
|
| 127 |
+
seq_lens=seq_lens,
|
| 128 |
+
grid_sizes=grid_sizes,
|
| 129 |
+
freqs=self.freqs,
|
| 130 |
+
context=context,
|
| 131 |
+
context_lens=context_lens)
|
| 132 |
+
|
| 133 |
+
for block in self.blocks:
|
| 134 |
+
x = block(x, **kwargs)
|
| 135 |
+
|
| 136 |
+
# head
|
| 137 |
+
x = self.head(x, e)
|
| 138 |
+
|
| 139 |
+
# Context Parallel
|
| 140 |
+
x = gather_forward(x, dim=1)
|
| 141 |
+
|
| 142 |
+
# unpatchify
|
| 143 |
+
x = self.unpatchify(x, grid_sizes)
|
| 144 |
+
return [u.float() for u in x]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
|
| 148 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 149 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 150 |
+
|
| 151 |
+
def half(x):
|
| 152 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 153 |
+
|
| 154 |
+
# query, key, value function
|
| 155 |
+
def qkv_fn(x):
|
| 156 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 157 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 158 |
+
v = self.v(x).view(b, s, n, d)
|
| 159 |
+
return q, k, v
|
| 160 |
+
|
| 161 |
+
q, k, v = qkv_fn(x)
|
| 162 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 163 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 164 |
+
|
| 165 |
+
x = distributed_attention(
|
| 166 |
+
half(q),
|
| 167 |
+
half(k),
|
| 168 |
+
half(v),
|
| 169 |
+
seq_lens,
|
| 170 |
+
window_size=self.window_size,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# output
|
| 174 |
+
x = x.flatten(2)
|
| 175 |
+
x = self.o(x)
|
| 176 |
+
return x
|
wan/distributed/ulysses.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
from ..modules.attention import flash_attention
|
| 6 |
+
from .util import all_to_all
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def distributed_attention(
|
| 10 |
+
q,
|
| 11 |
+
k,
|
| 12 |
+
v,
|
| 13 |
+
seq_lens,
|
| 14 |
+
window_size=(-1, -1),
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
|
| 18 |
+
please refer to https://arxiv.org/pdf/2309.14509
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
q: [B, Lq // p, Nq, C1].
|
| 22 |
+
k: [B, Lk // p, Nk, C1].
|
| 23 |
+
v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
|
| 24 |
+
seq_lens: [B], length of each sequence in batch
|
| 25 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 26 |
+
"""
|
| 27 |
+
if not dist.is_initialized():
|
| 28 |
+
raise ValueError("distributed group should be initialized.")
|
| 29 |
+
b = q.shape[0]
|
| 30 |
+
|
| 31 |
+
# gather q/k/v sequence
|
| 32 |
+
q = all_to_all(q, scatter_dim=2, gather_dim=1)
|
| 33 |
+
k = all_to_all(k, scatter_dim=2, gather_dim=1)
|
| 34 |
+
v = all_to_all(v, scatter_dim=2, gather_dim=1)
|
| 35 |
+
|
| 36 |
+
# apply attention
|
| 37 |
+
x = flash_attention(
|
| 38 |
+
q,
|
| 39 |
+
k,
|
| 40 |
+
v,
|
| 41 |
+
k_lens=seq_lens,
|
| 42 |
+
window_size=window_size,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# scatter q/k/v sequence
|
| 46 |
+
x = all_to_all(x, scatter_dim=1, gather_dim=2)
|
| 47 |
+
return x
|
wan/distributed/util.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_distributed_group():
|
| 7 |
+
"""r initialize sequence parallel group.
|
| 8 |
+
"""
|
| 9 |
+
if not dist.is_initialized():
|
| 10 |
+
dist.init_process_group(backend='nccl')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_rank():
|
| 14 |
+
return dist.get_rank()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_world_size():
|
| 18 |
+
return dist.get_world_size()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
|
| 22 |
+
"""
|
| 23 |
+
`scatter` along one dimension and `gather` along another.
|
| 24 |
+
"""
|
| 25 |
+
world_size = get_world_size()
|
| 26 |
+
if world_size > 1:
|
| 27 |
+
inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
|
| 28 |
+
outputs = [torch.empty_like(u) for u in inputs]
|
| 29 |
+
dist.all_to_all(outputs, inputs, group=group, **kwargs)
|
| 30 |
+
x = torch.cat(outputs, dim=gather_dim).contiguous()
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def all_gather(tensor):
|
| 35 |
+
world_size = dist.get_world_size()
|
| 36 |
+
if world_size == 1:
|
| 37 |
+
return [tensor]
|
| 38 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
| 39 |
+
torch.distributed.all_gather(tensor_list, tensor)
|
| 40 |
+
return tensor_list
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def gather_forward(input, dim):
|
| 44 |
+
# skip if world_size == 1
|
| 45 |
+
world_size = dist.get_world_size()
|
| 46 |
+
if world_size == 1:
|
| 47 |
+
return input
|
| 48 |
+
|
| 49 |
+
# gather sequence
|
| 50 |
+
output = all_gather(input)
|
| 51 |
+
return torch.cat(output, dim=dim).contiguous()
|
wan/image2video.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
|
| 21 |
+
from .distributed.util import get_world_size
|
| 22 |
+
from .modules.model import WanModel
|
| 23 |
+
from .modules.t5 import T5EncoderModel
|
| 24 |
+
from .modules.vae2_1 import Wan2_1_VAE
|
| 25 |
+
from .utils.fm_solvers import (
|
| 26 |
+
FlowDPMSolverMultistepScheduler,
|
| 27 |
+
get_sampling_sigmas,
|
| 28 |
+
retrieve_timesteps,
|
| 29 |
+
)
|
| 30 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WanI2V:
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
config,
|
| 38 |
+
checkpoint_dir,
|
| 39 |
+
device_id=0,
|
| 40 |
+
rank=0,
|
| 41 |
+
t5_fsdp=False,
|
| 42 |
+
dit_fsdp=False,
|
| 43 |
+
use_sp=False,
|
| 44 |
+
t5_cpu=False,
|
| 45 |
+
init_on_cpu=True,
|
| 46 |
+
convert_model_dtype=False,
|
| 47 |
+
):
|
| 48 |
+
r"""
|
| 49 |
+
Initializes the image-to-video generation model components.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config (EasyDict):
|
| 53 |
+
Object containing model parameters initialized from config.py
|
| 54 |
+
checkpoint_dir (`str`):
|
| 55 |
+
Path to directory containing model checkpoints
|
| 56 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 57 |
+
Id of target GPU device
|
| 58 |
+
rank (`int`, *optional*, defaults to 0):
|
| 59 |
+
Process rank for distributed training
|
| 60 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 61 |
+
Enable FSDP sharding for T5 model
|
| 62 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 63 |
+
Enable FSDP sharding for DiT model
|
| 64 |
+
use_sp (`bool`, *optional*, defaults to False):
|
| 65 |
+
Enable distribution strategy of sequence parallel.
|
| 66 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 67 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 68 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 69 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 70 |
+
convert_model_dtype (`bool`, *optional*, defaults to False):
|
| 71 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 72 |
+
Only works without FSDP.
|
| 73 |
+
"""
|
| 74 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 75 |
+
self.config = config
|
| 76 |
+
self.rank = rank
|
| 77 |
+
self.t5_cpu = t5_cpu
|
| 78 |
+
self.init_on_cpu = init_on_cpu
|
| 79 |
+
|
| 80 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 81 |
+
self.boundary = config.boundary
|
| 82 |
+
self.param_dtype = config.param_dtype
|
| 83 |
+
|
| 84 |
+
if t5_fsdp or dit_fsdp or use_sp:
|
| 85 |
+
self.init_on_cpu = False
|
| 86 |
+
|
| 87 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 88 |
+
self.text_encoder = T5EncoderModel(
|
| 89 |
+
text_len=config.text_len,
|
| 90 |
+
dtype=config.t5_dtype,
|
| 91 |
+
device=torch.device('cpu'),
|
| 92 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 93 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 94 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.vae_stride = config.vae_stride
|
| 98 |
+
self.patch_size = config.patch_size
|
| 99 |
+
self.vae = Wan2_1_VAE(
|
| 100 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 101 |
+
device=self.device)
|
| 102 |
+
|
| 103 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 104 |
+
self.low_noise_model = WanModel.from_pretrained(
|
| 105 |
+
checkpoint_dir, subfolder=config.low_noise_checkpoint)
|
| 106 |
+
self.low_noise_model = self._configure_model(
|
| 107 |
+
model=self.low_noise_model,
|
| 108 |
+
use_sp=use_sp,
|
| 109 |
+
dit_fsdp=dit_fsdp,
|
| 110 |
+
shard_fn=shard_fn,
|
| 111 |
+
convert_model_dtype=convert_model_dtype)
|
| 112 |
+
|
| 113 |
+
self.high_noise_model = WanModel.from_pretrained(
|
| 114 |
+
checkpoint_dir, subfolder=config.high_noise_checkpoint)
|
| 115 |
+
self.high_noise_model = self._configure_model(
|
| 116 |
+
model=self.high_noise_model,
|
| 117 |
+
use_sp=use_sp,
|
| 118 |
+
dit_fsdp=dit_fsdp,
|
| 119 |
+
shard_fn=shard_fn,
|
| 120 |
+
convert_model_dtype=convert_model_dtype)
|
| 121 |
+
if use_sp:
|
| 122 |
+
self.sp_size = get_world_size()
|
| 123 |
+
else:
|
| 124 |
+
self.sp_size = 1
|
| 125 |
+
|
| 126 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 127 |
+
|
| 128 |
+
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
|
| 129 |
+
convert_model_dtype):
|
| 130 |
+
"""
|
| 131 |
+
Configures a model object. This includes setting evaluation modes,
|
| 132 |
+
applying distributed parallel strategy, and handling device placement.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
model (torch.nn.Module):
|
| 136 |
+
The model instance to configure.
|
| 137 |
+
use_sp (`bool`):
|
| 138 |
+
Enable distribution strategy of sequence parallel.
|
| 139 |
+
dit_fsdp (`bool`):
|
| 140 |
+
Enable FSDP sharding for DiT model.
|
| 141 |
+
shard_fn (callable):
|
| 142 |
+
The function to apply FSDP sharding.
|
| 143 |
+
convert_model_dtype (`bool`):
|
| 144 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 145 |
+
Only works without FSDP.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
torch.nn.Module:
|
| 149 |
+
The configured model.
|
| 150 |
+
"""
|
| 151 |
+
model.eval().requires_grad_(False)
|
| 152 |
+
|
| 153 |
+
if use_sp:
|
| 154 |
+
for block in model.blocks:
|
| 155 |
+
block.self_attn.forward = types.MethodType(
|
| 156 |
+
sp_attn_forward, block.self_attn)
|
| 157 |
+
model.forward = types.MethodType(sp_dit_forward, model)
|
| 158 |
+
|
| 159 |
+
if dist.is_initialized():
|
| 160 |
+
dist.barrier()
|
| 161 |
+
|
| 162 |
+
if dit_fsdp:
|
| 163 |
+
model = shard_fn(model)
|
| 164 |
+
else:
|
| 165 |
+
if convert_model_dtype:
|
| 166 |
+
model.to(self.param_dtype)
|
| 167 |
+
if not self.init_on_cpu:
|
| 168 |
+
model.to(self.device)
|
| 169 |
+
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
def _prepare_model_for_timestep(self, t, boundary, offload_model):
|
| 173 |
+
r"""
|
| 174 |
+
Prepares and returns the required model for the current timestep.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
t (torch.Tensor):
|
| 178 |
+
current timestep.
|
| 179 |
+
boundary (`int`):
|
| 180 |
+
The timestep threshold. If `t` is at or above this value,
|
| 181 |
+
the `high_noise_model` is considered as the required model.
|
| 182 |
+
offload_model (`bool`):
|
| 183 |
+
A flag intended to control the offloading behavior.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
torch.nn.Module:
|
| 187 |
+
The active model on the target device for the current timestep.
|
| 188 |
+
"""
|
| 189 |
+
if t.item() >= boundary:
|
| 190 |
+
required_model_name = 'high_noise_model'
|
| 191 |
+
offload_model_name = 'low_noise_model'
|
| 192 |
+
else:
|
| 193 |
+
required_model_name = 'low_noise_model'
|
| 194 |
+
offload_model_name = 'high_noise_model'
|
| 195 |
+
if offload_model or self.init_on_cpu:
|
| 196 |
+
if next(getattr(
|
| 197 |
+
self,
|
| 198 |
+
offload_model_name).parameters()).device.type == 'cuda':
|
| 199 |
+
getattr(self, offload_model_name).to('cpu')
|
| 200 |
+
if next(getattr(
|
| 201 |
+
self,
|
| 202 |
+
required_model_name).parameters()).device.type == 'cpu':
|
| 203 |
+
getattr(self, required_model_name).to(self.device)
|
| 204 |
+
return getattr(self, required_model_name)
|
| 205 |
+
|
| 206 |
+
def generate(self,
|
| 207 |
+
input_prompt,
|
| 208 |
+
img,
|
| 209 |
+
max_area=720 * 1280,
|
| 210 |
+
frame_num=81,
|
| 211 |
+
shift=5.0,
|
| 212 |
+
sample_solver='unipc',
|
| 213 |
+
sampling_steps=40,
|
| 214 |
+
guide_scale=5.0,
|
| 215 |
+
n_prompt="",
|
| 216 |
+
seed=-1,
|
| 217 |
+
offload_model=True):
|
| 218 |
+
r"""
|
| 219 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
input_prompt (`str`):
|
| 223 |
+
Text prompt for content generation.
|
| 224 |
+
img (PIL.Image.Image):
|
| 225 |
+
Input image tensor. Shape: [3, H, W]
|
| 226 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 227 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 228 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 229 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 230 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 231 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 232 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 233 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 234 |
+
Solver used to sample the video.
|
| 235 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 236 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 237 |
+
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
|
| 238 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 239 |
+
If tuple, the first guide_scale will be used for low noise model and
|
| 240 |
+
the second guide_scale will be used for high noise model.
|
| 241 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 242 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 243 |
+
seed (`int`, *optional*, defaults to -1):
|
| 244 |
+
Random seed for noise generation. If -1, use random seed
|
| 245 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 246 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
torch.Tensor:
|
| 250 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 251 |
+
- C: Color channels (3 for RGB)
|
| 252 |
+
- N: Number of frames (81)
|
| 253 |
+
- H: Frame height (from max_area)
|
| 254 |
+
- W: Frame width from max_area)
|
| 255 |
+
"""
|
| 256 |
+
# preprocess
|
| 257 |
+
guide_scale = (guide_scale, guide_scale) if isinstance(
|
| 258 |
+
guide_scale, float) else guide_scale
|
| 259 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
| 260 |
+
|
| 261 |
+
F = frame_num
|
| 262 |
+
h, w = img.shape[1:]
|
| 263 |
+
aspect_ratio = h / w
|
| 264 |
+
lat_h = round(
|
| 265 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 266 |
+
self.patch_size[1] * self.patch_size[1])
|
| 267 |
+
lat_w = round(
|
| 268 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 269 |
+
self.patch_size[2] * self.patch_size[2])
|
| 270 |
+
h = lat_h * self.vae_stride[1]
|
| 271 |
+
w = lat_w * self.vae_stride[2]
|
| 272 |
+
|
| 273 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
| 274 |
+
self.patch_size[1] * self.patch_size[2])
|
| 275 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 276 |
+
|
| 277 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 278 |
+
seed_g = torch.Generator(device=self.device)
|
| 279 |
+
seed_g.manual_seed(seed)
|
| 280 |
+
noise = torch.randn(
|
| 281 |
+
16,
|
| 282 |
+
21,
|
| 283 |
+
lat_h,
|
| 284 |
+
lat_w,
|
| 285 |
+
dtype=torch.float32,
|
| 286 |
+
generator=seed_g,
|
| 287 |
+
device=self.device)
|
| 288 |
+
|
| 289 |
+
msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
|
| 290 |
+
msk[:, 1:] = 0
|
| 291 |
+
msk = torch.concat([
|
| 292 |
+
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
| 293 |
+
],
|
| 294 |
+
dim=1)
|
| 295 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 296 |
+
msk = msk.transpose(1, 2)[0]
|
| 297 |
+
|
| 298 |
+
if n_prompt == "":
|
| 299 |
+
n_prompt = self.sample_neg_prompt
|
| 300 |
+
|
| 301 |
+
# preprocess
|
| 302 |
+
if not self.t5_cpu:
|
| 303 |
+
self.text_encoder.model.to(self.device)
|
| 304 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 305 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 306 |
+
if offload_model:
|
| 307 |
+
self.text_encoder.model.cpu()
|
| 308 |
+
else:
|
| 309 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 310 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 311 |
+
context = [t.to(self.device) for t in context]
|
| 312 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 313 |
+
|
| 314 |
+
y = self.vae.encode([
|
| 315 |
+
torch.concat([
|
| 316 |
+
torch.nn.functional.interpolate(
|
| 317 |
+
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
| 318 |
+
0, 1),
|
| 319 |
+
torch.zeros(3, 80, h, w)
|
| 320 |
+
],
|
| 321 |
+
dim=1).to(self.device)
|
| 322 |
+
])[0]
|
| 323 |
+
y = torch.concat([msk, y])
|
| 324 |
+
|
| 325 |
+
@contextmanager
|
| 326 |
+
def noop_no_sync():
|
| 327 |
+
yield
|
| 328 |
+
|
| 329 |
+
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
|
| 330 |
+
noop_no_sync)
|
| 331 |
+
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
|
| 332 |
+
noop_no_sync)
|
| 333 |
+
|
| 334 |
+
# evaluation mode
|
| 335 |
+
with (
|
| 336 |
+
torch.amp.autocast('cuda', dtype=self.param_dtype),
|
| 337 |
+
torch.no_grad(),
|
| 338 |
+
no_sync_low_noise(),
|
| 339 |
+
no_sync_high_noise(),
|
| 340 |
+
):
|
| 341 |
+
boundary = self.boundary * self.num_train_timesteps
|
| 342 |
+
|
| 343 |
+
if sample_solver == 'unipc':
|
| 344 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 345 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 346 |
+
shift=1,
|
| 347 |
+
use_dynamic_shifting=False)
|
| 348 |
+
sample_scheduler.set_timesteps(
|
| 349 |
+
sampling_steps, device=self.device, shift=shift)
|
| 350 |
+
timesteps = sample_scheduler.timesteps
|
| 351 |
+
elif sample_solver == 'dpm++':
|
| 352 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 353 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 354 |
+
shift=1,
|
| 355 |
+
use_dynamic_shifting=False)
|
| 356 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 357 |
+
timesteps, _ = retrieve_timesteps(
|
| 358 |
+
sample_scheduler,
|
| 359 |
+
device=self.device,
|
| 360 |
+
sigmas=sampling_sigmas)
|
| 361 |
+
else:
|
| 362 |
+
raise NotImplementedError("Unsupported solver.")
|
| 363 |
+
|
| 364 |
+
# sample videos
|
| 365 |
+
latent = noise
|
| 366 |
+
|
| 367 |
+
arg_c = {
|
| 368 |
+
'context': [context[0]],
|
| 369 |
+
'seq_len': max_seq_len,
|
| 370 |
+
'y': [y],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
arg_null = {
|
| 374 |
+
'context': context_null,
|
| 375 |
+
'seq_len': max_seq_len,
|
| 376 |
+
'y': [y],
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
if offload_model:
|
| 380 |
+
torch.cuda.empty_cache()
|
| 381 |
+
|
| 382 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 383 |
+
latent_model_input = [latent.to(self.device)]
|
| 384 |
+
timestep = [t]
|
| 385 |
+
|
| 386 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 387 |
+
|
| 388 |
+
model = self._prepare_model_for_timestep(
|
| 389 |
+
t, boundary, offload_model)
|
| 390 |
+
sample_guide_scale = guide_scale[1] if t.item(
|
| 391 |
+
) >= boundary else guide_scale[0]
|
| 392 |
+
|
| 393 |
+
noise_pred_cond = model(
|
| 394 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
| 395 |
+
if offload_model:
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
noise_pred_uncond = model(
|
| 398 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
| 399 |
+
if offload_model:
|
| 400 |
+
torch.cuda.empty_cache()
|
| 401 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (
|
| 402 |
+
noise_pred_cond - noise_pred_uncond)
|
| 403 |
+
|
| 404 |
+
temp_x0 = sample_scheduler.step(
|
| 405 |
+
noise_pred.unsqueeze(0),
|
| 406 |
+
t,
|
| 407 |
+
latent.unsqueeze(0),
|
| 408 |
+
return_dict=False,
|
| 409 |
+
generator=seed_g)[0]
|
| 410 |
+
latent = temp_x0.squeeze(0)
|
| 411 |
+
|
| 412 |
+
x0 = [latent]
|
| 413 |
+
del latent_model_input, timestep
|
| 414 |
+
|
| 415 |
+
if offload_model:
|
| 416 |
+
self.low_noise_model.cpu()
|
| 417 |
+
self.high_noise_model.cpu()
|
| 418 |
+
torch.cuda.empty_cache()
|
| 419 |
+
|
| 420 |
+
if self.rank == 0:
|
| 421 |
+
videos = self.vae.decode(x0)
|
| 422 |
+
|
| 423 |
+
del noise, latent, x0
|
| 424 |
+
del sample_scheduler
|
| 425 |
+
if offload_model:
|
| 426 |
+
gc.collect()
|
| 427 |
+
torch.cuda.synchronize()
|
| 428 |
+
if dist.is_initialized():
|
| 429 |
+
dist.barrier()
|
| 430 |
+
|
| 431 |
+
return videos[0] if self.rank == 0 else None
|
wan/modules/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .attention import flash_attention
|
| 3 |
+
from .model import WanModel
|
| 4 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
| 5 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 6 |
+
from .vae2_1 import Wan2_1_VAE
|
| 7 |
+
from .vae2_2 import Wan2_2_VAE
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'Wan2_1_VAE',
|
| 11 |
+
'Wan2_2_VAE',
|
| 12 |
+
'WanModel',
|
| 13 |
+
'T5Model',
|
| 14 |
+
'T5Encoder',
|
| 15 |
+
'T5Decoder',
|
| 16 |
+
'T5EncoderModel',
|
| 17 |
+
'HuggingfaceTokenizer',
|
| 18 |
+
'flash_attention',
|
| 19 |
+
]
|
wan/modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (528 Bytes). View file
|
|
|
wan/modules/__pycache__/attention.cpython-310.pyc
ADDED
|
Binary file (3.95 kB). View file
|
|
|
wan/modules/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
wan/modules/__pycache__/t5.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
wan/modules/__pycache__/tokenizers.cpython-310.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
wan/modules/__pycache__/vae2_1.cpython-310.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
wan/modules/__pycache__/vae2_2.cpython-310.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
wan/modules/attention.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
try:
|
| 5 |
+
import flash_attn_interface
|
| 6 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 7 |
+
except ModuleNotFoundError:
|
| 8 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import flash_attn
|
| 12 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 13 |
+
except ModuleNotFoundError:
|
| 14 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 15 |
+
|
| 16 |
+
import warnings
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
'flash_attention',
|
| 20 |
+
'attention',
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def flash_attention(
|
| 25 |
+
q,
|
| 26 |
+
k,
|
| 27 |
+
v,
|
| 28 |
+
q_lens=None,
|
| 29 |
+
k_lens=None,
|
| 30 |
+
dropout_p=0.,
|
| 31 |
+
softmax_scale=None,
|
| 32 |
+
q_scale=None,
|
| 33 |
+
causal=False,
|
| 34 |
+
window_size=(-1, -1),
|
| 35 |
+
deterministic=False,
|
| 36 |
+
dtype=torch.bfloat16,
|
| 37 |
+
version=None,
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
q: [B, Lq, Nq, C1].
|
| 41 |
+
k: [B, Lk, Nk, C1].
|
| 42 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
| 43 |
+
q_lens: [B].
|
| 44 |
+
k_lens: [B].
|
| 45 |
+
dropout_p: float. Dropout probability.
|
| 46 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
| 47 |
+
causal: bool. Whether to apply causal attention mask.
|
| 48 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 49 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
| 50 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
| 51 |
+
"""
|
| 52 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 53 |
+
assert dtype in half_dtypes
|
| 54 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 55 |
+
|
| 56 |
+
# params
|
| 57 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 58 |
+
|
| 59 |
+
def half(x):
|
| 60 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 61 |
+
|
| 62 |
+
# preprocess query
|
| 63 |
+
if q_lens is None:
|
| 64 |
+
q = half(q.flatten(0, 1))
|
| 65 |
+
q_lens = torch.tensor(
|
| 66 |
+
[lq] * b, dtype=torch.int32).to(
|
| 67 |
+
device=q.device, non_blocking=True)
|
| 68 |
+
else:
|
| 69 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
| 70 |
+
|
| 71 |
+
# preprocess key, value
|
| 72 |
+
if k_lens is None:
|
| 73 |
+
k = half(k.flatten(0, 1))
|
| 74 |
+
v = half(v.flatten(0, 1))
|
| 75 |
+
k_lens = torch.tensor(
|
| 76 |
+
[lk] * b, dtype=torch.int32).to(
|
| 77 |
+
device=k.device, non_blocking=True)
|
| 78 |
+
else:
|
| 79 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
| 80 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
| 81 |
+
|
| 82 |
+
q = q.to(v.dtype)
|
| 83 |
+
k = k.to(v.dtype)
|
| 84 |
+
|
| 85 |
+
if q_scale is not None:
|
| 86 |
+
q = q * q_scale
|
| 87 |
+
|
| 88 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
| 89 |
+
warnings.warn(
|
| 90 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# apply attention
|
| 94 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
| 95 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
| 96 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
| 97 |
+
q=q,
|
| 98 |
+
k=k,
|
| 99 |
+
v=v,
|
| 100 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 101 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 102 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 103 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 104 |
+
seqused_q=None,
|
| 105 |
+
seqused_k=None,
|
| 106 |
+
max_seqlen_q=lq,
|
| 107 |
+
max_seqlen_k=lk,
|
| 108 |
+
softmax_scale=softmax_scale,
|
| 109 |
+
causal=causal,
|
| 110 |
+
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
| 111 |
+
else:
|
| 112 |
+
assert FLASH_ATTN_2_AVAILABLE
|
| 113 |
+
x = flash_attn.flash_attn_varlen_func(
|
| 114 |
+
q=q,
|
| 115 |
+
k=k,
|
| 116 |
+
v=v,
|
| 117 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 118 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 119 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 120 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 121 |
+
max_seqlen_q=lq,
|
| 122 |
+
max_seqlen_k=lk,
|
| 123 |
+
dropout_p=dropout_p,
|
| 124 |
+
softmax_scale=softmax_scale,
|
| 125 |
+
causal=causal,
|
| 126 |
+
window_size=window_size,
|
| 127 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
| 128 |
+
|
| 129 |
+
# output
|
| 130 |
+
return x.type(out_dtype)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def attention(
|
| 134 |
+
q,
|
| 135 |
+
k,
|
| 136 |
+
v,
|
| 137 |
+
q_lens=None,
|
| 138 |
+
k_lens=None,
|
| 139 |
+
dropout_p=0.,
|
| 140 |
+
softmax_scale=None,
|
| 141 |
+
q_scale=None,
|
| 142 |
+
causal=False,
|
| 143 |
+
window_size=(-1, -1),
|
| 144 |
+
deterministic=False,
|
| 145 |
+
dtype=torch.bfloat16,
|
| 146 |
+
fa_version=None,
|
| 147 |
+
):
|
| 148 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
| 149 |
+
return flash_attention(
|
| 150 |
+
q=q,
|
| 151 |
+
k=k,
|
| 152 |
+
v=v,
|
| 153 |
+
q_lens=q_lens,
|
| 154 |
+
k_lens=k_lens,
|
| 155 |
+
dropout_p=dropout_p,
|
| 156 |
+
softmax_scale=softmax_scale,
|
| 157 |
+
q_scale=q_scale,
|
| 158 |
+
causal=causal,
|
| 159 |
+
window_size=window_size,
|
| 160 |
+
deterministic=deterministic,
|
| 161 |
+
dtype=dtype,
|
| 162 |
+
version=fa_version,
|
| 163 |
+
)
|
| 164 |
+
else:
|
| 165 |
+
if q_lens is not None or k_lens is not None:
|
| 166 |
+
warnings.warn(
|
| 167 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
| 168 |
+
)
|
| 169 |
+
attn_mask = None
|
| 170 |
+
|
| 171 |
+
q = q.transpose(1, 2).to(dtype)
|
| 172 |
+
k = k.transpose(1, 2).to(dtype)
|
| 173 |
+
v = v.transpose(1, 2).to(dtype)
|
| 174 |
+
|
| 175 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
| 176 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
| 177 |
+
|
| 178 |
+
out = out.transpose(1, 2).contiguous()
|
| 179 |
+
return out
|
wan/modules/model.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 8 |
+
|
| 9 |
+
from .attention import flash_attention
|
| 10 |
+
|
| 11 |
+
__all__ = ['WanModel']
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 15 |
+
# preprocess
|
| 16 |
+
assert dim % 2 == 0
|
| 17 |
+
half = dim // 2
|
| 18 |
+
position = position.type(torch.float64)
|
| 19 |
+
|
| 20 |
+
# calculation
|
| 21 |
+
sinusoid = torch.outer(
|
| 22 |
+
position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
|
| 23 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 28 |
+
def rope_params(max_seq_len, dim, theta=10000):
|
| 29 |
+
assert dim % 2 == 0
|
| 30 |
+
freqs = torch.outer(
|
| 31 |
+
torch.arange(max_seq_len),
|
| 32 |
+
1.0 / torch.pow(theta,
|
| 33 |
+
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
| 34 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
| 35 |
+
return freqs
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 39 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 40 |
+
n, c = x.size(2), x.size(3) // 2
|
| 41 |
+
|
| 42 |
+
# split freqs
|
| 43 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 44 |
+
|
| 45 |
+
# loop over samples
|
| 46 |
+
output = []
|
| 47 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 48 |
+
seq_len = f * h * w
|
| 49 |
+
|
| 50 |
+
# precompute multipliers
|
| 51 |
+
x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
|
| 52 |
+
seq_len, n, -1, 2))
|
| 53 |
+
freqs_i = torch.cat([
|
| 54 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 55 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 56 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 57 |
+
],
|
| 58 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 59 |
+
|
| 60 |
+
# apply rotary embedding
|
| 61 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 62 |
+
x_i = torch.cat([x_i, x[i, seq_len:]])
|
| 63 |
+
|
| 64 |
+
# append to collection
|
| 65 |
+
output.append(x_i)
|
| 66 |
+
return torch.stack(output).float()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class WanRMSNorm(nn.Module):
|
| 70 |
+
|
| 71 |
+
def __init__(self, dim, eps=1e-5):
|
| 72 |
+
super().__init__()
|
| 73 |
+
self.dim = dim
|
| 74 |
+
self.eps = eps
|
| 75 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
r"""
|
| 79 |
+
Args:
|
| 80 |
+
x(Tensor): Shape [B, L, C]
|
| 81 |
+
"""
|
| 82 |
+
return self._norm(x.float()).type_as(x) * self.weight
|
| 83 |
+
|
| 84 |
+
def _norm(self, x):
|
| 85 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class WanLayerNorm(nn.LayerNorm):
|
| 89 |
+
|
| 90 |
+
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
|
| 91 |
+
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
r"""
|
| 95 |
+
Args:
|
| 96 |
+
x(Tensor): Shape [B, L, C]
|
| 97 |
+
"""
|
| 98 |
+
return super().forward(x.float()).type_as(x)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class WanSelfAttention(nn.Module):
|
| 102 |
+
|
| 103 |
+
def __init__(self,
|
| 104 |
+
dim,
|
| 105 |
+
num_heads,
|
| 106 |
+
window_size=(-1, -1),
|
| 107 |
+
qk_norm=True,
|
| 108 |
+
eps=1e-6):
|
| 109 |
+
assert dim % num_heads == 0
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.dim = dim
|
| 112 |
+
self.num_heads = num_heads
|
| 113 |
+
self.head_dim = dim // num_heads
|
| 114 |
+
self.window_size = window_size
|
| 115 |
+
self.qk_norm = qk_norm
|
| 116 |
+
self.eps = eps
|
| 117 |
+
|
| 118 |
+
# layers
|
| 119 |
+
self.q = nn.Linear(dim, dim)
|
| 120 |
+
self.k = nn.Linear(dim, dim)
|
| 121 |
+
self.v = nn.Linear(dim, dim)
|
| 122 |
+
self.o = nn.Linear(dim, dim)
|
| 123 |
+
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 124 |
+
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 125 |
+
|
| 126 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 127 |
+
r"""
|
| 128 |
+
Args:
|
| 129 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 130 |
+
seq_lens(Tensor): Shape [B]
|
| 131 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 132 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 133 |
+
"""
|
| 134 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 135 |
+
|
| 136 |
+
# query, key, value function
|
| 137 |
+
def qkv_fn(x):
|
| 138 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 139 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 140 |
+
v = self.v(x).view(b, s, n, d)
|
| 141 |
+
return q, k, v
|
| 142 |
+
|
| 143 |
+
q, k, v = qkv_fn(x)
|
| 144 |
+
|
| 145 |
+
x = flash_attention(
|
| 146 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 147 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 148 |
+
v=v,
|
| 149 |
+
k_lens=seq_lens,
|
| 150 |
+
window_size=self.window_size)
|
| 151 |
+
|
| 152 |
+
# output
|
| 153 |
+
x = x.flatten(2)
|
| 154 |
+
x = self.o(x)
|
| 155 |
+
return x
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class WanCrossAttention(WanSelfAttention):
|
| 159 |
+
|
| 160 |
+
def forward(self, x, context, context_lens):
|
| 161 |
+
r"""
|
| 162 |
+
Args:
|
| 163 |
+
x(Tensor): Shape [B, L1, C]
|
| 164 |
+
context(Tensor): Shape [B, L2, C]
|
| 165 |
+
context_lens(Tensor): Shape [B]
|
| 166 |
+
"""
|
| 167 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 168 |
+
|
| 169 |
+
# compute query, key, value
|
| 170 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 171 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 172 |
+
v = self.v(context).view(b, -1, n, d)
|
| 173 |
+
|
| 174 |
+
# compute attention
|
| 175 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 176 |
+
|
| 177 |
+
# output
|
| 178 |
+
x = x.flatten(2)
|
| 179 |
+
x = self.o(x)
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class WanAttentionBlock(nn.Module):
|
| 184 |
+
|
| 185 |
+
def __init__(self,
|
| 186 |
+
dim,
|
| 187 |
+
ffn_dim,
|
| 188 |
+
num_heads,
|
| 189 |
+
window_size=(-1, -1),
|
| 190 |
+
qk_norm=True,
|
| 191 |
+
cross_attn_norm=False,
|
| 192 |
+
eps=1e-6):
|
| 193 |
+
super().__init__()
|
| 194 |
+
self.dim = dim
|
| 195 |
+
self.ffn_dim = ffn_dim
|
| 196 |
+
self.num_heads = num_heads
|
| 197 |
+
self.window_size = window_size
|
| 198 |
+
self.qk_norm = qk_norm
|
| 199 |
+
self.cross_attn_norm = cross_attn_norm
|
| 200 |
+
self.eps = eps
|
| 201 |
+
|
| 202 |
+
# layers
|
| 203 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 204 |
+
self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
|
| 205 |
+
eps)
|
| 206 |
+
self.norm3 = WanLayerNorm(
|
| 207 |
+
dim, eps,
|
| 208 |
+
elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 209 |
+
self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
|
| 210 |
+
eps)
|
| 211 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 212 |
+
self.ffn = nn.Sequential(
|
| 213 |
+
nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
|
| 214 |
+
nn.Linear(ffn_dim, dim))
|
| 215 |
+
|
| 216 |
+
# modulation
|
| 217 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self,
|
| 221 |
+
x,
|
| 222 |
+
e,
|
| 223 |
+
seq_lens,
|
| 224 |
+
grid_sizes,
|
| 225 |
+
freqs,
|
| 226 |
+
context,
|
| 227 |
+
context_lens,
|
| 228 |
+
):
|
| 229 |
+
r"""
|
| 230 |
+
Args:
|
| 231 |
+
x(Tensor): Shape [B, L, C]
|
| 232 |
+
e(Tensor): Shape [B, L1, 6, C]
|
| 233 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 234 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 235 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 236 |
+
"""
|
| 237 |
+
assert e.dtype == torch.float32
|
| 238 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 239 |
+
e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
|
| 240 |
+
assert e[0].dtype == torch.float32
|
| 241 |
+
|
| 242 |
+
# self-attention
|
| 243 |
+
y = self.self_attn(
|
| 244 |
+
self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
|
| 245 |
+
seq_lens, grid_sizes, freqs)
|
| 246 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 247 |
+
x = x + y * e[2].squeeze(2)
|
| 248 |
+
|
| 249 |
+
# cross-attention & ffn function
|
| 250 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 251 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 252 |
+
y = self.ffn(
|
| 253 |
+
self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
|
| 254 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 255 |
+
x = x + y * e[5].squeeze(2)
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class Head(nn.Module):
|
| 263 |
+
|
| 264 |
+
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.dim = dim
|
| 267 |
+
self.out_dim = out_dim
|
| 268 |
+
self.patch_size = patch_size
|
| 269 |
+
self.eps = eps
|
| 270 |
+
|
| 271 |
+
# layers
|
| 272 |
+
out_dim = math.prod(patch_size) * out_dim
|
| 273 |
+
self.norm = WanLayerNorm(dim, eps)
|
| 274 |
+
self.head = nn.Linear(dim, out_dim)
|
| 275 |
+
|
| 276 |
+
# modulation
|
| 277 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 278 |
+
|
| 279 |
+
def forward(self, x, e):
|
| 280 |
+
r"""
|
| 281 |
+
Args:
|
| 282 |
+
x(Tensor): Shape [B, L1, C]
|
| 283 |
+
e(Tensor): Shape [B, L1, C]
|
| 284 |
+
"""
|
| 285 |
+
assert e.dtype == torch.float32
|
| 286 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 287 |
+
e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
|
| 288 |
+
x = (
|
| 289 |
+
self.head(
|
| 290 |
+
self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
|
| 291 |
+
return x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class WanModel(ModelMixin, ConfigMixin):
|
| 295 |
+
r"""
|
| 296 |
+
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
ignore_for_config = [
|
| 300 |
+
'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
|
| 301 |
+
]
|
| 302 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 303 |
+
|
| 304 |
+
@register_to_config
|
| 305 |
+
def __init__(self,
|
| 306 |
+
model_type='t2v',
|
| 307 |
+
patch_size=(1, 2, 2),
|
| 308 |
+
text_len=512,
|
| 309 |
+
in_dim=16,
|
| 310 |
+
dim=2048,
|
| 311 |
+
ffn_dim=8192,
|
| 312 |
+
freq_dim=256,
|
| 313 |
+
text_dim=4096,
|
| 314 |
+
out_dim=16,
|
| 315 |
+
num_heads=16,
|
| 316 |
+
num_layers=32,
|
| 317 |
+
window_size=(-1, -1),
|
| 318 |
+
qk_norm=True,
|
| 319 |
+
cross_attn_norm=True,
|
| 320 |
+
eps=1e-6):
|
| 321 |
+
r"""
|
| 322 |
+
Initialize the diffusion model backbone.
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
model_type (`str`, *optional*, defaults to 't2v'):
|
| 326 |
+
Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
|
| 327 |
+
patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
|
| 328 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
|
| 329 |
+
text_len (`int`, *optional*, defaults to 512):
|
| 330 |
+
Fixed length for text embeddings
|
| 331 |
+
in_dim (`int`, *optional*, defaults to 16):
|
| 332 |
+
Input video channels (C_in)
|
| 333 |
+
dim (`int`, *optional*, defaults to 2048):
|
| 334 |
+
Hidden dimension of the transformer
|
| 335 |
+
ffn_dim (`int`, *optional*, defaults to 8192):
|
| 336 |
+
Intermediate dimension in feed-forward network
|
| 337 |
+
freq_dim (`int`, *optional*, defaults to 256):
|
| 338 |
+
Dimension for sinusoidal time embeddings
|
| 339 |
+
text_dim (`int`, *optional*, defaults to 4096):
|
| 340 |
+
Input dimension for text embeddings
|
| 341 |
+
out_dim (`int`, *optional*, defaults to 16):
|
| 342 |
+
Output video channels (C_out)
|
| 343 |
+
num_heads (`int`, *optional*, defaults to 16):
|
| 344 |
+
Number of attention heads
|
| 345 |
+
num_layers (`int`, *optional*, defaults to 32):
|
| 346 |
+
Number of transformer blocks
|
| 347 |
+
window_size (`tuple`, *optional*, defaults to (-1, -1)):
|
| 348 |
+
Window size for local attention (-1 indicates global attention)
|
| 349 |
+
qk_norm (`bool`, *optional*, defaults to True):
|
| 350 |
+
Enable query/key normalization
|
| 351 |
+
cross_attn_norm (`bool`, *optional*, defaults to False):
|
| 352 |
+
Enable cross-attention normalization
|
| 353 |
+
eps (`float`, *optional*, defaults to 1e-6):
|
| 354 |
+
Epsilon value for normalization layers
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
super().__init__()
|
| 358 |
+
|
| 359 |
+
assert model_type in ['t2v', 'i2v', 'ti2v']
|
| 360 |
+
self.model_type = model_type
|
| 361 |
+
|
| 362 |
+
self.patch_size = patch_size
|
| 363 |
+
self.text_len = text_len
|
| 364 |
+
self.in_dim = in_dim
|
| 365 |
+
self.dim = dim
|
| 366 |
+
self.ffn_dim = ffn_dim
|
| 367 |
+
self.freq_dim = freq_dim
|
| 368 |
+
self.text_dim = text_dim
|
| 369 |
+
self.out_dim = out_dim
|
| 370 |
+
self.num_heads = num_heads
|
| 371 |
+
self.num_layers = num_layers
|
| 372 |
+
self.window_size = window_size
|
| 373 |
+
self.qk_norm = qk_norm
|
| 374 |
+
self.cross_attn_norm = cross_attn_norm
|
| 375 |
+
self.eps = eps
|
| 376 |
+
|
| 377 |
+
# embeddings
|
| 378 |
+
self.patch_embedding = nn.Conv3d(
|
| 379 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 380 |
+
self.text_embedding = nn.Sequential(
|
| 381 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 382 |
+
nn.Linear(dim, dim))
|
| 383 |
+
|
| 384 |
+
self.time_embedding = nn.Sequential(
|
| 385 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 386 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 387 |
+
|
| 388 |
+
# blocks
|
| 389 |
+
self.blocks = nn.ModuleList([
|
| 390 |
+
WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
|
| 391 |
+
cross_attn_norm, eps) for _ in range(num_layers)
|
| 392 |
+
])
|
| 393 |
+
|
| 394 |
+
# head
|
| 395 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 396 |
+
|
| 397 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 398 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 399 |
+
d = dim // num_heads
|
| 400 |
+
self.freqs = torch.cat([
|
| 401 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 402 |
+
rope_params(1024, 2 * (d // 6)),
|
| 403 |
+
rope_params(1024, 2 * (d // 6))
|
| 404 |
+
],
|
| 405 |
+
dim=1)
|
| 406 |
+
|
| 407 |
+
# initialize weights
|
| 408 |
+
self.init_weights()
|
| 409 |
+
|
| 410 |
+
def forward(
|
| 411 |
+
self,
|
| 412 |
+
x,
|
| 413 |
+
t,
|
| 414 |
+
context,
|
| 415 |
+
seq_len,
|
| 416 |
+
y=None,
|
| 417 |
+
):
|
| 418 |
+
r"""
|
| 419 |
+
Forward pass through the diffusion model
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
x (List[Tensor]):
|
| 423 |
+
List of input video tensors, each with shape [C_in, F, H, W]
|
| 424 |
+
t (Tensor):
|
| 425 |
+
Diffusion timesteps tensor of shape [B]
|
| 426 |
+
context (List[Tensor]):
|
| 427 |
+
List of text embeddings each with shape [L, C]
|
| 428 |
+
seq_len (`int`):
|
| 429 |
+
Maximum sequence length for positional encoding
|
| 430 |
+
y (List[Tensor], *optional*):
|
| 431 |
+
Conditional video inputs for image-to-video mode, same shape as x
|
| 432 |
+
|
| 433 |
+
Returns:
|
| 434 |
+
List[Tensor]:
|
| 435 |
+
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
|
| 436 |
+
"""
|
| 437 |
+
if self.model_type == 'i2v':
|
| 438 |
+
assert y is not None
|
| 439 |
+
# params
|
| 440 |
+
device = self.patch_embedding.weight.device
|
| 441 |
+
if self.freqs.device != device:
|
| 442 |
+
self.freqs = self.freqs.to(device)
|
| 443 |
+
|
| 444 |
+
if y is not None:
|
| 445 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 446 |
+
|
| 447 |
+
# embeddings
|
| 448 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 449 |
+
grid_sizes = torch.stack(
|
| 450 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 451 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 452 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 453 |
+
assert seq_lens.max() <= seq_len
|
| 454 |
+
x = torch.cat([
|
| 455 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 456 |
+
dim=1) for u in x
|
| 457 |
+
])
|
| 458 |
+
|
| 459 |
+
# time embeddings
|
| 460 |
+
if t.dim() == 1:
|
| 461 |
+
t = t.expand(t.size(0), seq_len)
|
| 462 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 463 |
+
bt = t.size(0)
|
| 464 |
+
t = t.flatten()
|
| 465 |
+
e = self.time_embedding(
|
| 466 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 467 |
+
t).unflatten(0, (bt, seq_len)).float())
|
| 468 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 469 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 470 |
+
|
| 471 |
+
# context
|
| 472 |
+
context_lens = None
|
| 473 |
+
context = self.text_embedding(
|
| 474 |
+
torch.stack([
|
| 475 |
+
torch.cat(
|
| 476 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 477 |
+
for u in context
|
| 478 |
+
]))
|
| 479 |
+
|
| 480 |
+
# arguments
|
| 481 |
+
kwargs = dict(
|
| 482 |
+
e=e0,
|
| 483 |
+
seq_lens=seq_lens,
|
| 484 |
+
grid_sizes=grid_sizes,
|
| 485 |
+
freqs=self.freqs,
|
| 486 |
+
context=context,
|
| 487 |
+
context_lens=context_lens)
|
| 488 |
+
|
| 489 |
+
for block in self.blocks:
|
| 490 |
+
x = block(x, **kwargs)
|
| 491 |
+
|
| 492 |
+
# head
|
| 493 |
+
x = self.head(x, e)
|
| 494 |
+
|
| 495 |
+
# unpatchify
|
| 496 |
+
x = self.unpatchify(x, grid_sizes)
|
| 497 |
+
return [u.float() for u in x]
|
| 498 |
+
|
| 499 |
+
def unpatchify(self, x, grid_sizes):
|
| 500 |
+
r"""
|
| 501 |
+
Reconstruct video tensors from patch embeddings.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
x (List[Tensor]):
|
| 505 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 506 |
+
grid_sizes (Tensor):
|
| 507 |
+
Original spatial-temporal grid dimensions before patching,
|
| 508 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 509 |
+
|
| 510 |
+
Returns:
|
| 511 |
+
List[Tensor]:
|
| 512 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
c = self.out_dim
|
| 516 |
+
out = []
|
| 517 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 518 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 519 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 520 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 521 |
+
out.append(u)
|
| 522 |
+
return out
|
| 523 |
+
|
| 524 |
+
def init_weights(self):
|
| 525 |
+
r"""
|
| 526 |
+
Initialize model parameters using Xavier initialization.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
# basic init
|
| 530 |
+
for m in self.modules():
|
| 531 |
+
if isinstance(m, nn.Linear):
|
| 532 |
+
nn.init.xavier_uniform_(m.weight)
|
| 533 |
+
if m.bias is not None:
|
| 534 |
+
nn.init.zeros_(m.bias)
|
| 535 |
+
|
| 536 |
+
# init embeddings
|
| 537 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 538 |
+
for m in self.text_embedding.modules():
|
| 539 |
+
if isinstance(m, nn.Linear):
|
| 540 |
+
nn.init.normal_(m.weight, std=.02)
|
| 541 |
+
for m in self.time_embedding.modules():
|
| 542 |
+
if isinstance(m, nn.Linear):
|
| 543 |
+
nn.init.normal_(m.weight, std=.02)
|
| 544 |
+
|
| 545 |
+
# init output layer
|
| 546 |
+
nn.init.zeros_(self.head.head.weight)
|
wan/modules/t5.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from transformers.models.t5.modeling_t5
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
'T5Model',
|
| 14 |
+
'T5Encoder',
|
| 15 |
+
'T5Decoder',
|
| 16 |
+
'T5EncoderModel',
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def fp16_clamp(x):
|
| 21 |
+
if x.dtype == torch.float16 and torch.isinf(x).any():
|
| 22 |
+
clamp = torch.finfo(x.dtype).max - 1000
|
| 23 |
+
x = torch.clamp(x, min=-clamp, max=clamp)
|
| 24 |
+
return x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_weights(m):
|
| 28 |
+
if isinstance(m, T5LayerNorm):
|
| 29 |
+
nn.init.ones_(m.weight)
|
| 30 |
+
elif isinstance(m, T5Model):
|
| 31 |
+
nn.init.normal_(m.token_embedding.weight, std=1.0)
|
| 32 |
+
elif isinstance(m, T5FeedForward):
|
| 33 |
+
nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
|
| 34 |
+
nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
|
| 35 |
+
nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
|
| 36 |
+
elif isinstance(m, T5Attention):
|
| 37 |
+
nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
|
| 38 |
+
nn.init.normal_(m.k.weight, std=m.dim**-0.5)
|
| 39 |
+
nn.init.normal_(m.v.weight, std=m.dim**-0.5)
|
| 40 |
+
nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
|
| 41 |
+
elif isinstance(m, T5RelativeEmbedding):
|
| 42 |
+
nn.init.normal_(
|
| 43 |
+
m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GELU(nn.Module):
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
return 0.5 * x * (1.0 + torch.tanh(
|
| 50 |
+
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class T5LayerNorm(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self, dim, eps=1e-6):
|
| 56 |
+
super(T5LayerNorm, self).__init__()
|
| 57 |
+
self.dim = dim
|
| 58 |
+
self.eps = eps
|
| 59 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 60 |
+
|
| 61 |
+
def forward(self, x):
|
| 62 |
+
x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
|
| 63 |
+
self.eps)
|
| 64 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
| 65 |
+
x = x.type_as(self.weight)
|
| 66 |
+
return self.weight * x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class T5Attention(nn.Module):
|
| 70 |
+
|
| 71 |
+
def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
|
| 72 |
+
assert dim_attn % num_heads == 0
|
| 73 |
+
super(T5Attention, self).__init__()
|
| 74 |
+
self.dim = dim
|
| 75 |
+
self.dim_attn = dim_attn
|
| 76 |
+
self.num_heads = num_heads
|
| 77 |
+
self.head_dim = dim_attn // num_heads
|
| 78 |
+
|
| 79 |
+
# layers
|
| 80 |
+
self.q = nn.Linear(dim, dim_attn, bias=False)
|
| 81 |
+
self.k = nn.Linear(dim, dim_attn, bias=False)
|
| 82 |
+
self.v = nn.Linear(dim, dim_attn, bias=False)
|
| 83 |
+
self.o = nn.Linear(dim_attn, dim, bias=False)
|
| 84 |
+
self.dropout = nn.Dropout(dropout)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, context=None, mask=None, pos_bias=None):
|
| 87 |
+
"""
|
| 88 |
+
x: [B, L1, C].
|
| 89 |
+
context: [B, L2, C] or None.
|
| 90 |
+
mask: [B, L2] or [B, L1, L2] or None.
|
| 91 |
+
"""
|
| 92 |
+
# check inputs
|
| 93 |
+
context = x if context is None else context
|
| 94 |
+
b, n, c = x.size(0), self.num_heads, self.head_dim
|
| 95 |
+
|
| 96 |
+
# compute query, key, value
|
| 97 |
+
q = self.q(x).view(b, -1, n, c)
|
| 98 |
+
k = self.k(context).view(b, -1, n, c)
|
| 99 |
+
v = self.v(context).view(b, -1, n, c)
|
| 100 |
+
|
| 101 |
+
# attention bias
|
| 102 |
+
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
|
| 103 |
+
if pos_bias is not None:
|
| 104 |
+
attn_bias += pos_bias
|
| 105 |
+
if mask is not None:
|
| 106 |
+
assert mask.ndim in [2, 3]
|
| 107 |
+
mask = mask.view(b, 1, 1,
|
| 108 |
+
-1) if mask.ndim == 2 else mask.unsqueeze(1)
|
| 109 |
+
attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
|
| 110 |
+
|
| 111 |
+
# compute attention (T5 does not use scaling)
|
| 112 |
+
attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
|
| 113 |
+
attn = F.softmax(attn.float(), dim=-1).type_as(attn)
|
| 114 |
+
x = torch.einsum('bnij,bjnc->binc', attn, v)
|
| 115 |
+
|
| 116 |
+
# output
|
| 117 |
+
x = x.reshape(b, -1, n * c)
|
| 118 |
+
x = self.o(x)
|
| 119 |
+
x = self.dropout(x)
|
| 120 |
+
return x
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class T5FeedForward(nn.Module):
|
| 124 |
+
|
| 125 |
+
def __init__(self, dim, dim_ffn, dropout=0.1):
|
| 126 |
+
super(T5FeedForward, self).__init__()
|
| 127 |
+
self.dim = dim
|
| 128 |
+
self.dim_ffn = dim_ffn
|
| 129 |
+
|
| 130 |
+
# layers
|
| 131 |
+
self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
|
| 132 |
+
self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
|
| 133 |
+
self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
|
| 134 |
+
self.dropout = nn.Dropout(dropout)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = self.fc1(x) * self.gate(x)
|
| 138 |
+
x = self.dropout(x)
|
| 139 |
+
x = self.fc2(x)
|
| 140 |
+
x = self.dropout(x)
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class T5SelfAttention(nn.Module):
|
| 145 |
+
|
| 146 |
+
def __init__(self,
|
| 147 |
+
dim,
|
| 148 |
+
dim_attn,
|
| 149 |
+
dim_ffn,
|
| 150 |
+
num_heads,
|
| 151 |
+
num_buckets,
|
| 152 |
+
shared_pos=True,
|
| 153 |
+
dropout=0.1):
|
| 154 |
+
super(T5SelfAttention, self).__init__()
|
| 155 |
+
self.dim = dim
|
| 156 |
+
self.dim_attn = dim_attn
|
| 157 |
+
self.dim_ffn = dim_ffn
|
| 158 |
+
self.num_heads = num_heads
|
| 159 |
+
self.num_buckets = num_buckets
|
| 160 |
+
self.shared_pos = shared_pos
|
| 161 |
+
|
| 162 |
+
# layers
|
| 163 |
+
self.norm1 = T5LayerNorm(dim)
|
| 164 |
+
self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 165 |
+
self.norm2 = T5LayerNorm(dim)
|
| 166 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 167 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 168 |
+
num_buckets, num_heads, bidirectional=True)
|
| 169 |
+
|
| 170 |
+
def forward(self, x, mask=None, pos_bias=None):
|
| 171 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 172 |
+
x.size(1), x.size(1))
|
| 173 |
+
x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 174 |
+
x = fp16_clamp(x + self.ffn(self.norm2(x)))
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class T5CrossAttention(nn.Module):
|
| 179 |
+
|
| 180 |
+
def __init__(self,
|
| 181 |
+
dim,
|
| 182 |
+
dim_attn,
|
| 183 |
+
dim_ffn,
|
| 184 |
+
num_heads,
|
| 185 |
+
num_buckets,
|
| 186 |
+
shared_pos=True,
|
| 187 |
+
dropout=0.1):
|
| 188 |
+
super(T5CrossAttention, self).__init__()
|
| 189 |
+
self.dim = dim
|
| 190 |
+
self.dim_attn = dim_attn
|
| 191 |
+
self.dim_ffn = dim_ffn
|
| 192 |
+
self.num_heads = num_heads
|
| 193 |
+
self.num_buckets = num_buckets
|
| 194 |
+
self.shared_pos = shared_pos
|
| 195 |
+
|
| 196 |
+
# layers
|
| 197 |
+
self.norm1 = T5LayerNorm(dim)
|
| 198 |
+
self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 199 |
+
self.norm2 = T5LayerNorm(dim)
|
| 200 |
+
self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
|
| 201 |
+
self.norm3 = T5LayerNorm(dim)
|
| 202 |
+
self.ffn = T5FeedForward(dim, dim_ffn, dropout)
|
| 203 |
+
self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
|
| 204 |
+
num_buckets, num_heads, bidirectional=False)
|
| 205 |
+
|
| 206 |
+
def forward(self,
|
| 207 |
+
x,
|
| 208 |
+
mask=None,
|
| 209 |
+
encoder_states=None,
|
| 210 |
+
encoder_mask=None,
|
| 211 |
+
pos_bias=None):
|
| 212 |
+
e = pos_bias if self.shared_pos else self.pos_embedding(
|
| 213 |
+
x.size(1), x.size(1))
|
| 214 |
+
x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
|
| 215 |
+
x = fp16_clamp(x + self.cross_attn(
|
| 216 |
+
self.norm2(x), context=encoder_states, mask=encoder_mask))
|
| 217 |
+
x = fp16_clamp(x + self.ffn(self.norm3(x)))
|
| 218 |
+
return x
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class T5RelativeEmbedding(nn.Module):
|
| 222 |
+
|
| 223 |
+
def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
|
| 224 |
+
super(T5RelativeEmbedding, self).__init__()
|
| 225 |
+
self.num_buckets = num_buckets
|
| 226 |
+
self.num_heads = num_heads
|
| 227 |
+
self.bidirectional = bidirectional
|
| 228 |
+
self.max_dist = max_dist
|
| 229 |
+
|
| 230 |
+
# layers
|
| 231 |
+
self.embedding = nn.Embedding(num_buckets, num_heads)
|
| 232 |
+
|
| 233 |
+
def forward(self, lq, lk):
|
| 234 |
+
device = self.embedding.weight.device
|
| 235 |
+
# rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
|
| 236 |
+
# torch.arange(lq).unsqueeze(1).to(device)
|
| 237 |
+
rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
|
| 238 |
+
torch.arange(lq, device=device).unsqueeze(1)
|
| 239 |
+
rel_pos = self._relative_position_bucket(rel_pos)
|
| 240 |
+
rel_pos_embeds = self.embedding(rel_pos)
|
| 241 |
+
rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
|
| 242 |
+
0) # [1, N, Lq, Lk]
|
| 243 |
+
return rel_pos_embeds.contiguous()
|
| 244 |
+
|
| 245 |
+
def _relative_position_bucket(self, rel_pos):
|
| 246 |
+
# preprocess
|
| 247 |
+
if self.bidirectional:
|
| 248 |
+
num_buckets = self.num_buckets // 2
|
| 249 |
+
rel_buckets = (rel_pos > 0).long() * num_buckets
|
| 250 |
+
rel_pos = torch.abs(rel_pos)
|
| 251 |
+
else:
|
| 252 |
+
num_buckets = self.num_buckets
|
| 253 |
+
rel_buckets = 0
|
| 254 |
+
rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
|
| 255 |
+
|
| 256 |
+
# embeddings for small and large positions
|
| 257 |
+
max_exact = num_buckets // 2
|
| 258 |
+
rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
|
| 259 |
+
math.log(self.max_dist / max_exact) *
|
| 260 |
+
(num_buckets - max_exact)).long()
|
| 261 |
+
rel_pos_large = torch.min(
|
| 262 |
+
rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
|
| 263 |
+
rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
|
| 264 |
+
return rel_buckets
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class T5Encoder(nn.Module):
|
| 268 |
+
|
| 269 |
+
def __init__(self,
|
| 270 |
+
vocab,
|
| 271 |
+
dim,
|
| 272 |
+
dim_attn,
|
| 273 |
+
dim_ffn,
|
| 274 |
+
num_heads,
|
| 275 |
+
num_layers,
|
| 276 |
+
num_buckets,
|
| 277 |
+
shared_pos=True,
|
| 278 |
+
dropout=0.1):
|
| 279 |
+
super(T5Encoder, self).__init__()
|
| 280 |
+
self.dim = dim
|
| 281 |
+
self.dim_attn = dim_attn
|
| 282 |
+
self.dim_ffn = dim_ffn
|
| 283 |
+
self.num_heads = num_heads
|
| 284 |
+
self.num_layers = num_layers
|
| 285 |
+
self.num_buckets = num_buckets
|
| 286 |
+
self.shared_pos = shared_pos
|
| 287 |
+
|
| 288 |
+
# layers
|
| 289 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 290 |
+
else nn.Embedding(vocab, dim)
|
| 291 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 292 |
+
num_buckets, num_heads, bidirectional=True) if shared_pos else None
|
| 293 |
+
self.dropout = nn.Dropout(dropout)
|
| 294 |
+
self.blocks = nn.ModuleList([
|
| 295 |
+
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 296 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 297 |
+
])
|
| 298 |
+
self.norm = T5LayerNorm(dim)
|
| 299 |
+
|
| 300 |
+
# initialize weights
|
| 301 |
+
self.apply(init_weights)
|
| 302 |
+
|
| 303 |
+
def forward(self, ids, mask=None):
|
| 304 |
+
x = self.token_embedding(ids)
|
| 305 |
+
x = self.dropout(x)
|
| 306 |
+
e = self.pos_embedding(x.size(1),
|
| 307 |
+
x.size(1)) if self.shared_pos else None
|
| 308 |
+
for block in self.blocks:
|
| 309 |
+
x = block(x, mask, pos_bias=e)
|
| 310 |
+
x = self.norm(x)
|
| 311 |
+
x = self.dropout(x)
|
| 312 |
+
return x
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class T5Decoder(nn.Module):
|
| 316 |
+
|
| 317 |
+
def __init__(self,
|
| 318 |
+
vocab,
|
| 319 |
+
dim,
|
| 320 |
+
dim_attn,
|
| 321 |
+
dim_ffn,
|
| 322 |
+
num_heads,
|
| 323 |
+
num_layers,
|
| 324 |
+
num_buckets,
|
| 325 |
+
shared_pos=True,
|
| 326 |
+
dropout=0.1):
|
| 327 |
+
super(T5Decoder, self).__init__()
|
| 328 |
+
self.dim = dim
|
| 329 |
+
self.dim_attn = dim_attn
|
| 330 |
+
self.dim_ffn = dim_ffn
|
| 331 |
+
self.num_heads = num_heads
|
| 332 |
+
self.num_layers = num_layers
|
| 333 |
+
self.num_buckets = num_buckets
|
| 334 |
+
self.shared_pos = shared_pos
|
| 335 |
+
|
| 336 |
+
# layers
|
| 337 |
+
self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
|
| 338 |
+
else nn.Embedding(vocab, dim)
|
| 339 |
+
self.pos_embedding = T5RelativeEmbedding(
|
| 340 |
+
num_buckets, num_heads, bidirectional=False) if shared_pos else None
|
| 341 |
+
self.dropout = nn.Dropout(dropout)
|
| 342 |
+
self.blocks = nn.ModuleList([
|
| 343 |
+
T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
|
| 344 |
+
shared_pos, dropout) for _ in range(num_layers)
|
| 345 |
+
])
|
| 346 |
+
self.norm = T5LayerNorm(dim)
|
| 347 |
+
|
| 348 |
+
# initialize weights
|
| 349 |
+
self.apply(init_weights)
|
| 350 |
+
|
| 351 |
+
def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
|
| 352 |
+
b, s = ids.size()
|
| 353 |
+
|
| 354 |
+
# causal mask
|
| 355 |
+
if mask is None:
|
| 356 |
+
mask = torch.tril(torch.ones(1, s, s).to(ids.device))
|
| 357 |
+
elif mask.ndim == 2:
|
| 358 |
+
mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
|
| 359 |
+
|
| 360 |
+
# layers
|
| 361 |
+
x = self.token_embedding(ids)
|
| 362 |
+
x = self.dropout(x)
|
| 363 |
+
e = self.pos_embedding(x.size(1),
|
| 364 |
+
x.size(1)) if self.shared_pos else None
|
| 365 |
+
for block in self.blocks:
|
| 366 |
+
x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
|
| 367 |
+
x = self.norm(x)
|
| 368 |
+
x = self.dropout(x)
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class T5Model(nn.Module):
|
| 373 |
+
|
| 374 |
+
def __init__(self,
|
| 375 |
+
vocab_size,
|
| 376 |
+
dim,
|
| 377 |
+
dim_attn,
|
| 378 |
+
dim_ffn,
|
| 379 |
+
num_heads,
|
| 380 |
+
encoder_layers,
|
| 381 |
+
decoder_layers,
|
| 382 |
+
num_buckets,
|
| 383 |
+
shared_pos=True,
|
| 384 |
+
dropout=0.1):
|
| 385 |
+
super(T5Model, self).__init__()
|
| 386 |
+
self.vocab_size = vocab_size
|
| 387 |
+
self.dim = dim
|
| 388 |
+
self.dim_attn = dim_attn
|
| 389 |
+
self.dim_ffn = dim_ffn
|
| 390 |
+
self.num_heads = num_heads
|
| 391 |
+
self.encoder_layers = encoder_layers
|
| 392 |
+
self.decoder_layers = decoder_layers
|
| 393 |
+
self.num_buckets = num_buckets
|
| 394 |
+
|
| 395 |
+
# layers
|
| 396 |
+
self.token_embedding = nn.Embedding(vocab_size, dim)
|
| 397 |
+
self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 398 |
+
num_heads, encoder_layers, num_buckets,
|
| 399 |
+
shared_pos, dropout)
|
| 400 |
+
self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
|
| 401 |
+
num_heads, decoder_layers, num_buckets,
|
| 402 |
+
shared_pos, dropout)
|
| 403 |
+
self.head = nn.Linear(dim, vocab_size, bias=False)
|
| 404 |
+
|
| 405 |
+
# initialize weights
|
| 406 |
+
self.apply(init_weights)
|
| 407 |
+
|
| 408 |
+
def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
|
| 409 |
+
x = self.encoder(encoder_ids, encoder_mask)
|
| 410 |
+
x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
|
| 411 |
+
x = self.head(x)
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def _t5(name,
|
| 416 |
+
encoder_only=False,
|
| 417 |
+
decoder_only=False,
|
| 418 |
+
return_tokenizer=False,
|
| 419 |
+
tokenizer_kwargs={},
|
| 420 |
+
dtype=torch.float32,
|
| 421 |
+
device='cpu',
|
| 422 |
+
**kwargs):
|
| 423 |
+
# sanity check
|
| 424 |
+
assert not (encoder_only and decoder_only)
|
| 425 |
+
|
| 426 |
+
# params
|
| 427 |
+
if encoder_only:
|
| 428 |
+
model_cls = T5Encoder
|
| 429 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 430 |
+
kwargs['num_layers'] = kwargs.pop('encoder_layers')
|
| 431 |
+
_ = kwargs.pop('decoder_layers')
|
| 432 |
+
elif decoder_only:
|
| 433 |
+
model_cls = T5Decoder
|
| 434 |
+
kwargs['vocab'] = kwargs.pop('vocab_size')
|
| 435 |
+
kwargs['num_layers'] = kwargs.pop('decoder_layers')
|
| 436 |
+
_ = kwargs.pop('encoder_layers')
|
| 437 |
+
else:
|
| 438 |
+
model_cls = T5Model
|
| 439 |
+
|
| 440 |
+
# init model
|
| 441 |
+
with torch.device(device):
|
| 442 |
+
model = model_cls(**kwargs)
|
| 443 |
+
|
| 444 |
+
# set device
|
| 445 |
+
model = model.to(dtype=dtype, device=device)
|
| 446 |
+
|
| 447 |
+
# init tokenizer
|
| 448 |
+
if return_tokenizer:
|
| 449 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 450 |
+
tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
|
| 451 |
+
return model, tokenizer
|
| 452 |
+
else:
|
| 453 |
+
return model
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def umt5_xxl(**kwargs):
|
| 457 |
+
cfg = dict(
|
| 458 |
+
vocab_size=256384,
|
| 459 |
+
dim=4096,
|
| 460 |
+
dim_attn=4096,
|
| 461 |
+
dim_ffn=10240,
|
| 462 |
+
num_heads=64,
|
| 463 |
+
encoder_layers=24,
|
| 464 |
+
decoder_layers=24,
|
| 465 |
+
num_buckets=32,
|
| 466 |
+
shared_pos=False,
|
| 467 |
+
dropout=0.1)
|
| 468 |
+
cfg.update(**kwargs)
|
| 469 |
+
return _t5('umt5-xxl', **cfg)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class T5EncoderModel:
|
| 473 |
+
|
| 474 |
+
def __init__(
|
| 475 |
+
self,
|
| 476 |
+
text_len,
|
| 477 |
+
dtype=torch.bfloat16,
|
| 478 |
+
device=torch.cuda.current_device(),
|
| 479 |
+
checkpoint_path=None,
|
| 480 |
+
tokenizer_path=None,
|
| 481 |
+
shard_fn=None,
|
| 482 |
+
):
|
| 483 |
+
self.text_len = text_len
|
| 484 |
+
self.dtype = dtype
|
| 485 |
+
self.device = device
|
| 486 |
+
self.checkpoint_path = checkpoint_path
|
| 487 |
+
self.tokenizer_path = tokenizer_path
|
| 488 |
+
|
| 489 |
+
# init model
|
| 490 |
+
model = umt5_xxl(
|
| 491 |
+
encoder_only=True,
|
| 492 |
+
return_tokenizer=False,
|
| 493 |
+
dtype=dtype,
|
| 494 |
+
device=device).eval().requires_grad_(False)
|
| 495 |
+
logging.info(f'loading {checkpoint_path}')
|
| 496 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
|
| 497 |
+
self.model = model
|
| 498 |
+
if shard_fn is not None:
|
| 499 |
+
self.model = shard_fn(self.model, sync_module_states=False)
|
| 500 |
+
else:
|
| 501 |
+
self.model.to(self.device)
|
| 502 |
+
# init tokenizer
|
| 503 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 504 |
+
name=tokenizer_path, seq_len=text_len, clean='whitespace')
|
| 505 |
+
|
| 506 |
+
def __call__(self, texts, device):
|
| 507 |
+
ids, mask = self.tokenizer(
|
| 508 |
+
texts, return_mask=True, add_special_tokens=True)
|
| 509 |
+
ids = ids.to(device)
|
| 510 |
+
mask = mask.to(device)
|
| 511 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 512 |
+
context = self.model(ids, mask)
|
| 513 |
+
return [u[:v] for u, v in zip(context, seq_lens)]
|
wan/modules/tokenizers.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import html
|
| 3 |
+
import string
|
| 4 |
+
|
| 5 |
+
import ftfy
|
| 6 |
+
import regex as re
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
|
| 9 |
+
__all__ = ['HuggingfaceTokenizer']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def basic_clean(text):
|
| 13 |
+
text = ftfy.fix_text(text)
|
| 14 |
+
text = html.unescape(html.unescape(text))
|
| 15 |
+
return text.strip()
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def whitespace_clean(text):
|
| 19 |
+
text = re.sub(r'\s+', ' ', text)
|
| 20 |
+
text = text.strip()
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def canonicalize(text, keep_punctuation_exact_string=None):
|
| 25 |
+
text = text.replace('_', ' ')
|
| 26 |
+
if keep_punctuation_exact_string:
|
| 27 |
+
text = keep_punctuation_exact_string.join(
|
| 28 |
+
part.translate(str.maketrans('', '', string.punctuation))
|
| 29 |
+
for part in text.split(keep_punctuation_exact_string))
|
| 30 |
+
else:
|
| 31 |
+
text = text.translate(str.maketrans('', '', string.punctuation))
|
| 32 |
+
text = text.lower()
|
| 33 |
+
text = re.sub(r'\s+', ' ', text)
|
| 34 |
+
return text.strip()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class HuggingfaceTokenizer:
|
| 38 |
+
|
| 39 |
+
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
| 40 |
+
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
| 41 |
+
self.name = name
|
| 42 |
+
self.seq_len = seq_len
|
| 43 |
+
self.clean = clean
|
| 44 |
+
|
| 45 |
+
# init tokenizer
|
| 46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
| 47 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 48 |
+
|
| 49 |
+
def __call__(self, sequence, **kwargs):
|
| 50 |
+
return_mask = kwargs.pop('return_mask', False)
|
| 51 |
+
|
| 52 |
+
# arguments
|
| 53 |
+
_kwargs = {'return_tensors': 'pt'}
|
| 54 |
+
if self.seq_len is not None:
|
| 55 |
+
_kwargs.update({
|
| 56 |
+
'padding': 'max_length',
|
| 57 |
+
'truncation': True,
|
| 58 |
+
'max_length': self.seq_len
|
| 59 |
+
})
|
| 60 |
+
_kwargs.update(**kwargs)
|
| 61 |
+
|
| 62 |
+
# tokenization
|
| 63 |
+
if isinstance(sequence, str):
|
| 64 |
+
sequence = [sequence]
|
| 65 |
+
if self.clean:
|
| 66 |
+
sequence = [self._clean(u) for u in sequence]
|
| 67 |
+
ids = self.tokenizer(sequence, **_kwargs)
|
| 68 |
+
|
| 69 |
+
# output
|
| 70 |
+
if return_mask:
|
| 71 |
+
return ids.input_ids, ids.attention_mask
|
| 72 |
+
else:
|
| 73 |
+
return ids.input_ids
|
| 74 |
+
|
| 75 |
+
def _clean(self, text):
|
| 76 |
+
if self.clean == 'whitespace':
|
| 77 |
+
text = whitespace_clean(basic_clean(text))
|
| 78 |
+
elif self.clean == 'lower':
|
| 79 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
| 80 |
+
elif self.clean == 'canonicalize':
|
| 81 |
+
text = canonicalize(basic_clean(text))
|
| 82 |
+
return text
|
wan/modules/vae2_1.py
ADDED
|
@@ -0,0 +1,663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
'Wan2_1_VAE',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
CACHE_T = 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalConv3d(nn.Conv3d):
|
| 18 |
+
"""
|
| 19 |
+
Causal 3d convolusion.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1],
|
| 25 |
+
self.padding[1], 2 * self.padding[0], 0)
|
| 26 |
+
self.padding = (0, 0, 0)
|
| 27 |
+
|
| 28 |
+
def forward(self, x, cache_x=None):
|
| 29 |
+
padding = list(self._padding)
|
| 30 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 31 |
+
cache_x = cache_x.to(x.device)
|
| 32 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 33 |
+
padding[4] -= cache_x.shape[2]
|
| 34 |
+
x = F.pad(x, padding)
|
| 35 |
+
|
| 36 |
+
return super().forward(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class RMS_norm(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 42 |
+
super().__init__()
|
| 43 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 44 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 45 |
+
|
| 46 |
+
self.channel_first = channel_first
|
| 47 |
+
self.scale = dim**0.5
|
| 48 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 49 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
return F.normalize(
|
| 53 |
+
x, dim=(1 if self.channel_first else
|
| 54 |
+
-1)) * self.scale * self.gamma + self.bias
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Upsample(nn.Upsample):
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
"""
|
| 61 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 62 |
+
"""
|
| 63 |
+
return super().forward(x.float()).type_as(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class Resample(nn.Module):
|
| 67 |
+
|
| 68 |
+
def __init__(self, dim, mode):
|
| 69 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 70 |
+
'downsample3d')
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.dim = dim
|
| 73 |
+
self.mode = mode
|
| 74 |
+
|
| 75 |
+
# layers
|
| 76 |
+
if mode == 'upsample2d':
|
| 77 |
+
self.resample = nn.Sequential(
|
| 78 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 79 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 80 |
+
elif mode == 'upsample3d':
|
| 81 |
+
self.resample = nn.Sequential(
|
| 82 |
+
Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
|
| 83 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 84 |
+
self.time_conv = CausalConv3d(
|
| 85 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 86 |
+
|
| 87 |
+
elif mode == 'downsample2d':
|
| 88 |
+
self.resample = nn.Sequential(
|
| 89 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 90 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 91 |
+
elif mode == 'downsample3d':
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 95 |
+
self.time_conv = CausalConv3d(
|
| 96 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
self.resample = nn.Identity()
|
| 100 |
+
|
| 101 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 102 |
+
b, c, t, h, w = x.size()
|
| 103 |
+
if self.mode == 'upsample3d':
|
| 104 |
+
if feat_cache is not None:
|
| 105 |
+
idx = feat_idx[0]
|
| 106 |
+
if feat_cache[idx] is None:
|
| 107 |
+
feat_cache[idx] = 'Rep'
|
| 108 |
+
feat_idx[0] += 1
|
| 109 |
+
else:
|
| 110 |
+
|
| 111 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 112 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 113 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 114 |
+
# cache last frame of last two chunk
|
| 115 |
+
cache_x = torch.cat([
|
| 116 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 117 |
+
cache_x.device), cache_x
|
| 118 |
+
],
|
| 119 |
+
dim=2)
|
| 120 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 121 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 122 |
+
cache_x = torch.cat([
|
| 123 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 124 |
+
cache_x
|
| 125 |
+
],
|
| 126 |
+
dim=2)
|
| 127 |
+
if feat_cache[idx] == 'Rep':
|
| 128 |
+
x = self.time_conv(x)
|
| 129 |
+
else:
|
| 130 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 131 |
+
feat_cache[idx] = cache_x
|
| 132 |
+
feat_idx[0] += 1
|
| 133 |
+
|
| 134 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 135 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 136 |
+
3)
|
| 137 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 138 |
+
t = x.shape[2]
|
| 139 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 140 |
+
x = self.resample(x)
|
| 141 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 142 |
+
|
| 143 |
+
if self.mode == 'downsample3d':
|
| 144 |
+
if feat_cache is not None:
|
| 145 |
+
idx = feat_idx[0]
|
| 146 |
+
if feat_cache[idx] is None:
|
| 147 |
+
feat_cache[idx] = x.clone()
|
| 148 |
+
feat_idx[0] += 1
|
| 149 |
+
else:
|
| 150 |
+
|
| 151 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 152 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 153 |
+
# # cache last frame of last two chunk
|
| 154 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 155 |
+
|
| 156 |
+
x = self.time_conv(
|
| 157 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 158 |
+
feat_cache[idx] = cache_x
|
| 159 |
+
feat_idx[0] += 1
|
| 160 |
+
return x
|
| 161 |
+
|
| 162 |
+
def init_weight(self, conv):
|
| 163 |
+
conv_weight = conv.weight
|
| 164 |
+
nn.init.zeros_(conv_weight)
|
| 165 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 166 |
+
one_matrix = torch.eye(c1, c2)
|
| 167 |
+
init_matrix = one_matrix
|
| 168 |
+
nn.init.zeros_(conv_weight)
|
| 169 |
+
#conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 170 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
|
| 171 |
+
conv.weight.data.copy_(conv_weight)
|
| 172 |
+
nn.init.zeros_(conv.bias.data)
|
| 173 |
+
|
| 174 |
+
def init_weight2(self, conv):
|
| 175 |
+
conv_weight = conv.weight.data
|
| 176 |
+
nn.init.zeros_(conv_weight)
|
| 177 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 178 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 179 |
+
#init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 180 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 181 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 182 |
+
conv.weight.data.copy_(conv_weight)
|
| 183 |
+
nn.init.zeros_(conv.bias.data)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ResidualBlock(nn.Module):
|
| 187 |
+
|
| 188 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_dim = in_dim
|
| 191 |
+
self.out_dim = out_dim
|
| 192 |
+
|
| 193 |
+
# layers
|
| 194 |
+
self.residual = nn.Sequential(
|
| 195 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 196 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 197 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 198 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 199 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 200 |
+
if in_dim != out_dim else nn.Identity()
|
| 201 |
+
|
| 202 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 203 |
+
h = self.shortcut(x)
|
| 204 |
+
for layer in self.residual:
|
| 205 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 206 |
+
idx = feat_idx[0]
|
| 207 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 208 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 209 |
+
# cache last frame of last two chunk
|
| 210 |
+
cache_x = torch.cat([
|
| 211 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 212 |
+
cache_x.device), cache_x
|
| 213 |
+
],
|
| 214 |
+
dim=2)
|
| 215 |
+
x = layer(x, feat_cache[idx])
|
| 216 |
+
feat_cache[idx] = cache_x
|
| 217 |
+
feat_idx[0] += 1
|
| 218 |
+
else:
|
| 219 |
+
x = layer(x)
|
| 220 |
+
return x + h
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class AttentionBlock(nn.Module):
|
| 224 |
+
"""
|
| 225 |
+
Causal self-attention with a single head.
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(self, dim):
|
| 229 |
+
super().__init__()
|
| 230 |
+
self.dim = dim
|
| 231 |
+
|
| 232 |
+
# layers
|
| 233 |
+
self.norm = RMS_norm(dim)
|
| 234 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 235 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 236 |
+
|
| 237 |
+
# zero out the last layer params
|
| 238 |
+
nn.init.zeros_(self.proj.weight)
|
| 239 |
+
|
| 240 |
+
def forward(self, x):
|
| 241 |
+
identity = x
|
| 242 |
+
b, c, t, h, w = x.size()
|
| 243 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 244 |
+
x = self.norm(x)
|
| 245 |
+
# compute query, key, value
|
| 246 |
+
q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 247 |
+
-1).permute(0, 1, 3,
|
| 248 |
+
2).contiguous().chunk(
|
| 249 |
+
3, dim=-1)
|
| 250 |
+
|
| 251 |
+
# apply attention
|
| 252 |
+
x = F.scaled_dot_product_attention(
|
| 253 |
+
q,
|
| 254 |
+
k,
|
| 255 |
+
v,
|
| 256 |
+
)
|
| 257 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 258 |
+
|
| 259 |
+
# output
|
| 260 |
+
x = self.proj(x)
|
| 261 |
+
x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
|
| 262 |
+
return x + identity
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class Encoder3d(nn.Module):
|
| 266 |
+
|
| 267 |
+
def __init__(self,
|
| 268 |
+
dim=128,
|
| 269 |
+
z_dim=4,
|
| 270 |
+
dim_mult=[1, 2, 4, 4],
|
| 271 |
+
num_res_blocks=2,
|
| 272 |
+
attn_scales=[],
|
| 273 |
+
temperal_downsample=[True, True, False],
|
| 274 |
+
dropout=0.0):
|
| 275 |
+
super().__init__()
|
| 276 |
+
self.dim = dim
|
| 277 |
+
self.z_dim = z_dim
|
| 278 |
+
self.dim_mult = dim_mult
|
| 279 |
+
self.num_res_blocks = num_res_blocks
|
| 280 |
+
self.attn_scales = attn_scales
|
| 281 |
+
self.temperal_downsample = temperal_downsample
|
| 282 |
+
|
| 283 |
+
# dimensions
|
| 284 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 285 |
+
scale = 1.0
|
| 286 |
+
|
| 287 |
+
# init block
|
| 288 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 289 |
+
|
| 290 |
+
# downsample blocks
|
| 291 |
+
downsamples = []
|
| 292 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 293 |
+
# residual (+attention) blocks
|
| 294 |
+
for _ in range(num_res_blocks):
|
| 295 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 296 |
+
if scale in attn_scales:
|
| 297 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 298 |
+
in_dim = out_dim
|
| 299 |
+
|
| 300 |
+
# downsample block
|
| 301 |
+
if i != len(dim_mult) - 1:
|
| 302 |
+
mode = 'downsample3d' if temperal_downsample[
|
| 303 |
+
i] else 'downsample2d'
|
| 304 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 305 |
+
scale /= 2.0
|
| 306 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 307 |
+
|
| 308 |
+
# middle blocks
|
| 309 |
+
self.middle = nn.Sequential(
|
| 310 |
+
ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
|
| 311 |
+
ResidualBlock(out_dim, out_dim, dropout))
|
| 312 |
+
|
| 313 |
+
# output blocks
|
| 314 |
+
self.head = nn.Sequential(
|
| 315 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 316 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
| 317 |
+
|
| 318 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 319 |
+
if feat_cache is not None:
|
| 320 |
+
idx = feat_idx[0]
|
| 321 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 322 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 323 |
+
# cache last frame of last two chunk
|
| 324 |
+
cache_x = torch.cat([
|
| 325 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 326 |
+
cache_x.device), cache_x
|
| 327 |
+
],
|
| 328 |
+
dim=2)
|
| 329 |
+
x = self.conv1(x, feat_cache[idx])
|
| 330 |
+
feat_cache[idx] = cache_x
|
| 331 |
+
feat_idx[0] += 1
|
| 332 |
+
else:
|
| 333 |
+
x = self.conv1(x)
|
| 334 |
+
|
| 335 |
+
## downsamples
|
| 336 |
+
for layer in self.downsamples:
|
| 337 |
+
if feat_cache is not None:
|
| 338 |
+
x = layer(x, feat_cache, feat_idx)
|
| 339 |
+
else:
|
| 340 |
+
x = layer(x)
|
| 341 |
+
|
| 342 |
+
## middle
|
| 343 |
+
for layer in self.middle:
|
| 344 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 345 |
+
x = layer(x, feat_cache, feat_idx)
|
| 346 |
+
else:
|
| 347 |
+
x = layer(x)
|
| 348 |
+
|
| 349 |
+
## head
|
| 350 |
+
for layer in self.head:
|
| 351 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 352 |
+
idx = feat_idx[0]
|
| 353 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 354 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 355 |
+
# cache last frame of last two chunk
|
| 356 |
+
cache_x = torch.cat([
|
| 357 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 358 |
+
cache_x.device), cache_x
|
| 359 |
+
],
|
| 360 |
+
dim=2)
|
| 361 |
+
x = layer(x, feat_cache[idx])
|
| 362 |
+
feat_cache[idx] = cache_x
|
| 363 |
+
feat_idx[0] += 1
|
| 364 |
+
else:
|
| 365 |
+
x = layer(x)
|
| 366 |
+
return x
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
class Decoder3d(nn.Module):
|
| 370 |
+
|
| 371 |
+
def __init__(self,
|
| 372 |
+
dim=128,
|
| 373 |
+
z_dim=4,
|
| 374 |
+
dim_mult=[1, 2, 4, 4],
|
| 375 |
+
num_res_blocks=2,
|
| 376 |
+
attn_scales=[],
|
| 377 |
+
temperal_upsample=[False, True, True],
|
| 378 |
+
dropout=0.0):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.dim = dim
|
| 381 |
+
self.z_dim = z_dim
|
| 382 |
+
self.dim_mult = dim_mult
|
| 383 |
+
self.num_res_blocks = num_res_blocks
|
| 384 |
+
self.attn_scales = attn_scales
|
| 385 |
+
self.temperal_upsample = temperal_upsample
|
| 386 |
+
|
| 387 |
+
# dimensions
|
| 388 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 389 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 390 |
+
|
| 391 |
+
# init block
|
| 392 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 393 |
+
|
| 394 |
+
# middle blocks
|
| 395 |
+
self.middle = nn.Sequential(
|
| 396 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 397 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 398 |
+
|
| 399 |
+
# upsample blocks
|
| 400 |
+
upsamples = []
|
| 401 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 402 |
+
# residual (+attention) blocks
|
| 403 |
+
if i == 1 or i == 2 or i == 3:
|
| 404 |
+
in_dim = in_dim // 2
|
| 405 |
+
for _ in range(num_res_blocks + 1):
|
| 406 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 407 |
+
if scale in attn_scales:
|
| 408 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 409 |
+
in_dim = out_dim
|
| 410 |
+
|
| 411 |
+
# upsample block
|
| 412 |
+
if i != len(dim_mult) - 1:
|
| 413 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 414 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 415 |
+
scale *= 2.0
|
| 416 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 417 |
+
|
| 418 |
+
# output blocks
|
| 419 |
+
self.head = nn.Sequential(
|
| 420 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 421 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 422 |
+
|
| 423 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 424 |
+
## conv1
|
| 425 |
+
if feat_cache is not None:
|
| 426 |
+
idx = feat_idx[0]
|
| 427 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 428 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 429 |
+
# cache last frame of last two chunk
|
| 430 |
+
cache_x = torch.cat([
|
| 431 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 432 |
+
cache_x.device), cache_x
|
| 433 |
+
],
|
| 434 |
+
dim=2)
|
| 435 |
+
x = self.conv1(x, feat_cache[idx])
|
| 436 |
+
feat_cache[idx] = cache_x
|
| 437 |
+
feat_idx[0] += 1
|
| 438 |
+
else:
|
| 439 |
+
x = self.conv1(x)
|
| 440 |
+
|
| 441 |
+
## middle
|
| 442 |
+
for layer in self.middle:
|
| 443 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 444 |
+
x = layer(x, feat_cache, feat_idx)
|
| 445 |
+
else:
|
| 446 |
+
x = layer(x)
|
| 447 |
+
|
| 448 |
+
## upsamples
|
| 449 |
+
for layer in self.upsamples:
|
| 450 |
+
if feat_cache is not None:
|
| 451 |
+
x = layer(x, feat_cache, feat_idx)
|
| 452 |
+
else:
|
| 453 |
+
x = layer(x)
|
| 454 |
+
|
| 455 |
+
## head
|
| 456 |
+
for layer in self.head:
|
| 457 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 458 |
+
idx = feat_idx[0]
|
| 459 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 460 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 461 |
+
# cache last frame of last two chunk
|
| 462 |
+
cache_x = torch.cat([
|
| 463 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 464 |
+
cache_x.device), cache_x
|
| 465 |
+
],
|
| 466 |
+
dim=2)
|
| 467 |
+
x = layer(x, feat_cache[idx])
|
| 468 |
+
feat_cache[idx] = cache_x
|
| 469 |
+
feat_idx[0] += 1
|
| 470 |
+
else:
|
| 471 |
+
x = layer(x)
|
| 472 |
+
return x
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def count_conv3d(model):
|
| 476 |
+
count = 0
|
| 477 |
+
for m in model.modules():
|
| 478 |
+
if isinstance(m, CausalConv3d):
|
| 479 |
+
count += 1
|
| 480 |
+
return count
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class WanVAE_(nn.Module):
|
| 484 |
+
|
| 485 |
+
def __init__(self,
|
| 486 |
+
dim=128,
|
| 487 |
+
z_dim=4,
|
| 488 |
+
dim_mult=[1, 2, 4, 4],
|
| 489 |
+
num_res_blocks=2,
|
| 490 |
+
attn_scales=[],
|
| 491 |
+
temperal_downsample=[True, True, False],
|
| 492 |
+
dropout=0.0):
|
| 493 |
+
super().__init__()
|
| 494 |
+
self.dim = dim
|
| 495 |
+
self.z_dim = z_dim
|
| 496 |
+
self.dim_mult = dim_mult
|
| 497 |
+
self.num_res_blocks = num_res_blocks
|
| 498 |
+
self.attn_scales = attn_scales
|
| 499 |
+
self.temperal_downsample = temperal_downsample
|
| 500 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 501 |
+
|
| 502 |
+
# modules
|
| 503 |
+
self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
|
| 504 |
+
attn_scales, self.temperal_downsample, dropout)
|
| 505 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 506 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 507 |
+
self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
|
| 508 |
+
attn_scales, self.temperal_upsample, dropout)
|
| 509 |
+
|
| 510 |
+
def forward(self, x):
|
| 511 |
+
mu, log_var = self.encode(x)
|
| 512 |
+
z = self.reparameterize(mu, log_var)
|
| 513 |
+
x_recon = self.decode(z)
|
| 514 |
+
return x_recon, mu, log_var
|
| 515 |
+
|
| 516 |
+
def encode(self, x, scale):
|
| 517 |
+
self.clear_cache()
|
| 518 |
+
## cache
|
| 519 |
+
t = x.shape[2]
|
| 520 |
+
iter_ = 1 + (t - 1) // 4
|
| 521 |
+
## 对encode输入的x,按时间拆分为1、4、4、4....
|
| 522 |
+
for i in range(iter_):
|
| 523 |
+
self._enc_conv_idx = [0]
|
| 524 |
+
if i == 0:
|
| 525 |
+
out = self.encoder(
|
| 526 |
+
x[:, :, :1, :, :],
|
| 527 |
+
feat_cache=self._enc_feat_map,
|
| 528 |
+
feat_idx=self._enc_conv_idx)
|
| 529 |
+
else:
|
| 530 |
+
out_ = self.encoder(
|
| 531 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 532 |
+
feat_cache=self._enc_feat_map,
|
| 533 |
+
feat_idx=self._enc_conv_idx)
|
| 534 |
+
out = torch.cat([out, out_], 2)
|
| 535 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 536 |
+
if isinstance(scale[0], torch.Tensor):
|
| 537 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 538 |
+
1, self.z_dim, 1, 1, 1)
|
| 539 |
+
else:
|
| 540 |
+
mu = (mu - scale[0]) * scale[1]
|
| 541 |
+
self.clear_cache()
|
| 542 |
+
return mu
|
| 543 |
+
|
| 544 |
+
def decode(self, z, scale):
|
| 545 |
+
self.clear_cache()
|
| 546 |
+
# z: [b,c,t,h,w]
|
| 547 |
+
if isinstance(scale[0], torch.Tensor):
|
| 548 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 549 |
+
1, self.z_dim, 1, 1, 1)
|
| 550 |
+
else:
|
| 551 |
+
z = z / scale[1] + scale[0]
|
| 552 |
+
iter_ = z.shape[2]
|
| 553 |
+
x = self.conv2(z)
|
| 554 |
+
for i in range(iter_):
|
| 555 |
+
self._conv_idx = [0]
|
| 556 |
+
if i == 0:
|
| 557 |
+
out = self.decoder(
|
| 558 |
+
x[:, :, i:i + 1, :, :],
|
| 559 |
+
feat_cache=self._feat_map,
|
| 560 |
+
feat_idx=self._conv_idx)
|
| 561 |
+
else:
|
| 562 |
+
out_ = self.decoder(
|
| 563 |
+
x[:, :, i:i + 1, :, :],
|
| 564 |
+
feat_cache=self._feat_map,
|
| 565 |
+
feat_idx=self._conv_idx)
|
| 566 |
+
out = torch.cat([out, out_], 2)
|
| 567 |
+
self.clear_cache()
|
| 568 |
+
return out
|
| 569 |
+
|
| 570 |
+
def reparameterize(self, mu, log_var):
|
| 571 |
+
std = torch.exp(0.5 * log_var)
|
| 572 |
+
eps = torch.randn_like(std)
|
| 573 |
+
return eps * std + mu
|
| 574 |
+
|
| 575 |
+
def sample(self, imgs, deterministic=False):
|
| 576 |
+
mu, log_var = self.encode(imgs)
|
| 577 |
+
if deterministic:
|
| 578 |
+
return mu
|
| 579 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 580 |
+
return mu + std * torch.randn_like(std)
|
| 581 |
+
|
| 582 |
+
def clear_cache(self):
|
| 583 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 584 |
+
self._conv_idx = [0]
|
| 585 |
+
self._feat_map = [None] * self._conv_num
|
| 586 |
+
#cache encode
|
| 587 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 588 |
+
self._enc_conv_idx = [0]
|
| 589 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
| 593 |
+
"""
|
| 594 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
| 595 |
+
"""
|
| 596 |
+
# params
|
| 597 |
+
cfg = dict(
|
| 598 |
+
dim=96,
|
| 599 |
+
z_dim=z_dim,
|
| 600 |
+
dim_mult=[1, 2, 4, 4],
|
| 601 |
+
num_res_blocks=2,
|
| 602 |
+
attn_scales=[],
|
| 603 |
+
temperal_downsample=[False, True, True],
|
| 604 |
+
dropout=0.0)
|
| 605 |
+
cfg.update(**kwargs)
|
| 606 |
+
|
| 607 |
+
# init model
|
| 608 |
+
with torch.device('meta'):
|
| 609 |
+
model = WanVAE_(**cfg)
|
| 610 |
+
|
| 611 |
+
# load checkpoint
|
| 612 |
+
logging.info(f'loading {pretrained_path}')
|
| 613 |
+
model.load_state_dict(
|
| 614 |
+
torch.load(pretrained_path, map_location=device), assign=True)
|
| 615 |
+
|
| 616 |
+
return model
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
class Wan2_1_VAE:
|
| 620 |
+
|
| 621 |
+
def __init__(self,
|
| 622 |
+
z_dim=16,
|
| 623 |
+
vae_pth='cache/vae_step_411000.pth',
|
| 624 |
+
dtype=torch.float,
|
| 625 |
+
device="cuda"):
|
| 626 |
+
self.dtype = dtype
|
| 627 |
+
self.device = device
|
| 628 |
+
|
| 629 |
+
mean = [
|
| 630 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 631 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 632 |
+
]
|
| 633 |
+
std = [
|
| 634 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 635 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 636 |
+
]
|
| 637 |
+
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
| 638 |
+
self.std = torch.tensor(std, dtype=dtype, device=device)
|
| 639 |
+
self.scale = [self.mean, 1.0 / self.std]
|
| 640 |
+
|
| 641 |
+
# init model
|
| 642 |
+
self.model = _video_vae(
|
| 643 |
+
pretrained_path=vae_pth,
|
| 644 |
+
z_dim=z_dim,
|
| 645 |
+
).eval().requires_grad_(False).to(device)
|
| 646 |
+
|
| 647 |
+
def encode(self, videos):
|
| 648 |
+
"""
|
| 649 |
+
videos: A list of videos each with shape [C, T, H, W].
|
| 650 |
+
"""
|
| 651 |
+
with amp.autocast(dtype=self.dtype):
|
| 652 |
+
return [
|
| 653 |
+
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
|
| 654 |
+
for u in videos
|
| 655 |
+
]
|
| 656 |
+
|
| 657 |
+
def decode(self, zs):
|
| 658 |
+
with amp.autocast(dtype=self.dtype):
|
| 659 |
+
return [
|
| 660 |
+
self.model.decode(u.unsqueeze(0),
|
| 661 |
+
self.scale).float().clamp_(-1, 1).squeeze(0)
|
| 662 |
+
for u in zs
|
| 663 |
+
]
|
wan/modules/vae2_2.py
ADDED
|
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.cuda.amp as amp
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"Wan2_2_VAE",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
CACHE_T = 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalConv3d(nn.Conv3d):
|
| 18 |
+
"""
|
| 19 |
+
Causal 3d convolusion.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self._padding = (
|
| 25 |
+
self.padding[2],
|
| 26 |
+
self.padding[2],
|
| 27 |
+
self.padding[1],
|
| 28 |
+
self.padding[1],
|
| 29 |
+
2 * self.padding[0],
|
| 30 |
+
0,
|
| 31 |
+
)
|
| 32 |
+
self.padding = (0, 0, 0)
|
| 33 |
+
|
| 34 |
+
def forward(self, x, cache_x=None):
|
| 35 |
+
padding = list(self._padding)
|
| 36 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 37 |
+
cache_x = cache_x.to(x.device)
|
| 38 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 39 |
+
padding[4] -= cache_x.shape[2]
|
| 40 |
+
x = F.pad(x, padding)
|
| 41 |
+
|
| 42 |
+
return super().forward(x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RMS_norm(nn.Module):
|
| 46 |
+
|
| 47 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 48 |
+
super().__init__()
|
| 49 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 50 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 51 |
+
|
| 52 |
+
self.channel_first = channel_first
|
| 53 |
+
self.scale = dim**0.5
|
| 54 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 55 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
|
| 59 |
+
self.scale * self.gamma + self.bias)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Upsample(nn.Upsample):
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
"""
|
| 66 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 67 |
+
"""
|
| 68 |
+
return super().forward(x.float()).type_as(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class Resample(nn.Module):
|
| 72 |
+
|
| 73 |
+
def __init__(self, dim, mode):
|
| 74 |
+
assert mode in (
|
| 75 |
+
"none",
|
| 76 |
+
"upsample2d",
|
| 77 |
+
"upsample3d",
|
| 78 |
+
"downsample2d",
|
| 79 |
+
"downsample3d",
|
| 80 |
+
)
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.dim = dim
|
| 83 |
+
self.mode = mode
|
| 84 |
+
|
| 85 |
+
# layers
|
| 86 |
+
if mode == "upsample2d":
|
| 87 |
+
self.resample = nn.Sequential(
|
| 88 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 89 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 90 |
+
)
|
| 91 |
+
elif mode == "upsample3d":
|
| 92 |
+
self.resample = nn.Sequential(
|
| 93 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 94 |
+
nn.Conv2d(dim, dim, 3, padding=1),
|
| 95 |
+
# nn.Conv2d(dim, dim//2, 3, padding=1)
|
| 96 |
+
)
|
| 97 |
+
self.time_conv = CausalConv3d(
|
| 98 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 99 |
+
elif mode == "downsample2d":
|
| 100 |
+
self.resample = nn.Sequential(
|
| 101 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 102 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 103 |
+
elif mode == "downsample3d":
|
| 104 |
+
self.resample = nn.Sequential(
|
| 105 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 106 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 107 |
+
self.time_conv = CausalConv3d(
|
| 108 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 109 |
+
else:
|
| 110 |
+
self.resample = nn.Identity()
|
| 111 |
+
|
| 112 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 113 |
+
b, c, t, h, w = x.size()
|
| 114 |
+
if self.mode == "upsample3d":
|
| 115 |
+
if feat_cache is not None:
|
| 116 |
+
idx = feat_idx[0]
|
| 117 |
+
if feat_cache[idx] is None:
|
| 118 |
+
feat_cache[idx] = "Rep"
|
| 119 |
+
feat_idx[0] += 1
|
| 120 |
+
else:
|
| 121 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 122 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 123 |
+
feat_cache[idx] != "Rep"):
|
| 124 |
+
# cache last frame of last two chunk
|
| 125 |
+
cache_x = torch.cat(
|
| 126 |
+
[
|
| 127 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 128 |
+
cache_x.device),
|
| 129 |
+
cache_x,
|
| 130 |
+
],
|
| 131 |
+
dim=2,
|
| 132 |
+
)
|
| 133 |
+
if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
|
| 134 |
+
feat_cache[idx] == "Rep"):
|
| 135 |
+
cache_x = torch.cat(
|
| 136 |
+
[
|
| 137 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 138 |
+
cache_x
|
| 139 |
+
],
|
| 140 |
+
dim=2,
|
| 141 |
+
)
|
| 142 |
+
if feat_cache[idx] == "Rep":
|
| 143 |
+
x = self.time_conv(x)
|
| 144 |
+
else:
|
| 145 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 146 |
+
feat_cache[idx] = cache_x
|
| 147 |
+
feat_idx[0] += 1
|
| 148 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 149 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 150 |
+
3)
|
| 151 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 152 |
+
t = x.shape[2]
|
| 153 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 154 |
+
x = self.resample(x)
|
| 155 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 156 |
+
|
| 157 |
+
if self.mode == "downsample3d":
|
| 158 |
+
if feat_cache is not None:
|
| 159 |
+
idx = feat_idx[0]
|
| 160 |
+
if feat_cache[idx] is None:
|
| 161 |
+
feat_cache[idx] = x.clone()
|
| 162 |
+
feat_idx[0] += 1
|
| 163 |
+
else:
|
| 164 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 165 |
+
x = self.time_conv(
|
| 166 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 167 |
+
feat_cache[idx] = cache_x
|
| 168 |
+
feat_idx[0] += 1
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
def init_weight(self, conv):
|
| 172 |
+
conv_weight = conv.weight.detach().clone()
|
| 173 |
+
nn.init.zeros_(conv_weight)
|
| 174 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 175 |
+
one_matrix = torch.eye(c1, c2)
|
| 176 |
+
init_matrix = one_matrix
|
| 177 |
+
nn.init.zeros_(conv_weight)
|
| 178 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 179 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 180 |
+
nn.init.zeros_(conv.bias.data)
|
| 181 |
+
|
| 182 |
+
def init_weight2(self, conv):
|
| 183 |
+
conv_weight = conv.weight.data.detach().clone()
|
| 184 |
+
nn.init.zeros_(conv_weight)
|
| 185 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 186 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 187 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 188 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 189 |
+
conv.weight = nn.Parameter(conv_weight)
|
| 190 |
+
nn.init.zeros_(conv.bias.data)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class ResidualBlock(nn.Module):
|
| 194 |
+
|
| 195 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.in_dim = in_dim
|
| 198 |
+
self.out_dim = out_dim
|
| 199 |
+
|
| 200 |
+
# layers
|
| 201 |
+
self.residual = nn.Sequential(
|
| 202 |
+
RMS_norm(in_dim, images=False),
|
| 203 |
+
nn.SiLU(),
|
| 204 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 205 |
+
RMS_norm(out_dim, images=False),
|
| 206 |
+
nn.SiLU(),
|
| 207 |
+
nn.Dropout(dropout),
|
| 208 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 209 |
+
)
|
| 210 |
+
self.shortcut = (
|
| 211 |
+
CausalConv3d(in_dim, out_dim, 1)
|
| 212 |
+
if in_dim != out_dim else nn.Identity())
|
| 213 |
+
|
| 214 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 215 |
+
h = self.shortcut(x)
|
| 216 |
+
for layer in self.residual:
|
| 217 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 218 |
+
idx = feat_idx[0]
|
| 219 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 220 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 221 |
+
# cache last frame of last two chunk
|
| 222 |
+
cache_x = torch.cat(
|
| 223 |
+
[
|
| 224 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 225 |
+
cache_x.device),
|
| 226 |
+
cache_x,
|
| 227 |
+
],
|
| 228 |
+
dim=2,
|
| 229 |
+
)
|
| 230 |
+
x = layer(x, feat_cache[idx])
|
| 231 |
+
feat_cache[idx] = cache_x
|
| 232 |
+
feat_idx[0] += 1
|
| 233 |
+
else:
|
| 234 |
+
x = layer(x)
|
| 235 |
+
return x + h
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class AttentionBlock(nn.Module):
|
| 239 |
+
"""
|
| 240 |
+
Causal self-attention with a single head.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, dim):
|
| 244 |
+
super().__init__()
|
| 245 |
+
self.dim = dim
|
| 246 |
+
|
| 247 |
+
# layers
|
| 248 |
+
self.norm = RMS_norm(dim)
|
| 249 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 250 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 251 |
+
|
| 252 |
+
# zero out the last layer params
|
| 253 |
+
nn.init.zeros_(self.proj.weight)
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
identity = x
|
| 257 |
+
b, c, t, h, w = x.size()
|
| 258 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 259 |
+
x = self.norm(x)
|
| 260 |
+
# compute query, key, value
|
| 261 |
+
q, k, v = (
|
| 262 |
+
self.to_qkv(x).reshape(b * t, 1, c * 3,
|
| 263 |
+
-1).permute(0, 1, 3,
|
| 264 |
+
2).contiguous().chunk(3, dim=-1))
|
| 265 |
+
|
| 266 |
+
# apply attention
|
| 267 |
+
x = F.scaled_dot_product_attention(
|
| 268 |
+
q,
|
| 269 |
+
k,
|
| 270 |
+
v,
|
| 271 |
+
)
|
| 272 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 273 |
+
|
| 274 |
+
# output
|
| 275 |
+
x = self.proj(x)
|
| 276 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 277 |
+
return x + identity
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def patchify(x, patch_size):
|
| 281 |
+
if patch_size == 1:
|
| 282 |
+
return x
|
| 283 |
+
if x.dim() == 4:
|
| 284 |
+
x = rearrange(
|
| 285 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
|
| 286 |
+
elif x.dim() == 5:
|
| 287 |
+
x = rearrange(
|
| 288 |
+
x,
|
| 289 |
+
"b c f (h q) (w r) -> b (c r q) f h w",
|
| 290 |
+
q=patch_size,
|
| 291 |
+
r=patch_size,
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 295 |
+
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def unpatchify(x, patch_size):
|
| 300 |
+
if patch_size == 1:
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
if x.dim() == 4:
|
| 304 |
+
x = rearrange(
|
| 305 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
|
| 306 |
+
elif x.dim() == 5:
|
| 307 |
+
x = rearrange(
|
| 308 |
+
x,
|
| 309 |
+
"b (c r q) f h w -> b c f (h q) (w r)",
|
| 310 |
+
q=patch_size,
|
| 311 |
+
r=patch_size,
|
| 312 |
+
)
|
| 313 |
+
return x
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class AvgDown3D(nn.Module):
|
| 317 |
+
|
| 318 |
+
def __init__(
|
| 319 |
+
self,
|
| 320 |
+
in_channels,
|
| 321 |
+
out_channels,
|
| 322 |
+
factor_t,
|
| 323 |
+
factor_s=1,
|
| 324 |
+
):
|
| 325 |
+
super().__init__()
|
| 326 |
+
self.in_channels = in_channels
|
| 327 |
+
self.out_channels = out_channels
|
| 328 |
+
self.factor_t = factor_t
|
| 329 |
+
self.factor_s = factor_s
|
| 330 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 331 |
+
|
| 332 |
+
assert in_channels * self.factor % out_channels == 0
|
| 333 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 334 |
+
|
| 335 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 336 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 337 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 338 |
+
x = F.pad(x, pad)
|
| 339 |
+
B, C, T, H, W = x.shape
|
| 340 |
+
x = x.view(
|
| 341 |
+
B,
|
| 342 |
+
C,
|
| 343 |
+
T // self.factor_t,
|
| 344 |
+
self.factor_t,
|
| 345 |
+
H // self.factor_s,
|
| 346 |
+
self.factor_s,
|
| 347 |
+
W // self.factor_s,
|
| 348 |
+
self.factor_s,
|
| 349 |
+
)
|
| 350 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 351 |
+
x = x.view(
|
| 352 |
+
B,
|
| 353 |
+
C * self.factor,
|
| 354 |
+
T // self.factor_t,
|
| 355 |
+
H // self.factor_s,
|
| 356 |
+
W // self.factor_s,
|
| 357 |
+
)
|
| 358 |
+
x = x.view(
|
| 359 |
+
B,
|
| 360 |
+
self.out_channels,
|
| 361 |
+
self.group_size,
|
| 362 |
+
T // self.factor_t,
|
| 363 |
+
H // self.factor_s,
|
| 364 |
+
W // self.factor_s,
|
| 365 |
+
)
|
| 366 |
+
x = x.mean(dim=2)
|
| 367 |
+
return x
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class DupUp3D(nn.Module):
|
| 371 |
+
|
| 372 |
+
def __init__(
|
| 373 |
+
self,
|
| 374 |
+
in_channels: int,
|
| 375 |
+
out_channels: int,
|
| 376 |
+
factor_t,
|
| 377 |
+
factor_s=1,
|
| 378 |
+
):
|
| 379 |
+
super().__init__()
|
| 380 |
+
self.in_channels = in_channels
|
| 381 |
+
self.out_channels = out_channels
|
| 382 |
+
|
| 383 |
+
self.factor_t = factor_t
|
| 384 |
+
self.factor_s = factor_s
|
| 385 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 386 |
+
|
| 387 |
+
assert out_channels * self.factor % in_channels == 0
|
| 388 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 389 |
+
|
| 390 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 391 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 392 |
+
x = x.view(
|
| 393 |
+
x.size(0),
|
| 394 |
+
self.out_channels,
|
| 395 |
+
self.factor_t,
|
| 396 |
+
self.factor_s,
|
| 397 |
+
self.factor_s,
|
| 398 |
+
x.size(2),
|
| 399 |
+
x.size(3),
|
| 400 |
+
x.size(4),
|
| 401 |
+
)
|
| 402 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 403 |
+
x = x.view(
|
| 404 |
+
x.size(0),
|
| 405 |
+
self.out_channels,
|
| 406 |
+
x.size(2) * self.factor_t,
|
| 407 |
+
x.size(4) * self.factor_s,
|
| 408 |
+
x.size(6) * self.factor_s,
|
| 409 |
+
)
|
| 410 |
+
if first_chunk:
|
| 411 |
+
x = x[:, :, self.factor_t - 1:, :, :]
|
| 412 |
+
return x
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class Down_ResidualBlock(nn.Module):
|
| 416 |
+
|
| 417 |
+
def __init__(self,
|
| 418 |
+
in_dim,
|
| 419 |
+
out_dim,
|
| 420 |
+
dropout,
|
| 421 |
+
mult,
|
| 422 |
+
temperal_downsample=False,
|
| 423 |
+
down_flag=False):
|
| 424 |
+
super().__init__()
|
| 425 |
+
|
| 426 |
+
# Shortcut path with downsample
|
| 427 |
+
self.avg_shortcut = AvgDown3D(
|
| 428 |
+
in_dim,
|
| 429 |
+
out_dim,
|
| 430 |
+
factor_t=2 if temperal_downsample else 1,
|
| 431 |
+
factor_s=2 if down_flag else 1,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Main path with residual blocks and downsample
|
| 435 |
+
downsamples = []
|
| 436 |
+
for _ in range(mult):
|
| 437 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 438 |
+
in_dim = out_dim
|
| 439 |
+
|
| 440 |
+
# Add the final downsample block
|
| 441 |
+
if down_flag:
|
| 442 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 443 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 444 |
+
|
| 445 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 446 |
+
|
| 447 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 448 |
+
x_copy = x.clone()
|
| 449 |
+
for module in self.downsamples:
|
| 450 |
+
x = module(x, feat_cache, feat_idx)
|
| 451 |
+
|
| 452 |
+
return x + self.avg_shortcut(x_copy)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class Up_ResidualBlock(nn.Module):
|
| 456 |
+
|
| 457 |
+
def __init__(self,
|
| 458 |
+
in_dim,
|
| 459 |
+
out_dim,
|
| 460 |
+
dropout,
|
| 461 |
+
mult,
|
| 462 |
+
temperal_upsample=False,
|
| 463 |
+
up_flag=False):
|
| 464 |
+
super().__init__()
|
| 465 |
+
# Shortcut path with upsample
|
| 466 |
+
if up_flag:
|
| 467 |
+
self.avg_shortcut = DupUp3D(
|
| 468 |
+
in_dim,
|
| 469 |
+
out_dim,
|
| 470 |
+
factor_t=2 if temperal_upsample else 1,
|
| 471 |
+
factor_s=2 if up_flag else 1,
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
self.avg_shortcut = None
|
| 475 |
+
|
| 476 |
+
# Main path with residual blocks and upsample
|
| 477 |
+
upsamples = []
|
| 478 |
+
for _ in range(mult):
|
| 479 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 480 |
+
in_dim = out_dim
|
| 481 |
+
|
| 482 |
+
# Add the final upsample block
|
| 483 |
+
if up_flag:
|
| 484 |
+
mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 485 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 486 |
+
|
| 487 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 488 |
+
|
| 489 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 490 |
+
x_main = x.clone()
|
| 491 |
+
for module in self.upsamples:
|
| 492 |
+
x_main = module(x_main, feat_cache, feat_idx)
|
| 493 |
+
if self.avg_shortcut is not None:
|
| 494 |
+
x_shortcut = self.avg_shortcut(x, first_chunk)
|
| 495 |
+
return x_main + x_shortcut
|
| 496 |
+
else:
|
| 497 |
+
return x_main
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class Encoder3d(nn.Module):
|
| 501 |
+
|
| 502 |
+
def __init__(
|
| 503 |
+
self,
|
| 504 |
+
dim=128,
|
| 505 |
+
z_dim=4,
|
| 506 |
+
dim_mult=[1, 2, 4, 4],
|
| 507 |
+
num_res_blocks=2,
|
| 508 |
+
attn_scales=[],
|
| 509 |
+
temperal_downsample=[True, True, False],
|
| 510 |
+
dropout=0.0,
|
| 511 |
+
):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.dim = dim
|
| 514 |
+
self.z_dim = z_dim
|
| 515 |
+
self.dim_mult = dim_mult
|
| 516 |
+
self.num_res_blocks = num_res_blocks
|
| 517 |
+
self.attn_scales = attn_scales
|
| 518 |
+
self.temperal_downsample = temperal_downsample
|
| 519 |
+
|
| 520 |
+
# dimensions
|
| 521 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 522 |
+
scale = 1.0
|
| 523 |
+
|
| 524 |
+
# init block
|
| 525 |
+
self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
|
| 526 |
+
|
| 527 |
+
# downsample blocks
|
| 528 |
+
downsamples = []
|
| 529 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 530 |
+
t_down_flag = (
|
| 531 |
+
temperal_downsample[i]
|
| 532 |
+
if i < len(temperal_downsample) else False)
|
| 533 |
+
downsamples.append(
|
| 534 |
+
Down_ResidualBlock(
|
| 535 |
+
in_dim=in_dim,
|
| 536 |
+
out_dim=out_dim,
|
| 537 |
+
dropout=dropout,
|
| 538 |
+
mult=num_res_blocks,
|
| 539 |
+
temperal_downsample=t_down_flag,
|
| 540 |
+
down_flag=i != len(dim_mult) - 1,
|
| 541 |
+
))
|
| 542 |
+
scale /= 2.0
|
| 543 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 544 |
+
|
| 545 |
+
# middle blocks
|
| 546 |
+
self.middle = nn.Sequential(
|
| 547 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 548 |
+
AttentionBlock(out_dim),
|
| 549 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# # output blocks
|
| 553 |
+
self.head = nn.Sequential(
|
| 554 |
+
RMS_norm(out_dim, images=False),
|
| 555 |
+
nn.SiLU(),
|
| 556 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 560 |
+
|
| 561 |
+
if feat_cache is not None:
|
| 562 |
+
idx = feat_idx[0]
|
| 563 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 564 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 565 |
+
cache_x = torch.cat(
|
| 566 |
+
[
|
| 567 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 568 |
+
cache_x.device),
|
| 569 |
+
cache_x,
|
| 570 |
+
],
|
| 571 |
+
dim=2,
|
| 572 |
+
)
|
| 573 |
+
x = self.conv1(x, feat_cache[idx])
|
| 574 |
+
feat_cache[idx] = cache_x
|
| 575 |
+
feat_idx[0] += 1
|
| 576 |
+
else:
|
| 577 |
+
x = self.conv1(x)
|
| 578 |
+
|
| 579 |
+
## downsamples
|
| 580 |
+
for layer in self.downsamples:
|
| 581 |
+
if feat_cache is not None:
|
| 582 |
+
x = layer(x, feat_cache, feat_idx)
|
| 583 |
+
else:
|
| 584 |
+
x = layer(x)
|
| 585 |
+
|
| 586 |
+
## middle
|
| 587 |
+
for layer in self.middle:
|
| 588 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 589 |
+
x = layer(x, feat_cache, feat_idx)
|
| 590 |
+
else:
|
| 591 |
+
x = layer(x)
|
| 592 |
+
|
| 593 |
+
## head
|
| 594 |
+
for layer in self.head:
|
| 595 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 596 |
+
idx = feat_idx[0]
|
| 597 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 598 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 599 |
+
cache_x = torch.cat(
|
| 600 |
+
[
|
| 601 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 602 |
+
cache_x.device),
|
| 603 |
+
cache_x,
|
| 604 |
+
],
|
| 605 |
+
dim=2,
|
| 606 |
+
)
|
| 607 |
+
x = layer(x, feat_cache[idx])
|
| 608 |
+
feat_cache[idx] = cache_x
|
| 609 |
+
feat_idx[0] += 1
|
| 610 |
+
else:
|
| 611 |
+
x = layer(x)
|
| 612 |
+
|
| 613 |
+
return x
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class Decoder3d(nn.Module):
|
| 617 |
+
|
| 618 |
+
def __init__(
|
| 619 |
+
self,
|
| 620 |
+
dim=128,
|
| 621 |
+
z_dim=4,
|
| 622 |
+
dim_mult=[1, 2, 4, 4],
|
| 623 |
+
num_res_blocks=2,
|
| 624 |
+
attn_scales=[],
|
| 625 |
+
temperal_upsample=[False, True, True],
|
| 626 |
+
dropout=0.0,
|
| 627 |
+
):
|
| 628 |
+
super().__init__()
|
| 629 |
+
self.dim = dim
|
| 630 |
+
self.z_dim = z_dim
|
| 631 |
+
self.dim_mult = dim_mult
|
| 632 |
+
self.num_res_blocks = num_res_blocks
|
| 633 |
+
self.attn_scales = attn_scales
|
| 634 |
+
self.temperal_upsample = temperal_upsample
|
| 635 |
+
|
| 636 |
+
# dimensions
|
| 637 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 638 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 639 |
+
# init block
|
| 640 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 641 |
+
|
| 642 |
+
# middle blocks
|
| 643 |
+
self.middle = nn.Sequential(
|
| 644 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 645 |
+
AttentionBlock(dims[0]),
|
| 646 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# upsample blocks
|
| 650 |
+
upsamples = []
|
| 651 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 652 |
+
t_up_flag = temperal_upsample[i] if i < len(
|
| 653 |
+
temperal_upsample) else False
|
| 654 |
+
upsamples.append(
|
| 655 |
+
Up_ResidualBlock(
|
| 656 |
+
in_dim=in_dim,
|
| 657 |
+
out_dim=out_dim,
|
| 658 |
+
dropout=dropout,
|
| 659 |
+
mult=num_res_blocks + 1,
|
| 660 |
+
temperal_upsample=t_up_flag,
|
| 661 |
+
up_flag=i != len(dim_mult) - 1,
|
| 662 |
+
))
|
| 663 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 664 |
+
|
| 665 |
+
# output blocks
|
| 666 |
+
self.head = nn.Sequential(
|
| 667 |
+
RMS_norm(out_dim, images=False),
|
| 668 |
+
nn.SiLU(),
|
| 669 |
+
CausalConv3d(out_dim, 12, 3, padding=1),
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 673 |
+
if feat_cache is not None:
|
| 674 |
+
idx = feat_idx[0]
|
| 675 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 676 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 677 |
+
cache_x = torch.cat(
|
| 678 |
+
[
|
| 679 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 680 |
+
cache_x.device),
|
| 681 |
+
cache_x,
|
| 682 |
+
],
|
| 683 |
+
dim=2,
|
| 684 |
+
)
|
| 685 |
+
x = self.conv1(x, feat_cache[idx])
|
| 686 |
+
feat_cache[idx] = cache_x
|
| 687 |
+
feat_idx[0] += 1
|
| 688 |
+
else:
|
| 689 |
+
x = self.conv1(x)
|
| 690 |
+
|
| 691 |
+
for layer in self.middle:
|
| 692 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 693 |
+
x = layer(x, feat_cache, feat_idx)
|
| 694 |
+
else:
|
| 695 |
+
x = layer(x)
|
| 696 |
+
|
| 697 |
+
## upsamples
|
| 698 |
+
for layer in self.upsamples:
|
| 699 |
+
if feat_cache is not None:
|
| 700 |
+
x = layer(x, feat_cache, feat_idx, first_chunk)
|
| 701 |
+
else:
|
| 702 |
+
x = layer(x)
|
| 703 |
+
|
| 704 |
+
## head
|
| 705 |
+
for layer in self.head:
|
| 706 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 707 |
+
idx = feat_idx[0]
|
| 708 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 709 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 710 |
+
cache_x = torch.cat(
|
| 711 |
+
[
|
| 712 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 713 |
+
cache_x.device),
|
| 714 |
+
cache_x,
|
| 715 |
+
],
|
| 716 |
+
dim=2,
|
| 717 |
+
)
|
| 718 |
+
x = layer(x, feat_cache[idx])
|
| 719 |
+
feat_cache[idx] = cache_x
|
| 720 |
+
feat_idx[0] += 1
|
| 721 |
+
else:
|
| 722 |
+
x = layer(x)
|
| 723 |
+
return x
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def count_conv3d(model):
|
| 727 |
+
count = 0
|
| 728 |
+
for m in model.modules():
|
| 729 |
+
if isinstance(m, CausalConv3d):
|
| 730 |
+
count += 1
|
| 731 |
+
return count
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
class WanVAE_(nn.Module):
|
| 735 |
+
|
| 736 |
+
def __init__(
|
| 737 |
+
self,
|
| 738 |
+
dim=160,
|
| 739 |
+
dec_dim=256,
|
| 740 |
+
z_dim=16,
|
| 741 |
+
dim_mult=[1, 2, 4, 4],
|
| 742 |
+
num_res_blocks=2,
|
| 743 |
+
attn_scales=[],
|
| 744 |
+
temperal_downsample=[True, True, False],
|
| 745 |
+
dropout=0.0,
|
| 746 |
+
):
|
| 747 |
+
super().__init__()
|
| 748 |
+
self.dim = dim
|
| 749 |
+
self.z_dim = z_dim
|
| 750 |
+
self.dim_mult = dim_mult
|
| 751 |
+
self.num_res_blocks = num_res_blocks
|
| 752 |
+
self.attn_scales = attn_scales
|
| 753 |
+
self.temperal_downsample = temperal_downsample
|
| 754 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 755 |
+
|
| 756 |
+
# modules
|
| 757 |
+
self.encoder = Encoder3d(
|
| 758 |
+
dim,
|
| 759 |
+
z_dim * 2,
|
| 760 |
+
dim_mult,
|
| 761 |
+
num_res_blocks,
|
| 762 |
+
attn_scales,
|
| 763 |
+
self.temperal_downsample,
|
| 764 |
+
dropout,
|
| 765 |
+
)
|
| 766 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 767 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 768 |
+
self.decoder = Decoder3d(
|
| 769 |
+
dec_dim,
|
| 770 |
+
z_dim,
|
| 771 |
+
dim_mult,
|
| 772 |
+
num_res_blocks,
|
| 773 |
+
attn_scales,
|
| 774 |
+
self.temperal_upsample,
|
| 775 |
+
dropout,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
def forward(self, x, scale=[0, 1]):
|
| 779 |
+
mu = self.encode(x, scale)
|
| 780 |
+
x_recon = self.decode(mu, scale)
|
| 781 |
+
return x_recon, mu
|
| 782 |
+
|
| 783 |
+
def encode(self, x, scale):
|
| 784 |
+
self.clear_cache()
|
| 785 |
+
x = patchify(x, patch_size=2)
|
| 786 |
+
t = x.shape[2]
|
| 787 |
+
iter_ = 1 + (t - 1) // 4
|
| 788 |
+
for i in range(iter_):
|
| 789 |
+
self._enc_conv_idx = [0]
|
| 790 |
+
if i == 0:
|
| 791 |
+
out = self.encoder(
|
| 792 |
+
x[:, :, :1, :, :],
|
| 793 |
+
feat_cache=self._enc_feat_map,
|
| 794 |
+
feat_idx=self._enc_conv_idx,
|
| 795 |
+
)
|
| 796 |
+
else:
|
| 797 |
+
out_ = self.encoder(
|
| 798 |
+
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
| 799 |
+
feat_cache=self._enc_feat_map,
|
| 800 |
+
feat_idx=self._enc_conv_idx,
|
| 801 |
+
)
|
| 802 |
+
out = torch.cat([out, out_], 2)
|
| 803 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 804 |
+
if isinstance(scale[0], torch.Tensor):
|
| 805 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 806 |
+
1, self.z_dim, 1, 1, 1)
|
| 807 |
+
else:
|
| 808 |
+
mu = (mu - scale[0]) * scale[1]
|
| 809 |
+
self.clear_cache()
|
| 810 |
+
return mu
|
| 811 |
+
|
| 812 |
+
def decode(self, z, scale):
|
| 813 |
+
self.clear_cache()
|
| 814 |
+
if isinstance(scale[0], torch.Tensor):
|
| 815 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 816 |
+
1, self.z_dim, 1, 1, 1)
|
| 817 |
+
else:
|
| 818 |
+
z = z / scale[1] + scale[0]
|
| 819 |
+
iter_ = z.shape[2]
|
| 820 |
+
x = self.conv2(z)
|
| 821 |
+
for i in range(iter_):
|
| 822 |
+
self._conv_idx = [0]
|
| 823 |
+
if i == 0:
|
| 824 |
+
out = self.decoder(
|
| 825 |
+
x[:, :, i:i + 1, :, :],
|
| 826 |
+
feat_cache=self._feat_map,
|
| 827 |
+
feat_idx=self._conv_idx,
|
| 828 |
+
first_chunk=True,
|
| 829 |
+
)
|
| 830 |
+
else:
|
| 831 |
+
out_ = self.decoder(
|
| 832 |
+
x[:, :, i:i + 1, :, :],
|
| 833 |
+
feat_cache=self._feat_map,
|
| 834 |
+
feat_idx=self._conv_idx,
|
| 835 |
+
)
|
| 836 |
+
out = torch.cat([out, out_], 2)
|
| 837 |
+
out = unpatchify(out, patch_size=2)
|
| 838 |
+
self.clear_cache()
|
| 839 |
+
return out
|
| 840 |
+
|
| 841 |
+
def reparameterize(self, mu, log_var):
|
| 842 |
+
std = torch.exp(0.5 * log_var)
|
| 843 |
+
eps = torch.randn_like(std)
|
| 844 |
+
return eps * std + mu
|
| 845 |
+
|
| 846 |
+
def sample(self, imgs, deterministic=False):
|
| 847 |
+
mu, log_var = self.encode(imgs)
|
| 848 |
+
if deterministic:
|
| 849 |
+
return mu
|
| 850 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 851 |
+
return mu + std * torch.randn_like(std)
|
| 852 |
+
|
| 853 |
+
def clear_cache(self):
|
| 854 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 855 |
+
self._conv_idx = [0]
|
| 856 |
+
self._feat_map = [None] * self._conv_num
|
| 857 |
+
# cache encode
|
| 858 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 859 |
+
self._enc_conv_idx = [0]
|
| 860 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
|
| 864 |
+
# params
|
| 865 |
+
cfg = dict(
|
| 866 |
+
dim=dim,
|
| 867 |
+
z_dim=z_dim,
|
| 868 |
+
dim_mult=[1, 2, 4, 4],
|
| 869 |
+
num_res_blocks=2,
|
| 870 |
+
attn_scales=[],
|
| 871 |
+
temperal_downsample=[True, True, True],
|
| 872 |
+
dropout=0.0,
|
| 873 |
+
)
|
| 874 |
+
cfg.update(**kwargs)
|
| 875 |
+
|
| 876 |
+
# init model
|
| 877 |
+
with torch.device("meta"):
|
| 878 |
+
model = WanVAE_(**cfg)
|
| 879 |
+
|
| 880 |
+
# load checkpoint
|
| 881 |
+
logging.info(f"loading {pretrained_path}")
|
| 882 |
+
model.load_state_dict(
|
| 883 |
+
torch.load(pretrained_path, map_location=device), assign=True)
|
| 884 |
+
|
| 885 |
+
return model
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
class Wan2_2_VAE:
|
| 889 |
+
|
| 890 |
+
def __init__(
|
| 891 |
+
self,
|
| 892 |
+
z_dim=48,
|
| 893 |
+
c_dim=160,
|
| 894 |
+
vae_pth=None,
|
| 895 |
+
dim_mult=[1, 2, 4, 4],
|
| 896 |
+
temperal_downsample=[False, True, True],
|
| 897 |
+
dtype=torch.float,
|
| 898 |
+
device="cuda",
|
| 899 |
+
):
|
| 900 |
+
|
| 901 |
+
self.dtype = dtype
|
| 902 |
+
self.device = device
|
| 903 |
+
|
| 904 |
+
mean = torch.tensor(
|
| 905 |
+
[
|
| 906 |
+
-0.2289,
|
| 907 |
+
-0.0052,
|
| 908 |
+
-0.1323,
|
| 909 |
+
-0.2339,
|
| 910 |
+
-0.2799,
|
| 911 |
+
0.0174,
|
| 912 |
+
0.1838,
|
| 913 |
+
0.1557,
|
| 914 |
+
-0.1382,
|
| 915 |
+
0.0542,
|
| 916 |
+
0.2813,
|
| 917 |
+
0.0891,
|
| 918 |
+
0.1570,
|
| 919 |
+
-0.0098,
|
| 920 |
+
0.0375,
|
| 921 |
+
-0.1825,
|
| 922 |
+
-0.2246,
|
| 923 |
+
-0.1207,
|
| 924 |
+
-0.0698,
|
| 925 |
+
0.5109,
|
| 926 |
+
0.2665,
|
| 927 |
+
-0.2108,
|
| 928 |
+
-0.2158,
|
| 929 |
+
0.2502,
|
| 930 |
+
-0.2055,
|
| 931 |
+
-0.0322,
|
| 932 |
+
0.1109,
|
| 933 |
+
0.1567,
|
| 934 |
+
-0.0729,
|
| 935 |
+
0.0899,
|
| 936 |
+
-0.2799,
|
| 937 |
+
-0.1230,
|
| 938 |
+
-0.0313,
|
| 939 |
+
-0.1649,
|
| 940 |
+
0.0117,
|
| 941 |
+
0.0723,
|
| 942 |
+
-0.2839,
|
| 943 |
+
-0.2083,
|
| 944 |
+
-0.0520,
|
| 945 |
+
0.3748,
|
| 946 |
+
0.0152,
|
| 947 |
+
0.1957,
|
| 948 |
+
0.1433,
|
| 949 |
+
-0.2944,
|
| 950 |
+
0.3573,
|
| 951 |
+
-0.0548,
|
| 952 |
+
-0.1681,
|
| 953 |
+
-0.0667,
|
| 954 |
+
],
|
| 955 |
+
dtype=dtype,
|
| 956 |
+
device=device,
|
| 957 |
+
)
|
| 958 |
+
std = torch.tensor(
|
| 959 |
+
[
|
| 960 |
+
0.4765,
|
| 961 |
+
1.0364,
|
| 962 |
+
0.4514,
|
| 963 |
+
1.1677,
|
| 964 |
+
0.5313,
|
| 965 |
+
0.4990,
|
| 966 |
+
0.4818,
|
| 967 |
+
0.5013,
|
| 968 |
+
0.8158,
|
| 969 |
+
1.0344,
|
| 970 |
+
0.5894,
|
| 971 |
+
1.0901,
|
| 972 |
+
0.6885,
|
| 973 |
+
0.6165,
|
| 974 |
+
0.8454,
|
| 975 |
+
0.4978,
|
| 976 |
+
0.5759,
|
| 977 |
+
0.3523,
|
| 978 |
+
0.7135,
|
| 979 |
+
0.6804,
|
| 980 |
+
0.5833,
|
| 981 |
+
1.4146,
|
| 982 |
+
0.8986,
|
| 983 |
+
0.5659,
|
| 984 |
+
0.7069,
|
| 985 |
+
0.5338,
|
| 986 |
+
0.4889,
|
| 987 |
+
0.4917,
|
| 988 |
+
0.4069,
|
| 989 |
+
0.4999,
|
| 990 |
+
0.6866,
|
| 991 |
+
0.4093,
|
| 992 |
+
0.5709,
|
| 993 |
+
0.6065,
|
| 994 |
+
0.6415,
|
| 995 |
+
0.4944,
|
| 996 |
+
0.5726,
|
| 997 |
+
1.2042,
|
| 998 |
+
0.5458,
|
| 999 |
+
1.6887,
|
| 1000 |
+
0.3971,
|
| 1001 |
+
1.0600,
|
| 1002 |
+
0.3943,
|
| 1003 |
+
0.5537,
|
| 1004 |
+
0.5444,
|
| 1005 |
+
0.4089,
|
| 1006 |
+
0.7468,
|
| 1007 |
+
0.7744,
|
| 1008 |
+
],
|
| 1009 |
+
dtype=dtype,
|
| 1010 |
+
device=device,
|
| 1011 |
+
)
|
| 1012 |
+
self.scale = [mean, 1.0 / std]
|
| 1013 |
+
|
| 1014 |
+
# init model
|
| 1015 |
+
self.model = (
|
| 1016 |
+
_video_vae(
|
| 1017 |
+
pretrained_path=vae_pth,
|
| 1018 |
+
z_dim=z_dim,
|
| 1019 |
+
dim=c_dim,
|
| 1020 |
+
dim_mult=dim_mult,
|
| 1021 |
+
temperal_downsample=temperal_downsample,
|
| 1022 |
+
).eval().requires_grad_(False).to(device))
|
| 1023 |
+
|
| 1024 |
+
def encode(self, videos):
|
| 1025 |
+
try:
|
| 1026 |
+
if not isinstance(videos, list):
|
| 1027 |
+
raise TypeError("videos should be a list")
|
| 1028 |
+
with amp.autocast(dtype=self.dtype):
|
| 1029 |
+
return [
|
| 1030 |
+
self.model.encode(u.unsqueeze(0),
|
| 1031 |
+
self.scale).float().squeeze(0)
|
| 1032 |
+
for u in videos
|
| 1033 |
+
]
|
| 1034 |
+
except TypeError as e:
|
| 1035 |
+
logging.info(e)
|
| 1036 |
+
return None
|
| 1037 |
+
|
| 1038 |
+
def decode(self, zs):
|
| 1039 |
+
try:
|
| 1040 |
+
if not isinstance(zs, list):
|
| 1041 |
+
raise TypeError("zs should be a list")
|
| 1042 |
+
with amp.autocast(dtype=self.dtype):
|
| 1043 |
+
return [
|
| 1044 |
+
self.model.decode(u.unsqueeze(0),
|
| 1045 |
+
self.scale).float().clamp_(-1,
|
| 1046 |
+
1).squeeze(0)
|
| 1047 |
+
for u in zs
|
| 1048 |
+
]
|
| 1049 |
+
except TypeError as e:
|
| 1050 |
+
logging.info(e)
|
| 1051 |
+
return None
|
wan/text2video.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.cuda.amp as amp
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
from .distributed.fsdp import shard_model
|
| 18 |
+
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
|
| 19 |
+
from .distributed.util import get_world_size
|
| 20 |
+
from .modules.model import WanModel
|
| 21 |
+
from .modules.t5 import T5EncoderModel
|
| 22 |
+
from .modules.vae2_1 import Wan2_1_VAE
|
| 23 |
+
from .utils.fm_solvers import (
|
| 24 |
+
FlowDPMSolverMultistepScheduler,
|
| 25 |
+
get_sampling_sigmas,
|
| 26 |
+
retrieve_timesteps,
|
| 27 |
+
)
|
| 28 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class WanT2V:
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
config,
|
| 36 |
+
checkpoint_dir,
|
| 37 |
+
device_id=0,
|
| 38 |
+
rank=0,
|
| 39 |
+
t5_fsdp=False,
|
| 40 |
+
dit_fsdp=False,
|
| 41 |
+
use_sp=False,
|
| 42 |
+
t5_cpu=False,
|
| 43 |
+
init_on_cpu=True,
|
| 44 |
+
convert_model_dtype=False,
|
| 45 |
+
):
|
| 46 |
+
r"""
|
| 47 |
+
Initializes the Wan text-to-video generation model components.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
config (EasyDict):
|
| 51 |
+
Object containing model parameters initialized from config.py
|
| 52 |
+
checkpoint_dir (`str`):
|
| 53 |
+
Path to directory containing model checkpoints
|
| 54 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 55 |
+
Id of target GPU device
|
| 56 |
+
rank (`int`, *optional*, defaults to 0):
|
| 57 |
+
Process rank for distributed training
|
| 58 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 59 |
+
Enable FSDP sharding for T5 model
|
| 60 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 61 |
+
Enable FSDP sharding for DiT model
|
| 62 |
+
use_sp (`bool`, *optional*, defaults to False):
|
| 63 |
+
Enable distribution strategy of sequence parallel.
|
| 64 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 65 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 66 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 67 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 68 |
+
convert_model_dtype (`bool`, *optional*, defaults to False):
|
| 69 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 70 |
+
Only works without FSDP.
|
| 71 |
+
"""
|
| 72 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 73 |
+
self.config = config
|
| 74 |
+
self.rank = rank
|
| 75 |
+
self.t5_cpu = t5_cpu
|
| 76 |
+
self.init_on_cpu = init_on_cpu
|
| 77 |
+
|
| 78 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 79 |
+
self.boundary = config.boundary
|
| 80 |
+
self.param_dtype = config.param_dtype
|
| 81 |
+
|
| 82 |
+
if t5_fsdp or dit_fsdp or use_sp:
|
| 83 |
+
self.init_on_cpu = False
|
| 84 |
+
|
| 85 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 86 |
+
self.text_encoder = T5EncoderModel(
|
| 87 |
+
text_len=config.text_len,
|
| 88 |
+
dtype=config.t5_dtype,
|
| 89 |
+
device=torch.device('cpu'),
|
| 90 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 91 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 92 |
+
shard_fn=shard_fn if t5_fsdp else None)
|
| 93 |
+
|
| 94 |
+
self.vae_stride = config.vae_stride
|
| 95 |
+
self.patch_size = config.patch_size
|
| 96 |
+
self.vae = Wan2_1_VAE(
|
| 97 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 98 |
+
device=self.device)
|
| 99 |
+
|
| 100 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 101 |
+
self.low_noise_model = WanModel.from_pretrained(
|
| 102 |
+
checkpoint_dir, subfolder=config.low_noise_checkpoint)
|
| 103 |
+
self.low_noise_model = self._configure_model(
|
| 104 |
+
model=self.low_noise_model,
|
| 105 |
+
use_sp=use_sp,
|
| 106 |
+
dit_fsdp=dit_fsdp,
|
| 107 |
+
shard_fn=shard_fn,
|
| 108 |
+
convert_model_dtype=convert_model_dtype)
|
| 109 |
+
|
| 110 |
+
self.high_noise_model = WanModel.from_pretrained(
|
| 111 |
+
checkpoint_dir, subfolder=config.high_noise_checkpoint)
|
| 112 |
+
self.high_noise_model = self._configure_model(
|
| 113 |
+
model=self.high_noise_model,
|
| 114 |
+
use_sp=use_sp,
|
| 115 |
+
dit_fsdp=dit_fsdp,
|
| 116 |
+
shard_fn=shard_fn,
|
| 117 |
+
convert_model_dtype=convert_model_dtype)
|
| 118 |
+
if use_sp:
|
| 119 |
+
self.sp_size = get_world_size()
|
| 120 |
+
else:
|
| 121 |
+
self.sp_size = 1
|
| 122 |
+
|
| 123 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 124 |
+
|
| 125 |
+
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
|
| 126 |
+
convert_model_dtype):
|
| 127 |
+
"""
|
| 128 |
+
Configures a model object. This includes setting evaluation modes,
|
| 129 |
+
applying distributed parallel strategy, and handling device placement.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
model (torch.nn.Module):
|
| 133 |
+
The model instance to configure.
|
| 134 |
+
use_sp (`bool`):
|
| 135 |
+
Enable distribution strategy of sequence parallel.
|
| 136 |
+
dit_fsdp (`bool`):
|
| 137 |
+
Enable FSDP sharding for DiT model.
|
| 138 |
+
shard_fn (callable):
|
| 139 |
+
The function to apply FSDP sharding.
|
| 140 |
+
convert_model_dtype (`bool`):
|
| 141 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 142 |
+
Only works without FSDP.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
torch.nn.Module:
|
| 146 |
+
The configured model.
|
| 147 |
+
"""
|
| 148 |
+
model.eval().requires_grad_(False)
|
| 149 |
+
|
| 150 |
+
if use_sp:
|
| 151 |
+
for block in model.blocks:
|
| 152 |
+
block.self_attn.forward = types.MethodType(
|
| 153 |
+
sp_attn_forward, block.self_attn)
|
| 154 |
+
model.forward = types.MethodType(sp_dit_forward, model)
|
| 155 |
+
|
| 156 |
+
if dist.is_initialized():
|
| 157 |
+
dist.barrier()
|
| 158 |
+
|
| 159 |
+
if dit_fsdp:
|
| 160 |
+
model = shard_fn(model)
|
| 161 |
+
else:
|
| 162 |
+
if convert_model_dtype:
|
| 163 |
+
model.to(self.param_dtype)
|
| 164 |
+
if not self.init_on_cpu:
|
| 165 |
+
model.to(self.device)
|
| 166 |
+
|
| 167 |
+
return model
|
| 168 |
+
|
| 169 |
+
def _prepare_model_for_timestep(self, t, boundary, offload_model):
|
| 170 |
+
r"""
|
| 171 |
+
Prepares and returns the required model for the current timestep.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
t (torch.Tensor):
|
| 175 |
+
current timestep.
|
| 176 |
+
boundary (`int`):
|
| 177 |
+
The timestep threshold. If `t` is at or above this value,
|
| 178 |
+
the `high_noise_model` is considered as the required model.
|
| 179 |
+
offload_model (`bool`):
|
| 180 |
+
A flag intended to control the offloading behavior.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
torch.nn.Module:
|
| 184 |
+
The active model on the target device for the current timestep.
|
| 185 |
+
"""
|
| 186 |
+
if t.item() >= boundary:
|
| 187 |
+
required_model_name = 'high_noise_model'
|
| 188 |
+
offload_model_name = 'low_noise_model'
|
| 189 |
+
else:
|
| 190 |
+
required_model_name = 'low_noise_model'
|
| 191 |
+
offload_model_name = 'high_noise_model'
|
| 192 |
+
if offload_model or self.init_on_cpu:
|
| 193 |
+
if next(getattr(
|
| 194 |
+
self,
|
| 195 |
+
offload_model_name).parameters()).device.type == 'cuda':
|
| 196 |
+
getattr(self, offload_model_name).to('cpu')
|
| 197 |
+
if next(getattr(
|
| 198 |
+
self,
|
| 199 |
+
required_model_name).parameters()).device.type == 'cpu':
|
| 200 |
+
getattr(self, required_model_name).to(self.device)
|
| 201 |
+
return getattr(self, required_model_name)
|
| 202 |
+
|
| 203 |
+
def generate(self,
|
| 204 |
+
input_prompt,
|
| 205 |
+
size=(1280, 720),
|
| 206 |
+
frame_num=81,
|
| 207 |
+
shift=5.0,
|
| 208 |
+
sample_solver='unipc',
|
| 209 |
+
sampling_steps=50,
|
| 210 |
+
guide_scale=5.0,
|
| 211 |
+
n_prompt="",
|
| 212 |
+
seed=-1,
|
| 213 |
+
offload_model=True):
|
| 214 |
+
r"""
|
| 215 |
+
Generates video frames from text prompt using diffusion process.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
input_prompt (`str`):
|
| 219 |
+
Text prompt for content generation
|
| 220 |
+
size (`tuple[int]`, *optional*, defaults to (1280,720)):
|
| 221 |
+
Controls video resolution, (width,height).
|
| 222 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 223 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 224 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 225 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 226 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 227 |
+
Solver used to sample the video.
|
| 228 |
+
sampling_steps (`int`, *optional*, defaults to 50):
|
| 229 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 230 |
+
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
|
| 231 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 232 |
+
If tuple, the first guide_scale will be used for low noise model and
|
| 233 |
+
the second guide_scale will be used for high noise model.
|
| 234 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 235 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 236 |
+
seed (`int`, *optional*, defaults to -1):
|
| 237 |
+
Random seed for noise generation. If -1, use random seed.
|
| 238 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 239 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
torch.Tensor:
|
| 243 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 244 |
+
- C: Color channels (3 for RGB)
|
| 245 |
+
- N: Number of frames (81)
|
| 246 |
+
- H: Frame height (from size)
|
| 247 |
+
- W: Frame width from size)
|
| 248 |
+
"""
|
| 249 |
+
# preprocess
|
| 250 |
+
guide_scale = (guide_scale, guide_scale) if isinstance(
|
| 251 |
+
guide_scale, float) else guide_scale
|
| 252 |
+
F = frame_num
|
| 253 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 254 |
+
size[1] // self.vae_stride[1],
|
| 255 |
+
size[0] // self.vae_stride[2])
|
| 256 |
+
|
| 257 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 258 |
+
(self.patch_size[1] * self.patch_size[2]) *
|
| 259 |
+
target_shape[1] / self.sp_size) * self.sp_size
|
| 260 |
+
|
| 261 |
+
if n_prompt == "":
|
| 262 |
+
n_prompt = self.sample_neg_prompt
|
| 263 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 264 |
+
seed_g = torch.Generator(device=self.device)
|
| 265 |
+
seed_g.manual_seed(seed)
|
| 266 |
+
|
| 267 |
+
if not self.t5_cpu:
|
| 268 |
+
self.text_encoder.model.to(self.device)
|
| 269 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 270 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 271 |
+
if offload_model:
|
| 272 |
+
self.text_encoder.model.cpu()
|
| 273 |
+
else:
|
| 274 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 275 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 276 |
+
context = [t.to(self.device) for t in context]
|
| 277 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 278 |
+
|
| 279 |
+
noise = [
|
| 280 |
+
torch.randn(
|
| 281 |
+
target_shape[0],
|
| 282 |
+
target_shape[1],
|
| 283 |
+
target_shape[2],
|
| 284 |
+
target_shape[3],
|
| 285 |
+
dtype=torch.float32,
|
| 286 |
+
device=self.device,
|
| 287 |
+
generator=seed_g)
|
| 288 |
+
]
|
| 289 |
+
|
| 290 |
+
@contextmanager
|
| 291 |
+
def noop_no_sync():
|
| 292 |
+
yield
|
| 293 |
+
|
| 294 |
+
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
|
| 295 |
+
noop_no_sync)
|
| 296 |
+
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
|
| 297 |
+
noop_no_sync)
|
| 298 |
+
|
| 299 |
+
# evaluation mode
|
| 300 |
+
with (
|
| 301 |
+
torch.amp.autocast('cuda', dtype=self.param_dtype),
|
| 302 |
+
torch.no_grad(),
|
| 303 |
+
no_sync_low_noise(),
|
| 304 |
+
no_sync_high_noise(),
|
| 305 |
+
):
|
| 306 |
+
boundary = self.boundary * self.num_train_timesteps
|
| 307 |
+
|
| 308 |
+
if sample_solver == 'unipc':
|
| 309 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 310 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 311 |
+
shift=1,
|
| 312 |
+
use_dynamic_shifting=False)
|
| 313 |
+
sample_scheduler.set_timesteps(
|
| 314 |
+
sampling_steps, device=self.device, shift=shift)
|
| 315 |
+
timesteps = sample_scheduler.timesteps
|
| 316 |
+
elif sample_solver == 'dpm++':
|
| 317 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 318 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 319 |
+
shift=1,
|
| 320 |
+
use_dynamic_shifting=False)
|
| 321 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 322 |
+
timesteps, _ = retrieve_timesteps(
|
| 323 |
+
sample_scheduler,
|
| 324 |
+
device=self.device,
|
| 325 |
+
sigmas=sampling_sigmas)
|
| 326 |
+
else:
|
| 327 |
+
raise NotImplementedError("Unsupported solver.")
|
| 328 |
+
|
| 329 |
+
# sample videos
|
| 330 |
+
latents = noise
|
| 331 |
+
|
| 332 |
+
arg_c = {'context': context, 'seq_len': seq_len}
|
| 333 |
+
arg_null = {'context': context_null, 'seq_len': seq_len}
|
| 334 |
+
|
| 335 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 336 |
+
latent_model_input = latents
|
| 337 |
+
timestep = [t]
|
| 338 |
+
|
| 339 |
+
timestep = torch.stack(timestep)
|
| 340 |
+
|
| 341 |
+
model = self._prepare_model_for_timestep(
|
| 342 |
+
t, boundary, offload_model)
|
| 343 |
+
sample_guide_scale = guide_scale[1] if t.item(
|
| 344 |
+
) >= boundary else guide_scale[0]
|
| 345 |
+
|
| 346 |
+
noise_pred_cond = model(
|
| 347 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
| 348 |
+
noise_pred_uncond = model(
|
| 349 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
| 350 |
+
|
| 351 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (
|
| 352 |
+
noise_pred_cond - noise_pred_uncond)
|
| 353 |
+
|
| 354 |
+
temp_x0 = sample_scheduler.step(
|
| 355 |
+
noise_pred.unsqueeze(0),
|
| 356 |
+
t,
|
| 357 |
+
latents[0].unsqueeze(0),
|
| 358 |
+
return_dict=False,
|
| 359 |
+
generator=seed_g)[0]
|
| 360 |
+
latents = [temp_x0.squeeze(0)]
|
| 361 |
+
|
| 362 |
+
x0 = latents
|
| 363 |
+
if offload_model:
|
| 364 |
+
self.low_noise_model.cpu()
|
| 365 |
+
self.high_noise_model.cpu()
|
| 366 |
+
torch.cuda.empty_cache()
|
| 367 |
+
if self.rank == 0:
|
| 368 |
+
videos = self.vae.decode(x0)
|
| 369 |
+
|
| 370 |
+
del noise, latents
|
| 371 |
+
del sample_scheduler
|
| 372 |
+
if offload_model:
|
| 373 |
+
gc.collect()
|
| 374 |
+
torch.cuda.synchronize()
|
| 375 |
+
if dist.is_initialized():
|
| 376 |
+
dist.barrier()
|
| 377 |
+
|
| 378 |
+
return videos[0] if self.rank == 0 else None
|
wan/textimage2video.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.cuda.amp as amp
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
import torchvision.transforms.functional as TF
|
| 16 |
+
from PIL import Image
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
|
| 21 |
+
from .distributed.util import get_world_size
|
| 22 |
+
from .modules.model import WanModel
|
| 23 |
+
from .modules.t5 import T5EncoderModel
|
| 24 |
+
from .modules.vae2_2 import Wan2_2_VAE
|
| 25 |
+
from .utils.fm_solvers import (
|
| 26 |
+
FlowDPMSolverMultistepScheduler,
|
| 27 |
+
get_sampling_sigmas,
|
| 28 |
+
retrieve_timesteps,
|
| 29 |
+
)
|
| 30 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 31 |
+
from .utils.utils import best_output_size, masks_like
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class WanTI2V:
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
config,
|
| 39 |
+
checkpoint_dir,
|
| 40 |
+
device_id=0,
|
| 41 |
+
rank=0,
|
| 42 |
+
t5_fsdp=False,
|
| 43 |
+
dit_fsdp=False,
|
| 44 |
+
use_sp=False,
|
| 45 |
+
t5_cpu=False,
|
| 46 |
+
init_on_cpu=True,
|
| 47 |
+
convert_model_dtype=False,
|
| 48 |
+
):
|
| 49 |
+
r"""
|
| 50 |
+
Initializes the Wan text-to-video generation model components.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
config (EasyDict):
|
| 54 |
+
Object containing model parameters initialized from config.py
|
| 55 |
+
checkpoint_dir (`str`):
|
| 56 |
+
Path to directory containing model checkpoints
|
| 57 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 58 |
+
Id of target GPU device
|
| 59 |
+
rank (`int`, *optional*, defaults to 0):
|
| 60 |
+
Process rank for distributed training
|
| 61 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 62 |
+
Enable FSDP sharding for T5 model
|
| 63 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 64 |
+
Enable FSDP sharding for DiT model
|
| 65 |
+
use_sp (`bool`, *optional*, defaults to False):
|
| 66 |
+
Enable distribution strategy of sequence parallel.
|
| 67 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 68 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 69 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 70 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 71 |
+
convert_model_dtype (`bool`, *optional*, defaults to False):
|
| 72 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 73 |
+
Only works without FSDP.
|
| 74 |
+
"""
|
| 75 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 76 |
+
self.config = config
|
| 77 |
+
self.rank = rank
|
| 78 |
+
self.t5_cpu = t5_cpu
|
| 79 |
+
self.init_on_cpu = init_on_cpu
|
| 80 |
+
|
| 81 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 82 |
+
self.param_dtype = config.param_dtype
|
| 83 |
+
|
| 84 |
+
if t5_fsdp or dit_fsdp or use_sp:
|
| 85 |
+
self.init_on_cpu = False
|
| 86 |
+
|
| 87 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 88 |
+
self.text_encoder = T5EncoderModel(
|
| 89 |
+
text_len=config.text_len,
|
| 90 |
+
dtype=config.t5_dtype,
|
| 91 |
+
device=torch.device('cpu'),
|
| 92 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 93 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 94 |
+
shard_fn=shard_fn if t5_fsdp else None)
|
| 95 |
+
|
| 96 |
+
self.vae_stride = config.vae_stride
|
| 97 |
+
self.patch_size = config.patch_size
|
| 98 |
+
self.vae = Wan2_2_VAE(
|
| 99 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 100 |
+
device=self.device)
|
| 101 |
+
|
| 102 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 103 |
+
self.model = WanModel.from_pretrained(checkpoint_dir)
|
| 104 |
+
self.model = self._configure_model(
|
| 105 |
+
model=self.model,
|
| 106 |
+
use_sp=use_sp,
|
| 107 |
+
dit_fsdp=dit_fsdp,
|
| 108 |
+
shard_fn=shard_fn,
|
| 109 |
+
convert_model_dtype=convert_model_dtype)
|
| 110 |
+
|
| 111 |
+
if use_sp:
|
| 112 |
+
self.sp_size = get_world_size()
|
| 113 |
+
else:
|
| 114 |
+
self.sp_size = 1
|
| 115 |
+
|
| 116 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 117 |
+
|
| 118 |
+
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
|
| 119 |
+
convert_model_dtype):
|
| 120 |
+
"""
|
| 121 |
+
Configures a model object. This includes setting evaluation modes,
|
| 122 |
+
applying distributed parallel strategy, and handling device placement.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
model (torch.nn.Module):
|
| 126 |
+
The model instance to configure.
|
| 127 |
+
use_sp (`bool`):
|
| 128 |
+
Enable distribution strategy of sequence parallel.
|
| 129 |
+
dit_fsdp (`bool`):
|
| 130 |
+
Enable FSDP sharding for DiT model.
|
| 131 |
+
shard_fn (callable):
|
| 132 |
+
The function to apply FSDP sharding.
|
| 133 |
+
convert_model_dtype (`bool`):
|
| 134 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 135 |
+
Only works without FSDP.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
torch.nn.Module:
|
| 139 |
+
The configured model.
|
| 140 |
+
"""
|
| 141 |
+
model.eval().requires_grad_(False)
|
| 142 |
+
|
| 143 |
+
if use_sp:
|
| 144 |
+
for block in model.blocks:
|
| 145 |
+
block.self_attn.forward = types.MethodType(
|
| 146 |
+
sp_attn_forward, block.self_attn)
|
| 147 |
+
model.forward = types.MethodType(sp_dit_forward, model)
|
| 148 |
+
|
| 149 |
+
if dist.is_initialized():
|
| 150 |
+
dist.barrier()
|
| 151 |
+
|
| 152 |
+
if dit_fsdp:
|
| 153 |
+
model = shard_fn(model)
|
| 154 |
+
else:
|
| 155 |
+
if convert_model_dtype:
|
| 156 |
+
model.to(self.param_dtype)
|
| 157 |
+
if not self.init_on_cpu:
|
| 158 |
+
model.to(self.device)
|
| 159 |
+
|
| 160 |
+
return model
|
| 161 |
+
|
| 162 |
+
def generate(self,
|
| 163 |
+
input_prompt,
|
| 164 |
+
img=None,
|
| 165 |
+
size=(1280, 704),
|
| 166 |
+
max_area=704 * 1280,
|
| 167 |
+
frame_num=81,
|
| 168 |
+
shift=5.0,
|
| 169 |
+
sample_solver='unipc',
|
| 170 |
+
sampling_steps=50,
|
| 171 |
+
guide_scale=5.0,
|
| 172 |
+
n_prompt="",
|
| 173 |
+
seed=-1,
|
| 174 |
+
offload_model=True):
|
| 175 |
+
r"""
|
| 176 |
+
Generates video frames from text prompt using diffusion process.
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
input_prompt (`str`):
|
| 180 |
+
Text prompt for content generation
|
| 181 |
+
img (PIL.Image.Image):
|
| 182 |
+
Input image tensor. Shape: [3, H, W]
|
| 183 |
+
size (`tuple[int]`, *optional*, defaults to (1280,704)):
|
| 184 |
+
Controls video resolution, (width,height).
|
| 185 |
+
max_area (`int`, *optional*, defaults to 704*1280):
|
| 186 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 187 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 188 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 189 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 190 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 191 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 192 |
+
Solver used to sample the video.
|
| 193 |
+
sampling_steps (`int`, *optional*, defaults to 50):
|
| 194 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 195 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 196 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 197 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 198 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 199 |
+
seed (`int`, *optional*, defaults to -1):
|
| 200 |
+
Random seed for noise generation. If -1, use random seed.
|
| 201 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 202 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
torch.Tensor:
|
| 206 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 207 |
+
- C: Color channels (3 for RGB)
|
| 208 |
+
- N: Number of frames (81)
|
| 209 |
+
- H: Frame height (from size)
|
| 210 |
+
- W: Frame width from size)
|
| 211 |
+
"""
|
| 212 |
+
# i2v
|
| 213 |
+
if img is not None:
|
| 214 |
+
return self.i2v(
|
| 215 |
+
input_prompt=input_prompt,
|
| 216 |
+
img=img,
|
| 217 |
+
max_area=max_area,
|
| 218 |
+
frame_num=frame_num,
|
| 219 |
+
shift=shift,
|
| 220 |
+
sample_solver=sample_solver,
|
| 221 |
+
sampling_steps=sampling_steps,
|
| 222 |
+
guide_scale=guide_scale,
|
| 223 |
+
n_prompt=n_prompt,
|
| 224 |
+
seed=seed,
|
| 225 |
+
offload_model=offload_model)
|
| 226 |
+
# t2v
|
| 227 |
+
return self.t2v(
|
| 228 |
+
input_prompt=input_prompt,
|
| 229 |
+
size=size,
|
| 230 |
+
frame_num=frame_num,
|
| 231 |
+
shift=shift,
|
| 232 |
+
sample_solver=sample_solver,
|
| 233 |
+
sampling_steps=sampling_steps,
|
| 234 |
+
guide_scale=guide_scale,
|
| 235 |
+
n_prompt=n_prompt,
|
| 236 |
+
seed=seed,
|
| 237 |
+
offload_model=offload_model)
|
| 238 |
+
|
| 239 |
+
def t2v(self,
|
| 240 |
+
input_prompt,
|
| 241 |
+
size=(1280, 704),
|
| 242 |
+
frame_num=121,
|
| 243 |
+
shift=5.0,
|
| 244 |
+
sample_solver='unipc',
|
| 245 |
+
sampling_steps=50,
|
| 246 |
+
guide_scale=5.0,
|
| 247 |
+
n_prompt="",
|
| 248 |
+
seed=-1,
|
| 249 |
+
offload_model=True):
|
| 250 |
+
r"""
|
| 251 |
+
Generates video frames from text prompt using diffusion process.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
input_prompt (`str`):
|
| 255 |
+
Text prompt for content generation
|
| 256 |
+
size (`tuple[int]`, *optional*, defaults to (1280,704)):
|
| 257 |
+
Controls video resolution, (width,height).
|
| 258 |
+
frame_num (`int`, *optional*, defaults to 121):
|
| 259 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 260 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 261 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 262 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 263 |
+
Solver used to sample the video.
|
| 264 |
+
sampling_steps (`int`, *optional*, defaults to 50):
|
| 265 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 266 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 267 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 268 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 269 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 270 |
+
seed (`int`, *optional*, defaults to -1):
|
| 271 |
+
Random seed for noise generation. If -1, use random seed.
|
| 272 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 273 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
torch.Tensor:
|
| 277 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 278 |
+
- C: Color channels (3 for RGB)
|
| 279 |
+
- N: Number of frames (81)
|
| 280 |
+
- H: Frame height (from size)
|
| 281 |
+
- W: Frame width from size)
|
| 282 |
+
"""
|
| 283 |
+
# preprocess
|
| 284 |
+
F = frame_num
|
| 285 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 286 |
+
size[1] // self.vae_stride[1],
|
| 287 |
+
size[0] // self.vae_stride[2])
|
| 288 |
+
|
| 289 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 290 |
+
(self.patch_size[1] * self.patch_size[2]) *
|
| 291 |
+
target_shape[1] / self.sp_size) * self.sp_size
|
| 292 |
+
|
| 293 |
+
if n_prompt == "":
|
| 294 |
+
n_prompt = self.sample_neg_prompt
|
| 295 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 296 |
+
seed_g = torch.Generator(device=self.device)
|
| 297 |
+
seed_g.manual_seed(seed)
|
| 298 |
+
|
| 299 |
+
if not self.t5_cpu:
|
| 300 |
+
self.text_encoder.model.to(self.device)
|
| 301 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 302 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 303 |
+
if offload_model:
|
| 304 |
+
self.text_encoder.model.cpu()
|
| 305 |
+
else:
|
| 306 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 307 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 308 |
+
context = [t.to(self.device) for t in context]
|
| 309 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 310 |
+
|
| 311 |
+
noise = [
|
| 312 |
+
torch.randn(
|
| 313 |
+
target_shape[0],
|
| 314 |
+
target_shape[1],
|
| 315 |
+
target_shape[2],
|
| 316 |
+
target_shape[3],
|
| 317 |
+
dtype=torch.float32,
|
| 318 |
+
device=self.device,
|
| 319 |
+
generator=seed_g)
|
| 320 |
+
]
|
| 321 |
+
|
| 322 |
+
@contextmanager
|
| 323 |
+
def noop_no_sync():
|
| 324 |
+
yield
|
| 325 |
+
|
| 326 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 327 |
+
|
| 328 |
+
# evaluation mode
|
| 329 |
+
with (
|
| 330 |
+
torch.amp.autocast('cuda', dtype=self.param_dtype),
|
| 331 |
+
torch.no_grad(),
|
| 332 |
+
no_sync(),
|
| 333 |
+
):
|
| 334 |
+
|
| 335 |
+
if sample_solver == 'unipc':
|
| 336 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 337 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 338 |
+
shift=1,
|
| 339 |
+
use_dynamic_shifting=False)
|
| 340 |
+
sample_scheduler.set_timesteps(
|
| 341 |
+
sampling_steps, device=self.device, shift=shift)
|
| 342 |
+
timesteps = sample_scheduler.timesteps
|
| 343 |
+
elif sample_solver == 'dpm++':
|
| 344 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 345 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 346 |
+
shift=1,
|
| 347 |
+
use_dynamic_shifting=False)
|
| 348 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 349 |
+
timesteps, _ = retrieve_timesteps(
|
| 350 |
+
sample_scheduler,
|
| 351 |
+
device=self.device,
|
| 352 |
+
sigmas=sampling_sigmas)
|
| 353 |
+
else:
|
| 354 |
+
raise NotImplementedError("Unsupported solver.")
|
| 355 |
+
|
| 356 |
+
# sample videos
|
| 357 |
+
latents = noise
|
| 358 |
+
mask1, mask2 = masks_like(noise, zero=False)
|
| 359 |
+
|
| 360 |
+
arg_c = {'context': context, 'seq_len': seq_len}
|
| 361 |
+
arg_null = {'context': context_null, 'seq_len': seq_len}
|
| 362 |
+
|
| 363 |
+
if offload_model or self.init_on_cpu:
|
| 364 |
+
self.model.to(self.device)
|
| 365 |
+
torch.cuda.empty_cache()
|
| 366 |
+
|
| 367 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 368 |
+
latent_model_input = latents
|
| 369 |
+
timestep = [t]
|
| 370 |
+
|
| 371 |
+
timestep = torch.stack(timestep)
|
| 372 |
+
|
| 373 |
+
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
|
| 374 |
+
temp_ts = torch.cat([
|
| 375 |
+
temp_ts,
|
| 376 |
+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
|
| 377 |
+
])
|
| 378 |
+
timestep = temp_ts.unsqueeze(0)
|
| 379 |
+
|
| 380 |
+
noise_pred_cond = self.model(
|
| 381 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
| 382 |
+
noise_pred_uncond = self.model(
|
| 383 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
| 384 |
+
|
| 385 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 386 |
+
noise_pred_cond - noise_pred_uncond)
|
| 387 |
+
|
| 388 |
+
temp_x0 = sample_scheduler.step(
|
| 389 |
+
noise_pred.unsqueeze(0),
|
| 390 |
+
t,
|
| 391 |
+
latents[0].unsqueeze(0),
|
| 392 |
+
return_dict=False,
|
| 393 |
+
generator=seed_g)[0]
|
| 394 |
+
latents = [temp_x0.squeeze(0)]
|
| 395 |
+
x0 = latents
|
| 396 |
+
if offload_model:
|
| 397 |
+
self.model.cpu()
|
| 398 |
+
torch.cuda.synchronize()
|
| 399 |
+
torch.cuda.empty_cache()
|
| 400 |
+
if self.rank == 0:
|
| 401 |
+
videos = self.vae.decode(x0)
|
| 402 |
+
|
| 403 |
+
del noise, latents
|
| 404 |
+
del sample_scheduler
|
| 405 |
+
if offload_model:
|
| 406 |
+
gc.collect()
|
| 407 |
+
torch.cuda.synchronize()
|
| 408 |
+
if dist.is_initialized():
|
| 409 |
+
dist.barrier()
|
| 410 |
+
|
| 411 |
+
return videos[0] if self.rank == 0 else None
|
| 412 |
+
|
| 413 |
+
def i2v(self,
|
| 414 |
+
input_prompt,
|
| 415 |
+
img,
|
| 416 |
+
max_area=704 * 1280,
|
| 417 |
+
frame_num=121,
|
| 418 |
+
shift=5.0,
|
| 419 |
+
sample_solver='unipc',
|
| 420 |
+
sampling_steps=40,
|
| 421 |
+
guide_scale=5.0,
|
| 422 |
+
n_prompt="",
|
| 423 |
+
seed=-1,
|
| 424 |
+
offload_model=True):
|
| 425 |
+
r"""
|
| 426 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
input_prompt (`str`):
|
| 430 |
+
Text prompt for content generation.
|
| 431 |
+
img (PIL.Image.Image):
|
| 432 |
+
Input image tensor. Shape: [3, H, W]
|
| 433 |
+
max_area (`int`, *optional*, defaults to 704*1280):
|
| 434 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 435 |
+
frame_num (`int`, *optional*, defaults to 121):
|
| 436 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 437 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 438 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 439 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 440 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 441 |
+
Solver used to sample the video.
|
| 442 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 443 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 444 |
+
guide_scale (`float`, *optional*, defaults 5.0):
|
| 445 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 446 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 447 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 448 |
+
seed (`int`, *optional*, defaults to -1):
|
| 449 |
+
Random seed for noise generation. If -1, use random seed
|
| 450 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 451 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
torch.Tensor:
|
| 455 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 456 |
+
- C: Color channels (3 for RGB)
|
| 457 |
+
- N: Number of frames (121)
|
| 458 |
+
- H: Frame height (from max_area)
|
| 459 |
+
- W: Frame width (from max_area)
|
| 460 |
+
"""
|
| 461 |
+
# preprocess
|
| 462 |
+
ih, iw = img.height, img.width
|
| 463 |
+
dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[
|
| 464 |
+
2] * self.vae_stride[2]
|
| 465 |
+
ow, oh = best_output_size(iw, ih, dw, dh, max_area)
|
| 466 |
+
|
| 467 |
+
scale = max(ow / iw, oh / ih)
|
| 468 |
+
img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
|
| 469 |
+
|
| 470 |
+
# center-crop
|
| 471 |
+
x1 = (img.width - ow) // 2
|
| 472 |
+
y1 = (img.height - oh) // 2
|
| 473 |
+
img = img.crop((x1, y1, x1 + ow, y1 + oh))
|
| 474 |
+
assert img.width == ow and img.height == oh
|
| 475 |
+
|
| 476 |
+
# to tensor
|
| 477 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
|
| 478 |
+
|
| 479 |
+
F = frame_num
|
| 480 |
+
seq_len = ((F - 1) // self.vae_stride[0] + 1) * (
|
| 481 |
+
oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (
|
| 482 |
+
self.patch_size[1] * self.patch_size[2])
|
| 483 |
+
seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size
|
| 484 |
+
|
| 485 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 486 |
+
seed_g = torch.Generator(device=self.device)
|
| 487 |
+
seed_g.manual_seed(seed)
|
| 488 |
+
noise = torch.randn(
|
| 489 |
+
self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 490 |
+
oh // self.vae_stride[1],
|
| 491 |
+
ow // self.vae_stride[2],
|
| 492 |
+
dtype=torch.float32,
|
| 493 |
+
generator=seed_g,
|
| 494 |
+
device=self.device)
|
| 495 |
+
|
| 496 |
+
if n_prompt == "":
|
| 497 |
+
n_prompt = self.sample_neg_prompt
|
| 498 |
+
|
| 499 |
+
# preprocess
|
| 500 |
+
if not self.t5_cpu:
|
| 501 |
+
self.text_encoder.model.to(self.device)
|
| 502 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 503 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 504 |
+
if offload_model:
|
| 505 |
+
self.text_encoder.model.cpu()
|
| 506 |
+
else:
|
| 507 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 508 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 509 |
+
context = [t.to(self.device) for t in context]
|
| 510 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 511 |
+
|
| 512 |
+
z = self.vae.encode([img])
|
| 513 |
+
|
| 514 |
+
@contextmanager
|
| 515 |
+
def noop_no_sync():
|
| 516 |
+
yield
|
| 517 |
+
|
| 518 |
+
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 519 |
+
|
| 520 |
+
# evaluation mode
|
| 521 |
+
with (
|
| 522 |
+
torch.amp.autocast('cuda', dtype=self.param_dtype),
|
| 523 |
+
torch.no_grad(),
|
| 524 |
+
no_sync(),
|
| 525 |
+
):
|
| 526 |
+
|
| 527 |
+
if sample_solver == 'unipc':
|
| 528 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 529 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 530 |
+
shift=1,
|
| 531 |
+
use_dynamic_shifting=False)
|
| 532 |
+
sample_scheduler.set_timesteps(
|
| 533 |
+
sampling_steps, device=self.device, shift=shift)
|
| 534 |
+
timesteps = sample_scheduler.timesteps
|
| 535 |
+
elif sample_solver == 'dpm++':
|
| 536 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 537 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 538 |
+
shift=1,
|
| 539 |
+
use_dynamic_shifting=False)
|
| 540 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 541 |
+
timesteps, _ = retrieve_timesteps(
|
| 542 |
+
sample_scheduler,
|
| 543 |
+
device=self.device,
|
| 544 |
+
sigmas=sampling_sigmas)
|
| 545 |
+
else:
|
| 546 |
+
raise NotImplementedError("Unsupported solver.")
|
| 547 |
+
|
| 548 |
+
# sample videos
|
| 549 |
+
latent = noise
|
| 550 |
+
mask1, mask2 = masks_like([noise], zero=True)
|
| 551 |
+
latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
|
| 552 |
+
|
| 553 |
+
arg_c = {
|
| 554 |
+
'context': [context[0]],
|
| 555 |
+
'seq_len': seq_len,
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
arg_null = {
|
| 559 |
+
'context': context_null,
|
| 560 |
+
'seq_len': seq_len,
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
if offload_model or self.init_on_cpu:
|
| 564 |
+
self.model.to(self.device)
|
| 565 |
+
torch.cuda.empty_cache()
|
| 566 |
+
|
| 567 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 568 |
+
latent_model_input = [latent.to(self.device)]
|
| 569 |
+
timestep = [t]
|
| 570 |
+
|
| 571 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 572 |
+
|
| 573 |
+
temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
|
| 574 |
+
temp_ts = torch.cat([
|
| 575 |
+
temp_ts,
|
| 576 |
+
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
|
| 577 |
+
])
|
| 578 |
+
timestep = temp_ts.unsqueeze(0)
|
| 579 |
+
|
| 580 |
+
noise_pred_cond = self.model(
|
| 581 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
| 582 |
+
if offload_model:
|
| 583 |
+
torch.cuda.empty_cache()
|
| 584 |
+
noise_pred_uncond = self.model(
|
| 585 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
| 586 |
+
if offload_model:
|
| 587 |
+
torch.cuda.empty_cache()
|
| 588 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 589 |
+
noise_pred_cond - noise_pred_uncond)
|
| 590 |
+
|
| 591 |
+
temp_x0 = sample_scheduler.step(
|
| 592 |
+
noise_pred.unsqueeze(0),
|
| 593 |
+
t,
|
| 594 |
+
latent.unsqueeze(0),
|
| 595 |
+
return_dict=False,
|
| 596 |
+
generator=seed_g)[0]
|
| 597 |
+
latent = temp_x0.squeeze(0)
|
| 598 |
+
latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
|
| 599 |
+
|
| 600 |
+
x0 = [latent]
|
| 601 |
+
del latent_model_input, timestep
|
| 602 |
+
|
| 603 |
+
if offload_model:
|
| 604 |
+
self.model.cpu()
|
| 605 |
+
torch.cuda.synchronize()
|
| 606 |
+
torch.cuda.empty_cache()
|
| 607 |
+
|
| 608 |
+
if self.rank == 0:
|
| 609 |
+
videos = self.vae.decode(x0)
|
| 610 |
+
|
| 611 |
+
del noise, latent, x0
|
| 612 |
+
del sample_scheduler
|
| 613 |
+
if offload_model:
|
| 614 |
+
gc.collect()
|
| 615 |
+
torch.cuda.synchronize()
|
| 616 |
+
if dist.is_initialized():
|
| 617 |
+
dist.barrier()
|
| 618 |
+
|
| 619 |
+
return videos[0] if self.rank == 0 else None
|
wan/utils/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .fm_solvers import (
|
| 3 |
+
FlowDPMSolverMultistepScheduler,
|
| 4 |
+
get_sampling_sigmas,
|
| 5 |
+
retrieve_timesteps,
|
| 6 |
+
)
|
| 7 |
+
from .fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
|
| 11 |
+
'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
|
| 12 |
+
]
|
wan/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (393 Bytes). View file
|
|
|