vivienfanghua commited on
Commit
4de1367
·
verified ·
1 Parent(s): a9ff1a1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +6 -5
  3. app.py +387 -0
  4. examples/i2v_input.JPG +3 -0
  5. generate.py +411 -0
  6. requirements.txt +15 -0
  7. wan/__init__.py +5 -0
  8. wan/__pycache__/__init__.cpython-310.pyc +0 -0
  9. wan/__pycache__/image2video.cpython-310.pyc +0 -0
  10. wan/__pycache__/text2video.cpython-310.pyc +0 -0
  11. wan/__pycache__/textimage2video.cpython-310.pyc +0 -0
  12. wan/configs/__init__.py +39 -0
  13. wan/configs/__pycache__/__init__.cpython-310.pyc +0 -0
  14. wan/configs/__pycache__/shared_config.cpython-310.pyc +0 -0
  15. wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc +0 -0
  16. wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc +0 -0
  17. wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc +0 -0
  18. wan/configs/shared_config.py +20 -0
  19. wan/configs/wan_i2v_A14B.py +37 -0
  20. wan/configs/wan_t2v_A14B.py +37 -0
  21. wan/configs/wan_ti2v_5B.py +36 -0
  22. wan/distributed/__init__.py +1 -0
  23. wan/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
  24. wan/distributed/__pycache__/fsdp.cpython-310.pyc +0 -0
  25. wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc +0 -0
  26. wan/distributed/__pycache__/ulysses.cpython-310.pyc +0 -0
  27. wan/distributed/__pycache__/util.cpython-310.pyc +0 -0
  28. wan/distributed/fsdp.py +43 -0
  29. wan/distributed/sequence_parallel.py +176 -0
  30. wan/distributed/ulysses.py +47 -0
  31. wan/distributed/util.py +51 -0
  32. wan/image2video.py +431 -0
  33. wan/modules/__init__.py +19 -0
  34. wan/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  35. wan/modules/__pycache__/attention.cpython-310.pyc +0 -0
  36. wan/modules/__pycache__/model.cpython-310.pyc +0 -0
  37. wan/modules/__pycache__/t5.cpython-310.pyc +0 -0
  38. wan/modules/__pycache__/tokenizers.cpython-310.pyc +0 -0
  39. wan/modules/__pycache__/vae2_1.cpython-310.pyc +0 -0
  40. wan/modules/__pycache__/vae2_2.cpython-310.pyc +0 -0
  41. wan/modules/attention.py +179 -0
  42. wan/modules/model.py +546 -0
  43. wan/modules/t5.py +513 -0
  44. wan/modules/tokenizers.py +82 -0
  45. wan/modules/vae2_1.py +663 -0
  46. wan/modules/vae2_2.py +1051 -0
  47. wan/text2video.py +378 -0
  48. wan/textimage2video.py +619 -0
  49. wan/utils/__init__.py +12 -0
  50. wan/utils/__pycache__/__init__.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Wan2.2 Enhanced Amd
3
- emoji: 📊
4
- colorFrom: blue
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.45.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: wan2.2_enhanced_amd
3
+ emoji: 🚀
4
+ colorFrom: pink
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.39.0
8
  app_file: app.py
9
  pinned: false
10
+ short_description: Wan 2.2 5B
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
+
5
+ #import subprocess
6
+ #subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
7
+
8
+ # wan2.2-main/gradio_ti2v.py
9
+ import gradio as gr
10
+ import torch
11
+ from huggingface_hub import snapshot_download
12
+ from PIL import Image
13
+ import random
14
+ import numpy as np
15
+ import spaces
16
+
17
+ import wan
18
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
19
+ from wan.utils.utils import cache_video
20
+
21
+ import gc
22
+
23
+ # --- 1. Global Setup and Model Loading ---
24
+
25
+ print("Starting Gradio App for Wan 2.2 TI2V-5B...")
26
+
27
+ # Download model snapshots from Hugging Face Hub
28
+ repo_id = "Wan-AI/Wan2.2-TI2V-5B"
29
+ print(f"Downloading/loading checkpoints for {repo_id}...")
30
+ ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
31
+ print(f"Using checkpoints from {ckpt_dir}")
32
+
33
+ # Load the model configuration
34
+ TASK_NAME = 'ti2v-5B'
35
+ cfg = WAN_CONFIGS[TASK_NAME]
36
+ FIXED_FPS = 24
37
+ MIN_FRAMES_MODEL = 8
38
+ MAX_FRAMES_MODEL = 121
39
+
40
+ # Instantiate the pipeline in the global scope
41
+ print("Initializing WanTI2V pipeline...")
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ device_id = 0 if torch.cuda.is_available() else -1
44
+ pipeline = wan.WanTI2V(
45
+ config=cfg,
46
+ checkpoint_dir=ckpt_dir,
47
+ device_id=device_id,
48
+ rank=0,
49
+ t5_fsdp=False,
50
+ dit_fsdp=False,
51
+ use_sp=False,
52
+ t5_cpu=False,
53
+ init_on_cpu=False,
54
+ convert_model_dtype=True,
55
+ )
56
+ print("Pipeline initialized and ready.")
57
+
58
+ # --- Helper Functions ---
59
+ def clear_gpu_memory():
60
+ """Clear GPU memory more thoroughly"""
61
+ if torch.cuda.is_available():
62
+ torch.cuda.empty_cache()
63
+ torch.cuda.ipc_collect()
64
+ gc.collect()
65
+
66
+ def select_best_size_for_image(image, available_sizes):
67
+ """Select the size option with aspect ratio closest to the input image."""
68
+ if image is None:
69
+ return available_sizes[0] # Return first option if no image
70
+
71
+ img_width, img_height = image.size
72
+ img_aspect_ratio = img_height / img_width
73
+
74
+ best_size = available_sizes[0]
75
+ best_diff = float('inf')
76
+
77
+ for size_str in available_sizes:
78
+ # Parse size string like "704*1280"
79
+ height, width = map(int, size_str.split('*'))
80
+ size_aspect_ratio = height / width
81
+ diff = abs(img_aspect_ratio - size_aspect_ratio)
82
+
83
+ if diff < best_diff:
84
+ best_diff = diff
85
+ best_size = size_str
86
+
87
+ return best_size
88
+
89
+ def handle_image_upload(image):
90
+ """Handle image upload and return the best matching size."""
91
+ if image is None:
92
+ return gr.update()
93
+
94
+ pil_image = Image.fromarray(image).convert("RGB")
95
+ available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
96
+ best_size = select_best_size_for_image(pil_image, available_sizes)
97
+
98
+ return gr.update(value=best_size)
99
+
100
+ def validate_inputs(image, prompt, duration_seconds):
101
+ """Validate user inputs"""
102
+ errors = []
103
+
104
+ if not prompt or len(prompt.strip()) < 5:
105
+ errors.append("Prompt must be at least 5 characters long.")
106
+
107
+ if image is not None:
108
+ img = Image.fromarray(image)
109
+ if img.size[0] * img.size[1] > 4096 * 4096:
110
+ errors.append("Image size is too large (maximum 4096x4096).")
111
+
112
+ if duration_seconds > 5.0 and image is None:
113
+ errors.append("Videos longer than 5 seconds require an input image.")
114
+
115
+ return errors
116
+
117
+ def get_duration(image,
118
+ prompt,
119
+ size,
120
+ duration_seconds,
121
+ sampling_steps,
122
+ guide_scale,
123
+ shift,
124
+ seed,
125
+ progress):
126
+ """Calculate dynamic GPU duration based on parameters."""
127
+ if sampling_steps > 35 and duration_seconds >= 2:
128
+ return 120
129
+ elif sampling_steps < 35 or duration_seconds < 2:
130
+ return 105
131
+ else:
132
+ return 90
133
+
134
+ def apply_template(template, current_prompt):
135
+ """Apply prompt template"""
136
+ if "{subject}" in template:
137
+ # Extract the main subject from current prompt (simple heuristic)
138
+ subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
139
+ return template.replace("{subject}", subject)
140
+ return template + " " + current_prompt
141
+
142
+ # --- 2. Gradio Inference Function ---
143
+ @spaces.GPU(duration=get_duration)
144
+ def generate_video(
145
+ image,
146
+ prompt,
147
+ size,
148
+ duration_seconds,
149
+ sampling_steps,
150
+ guide_scale,
151
+ shift,
152
+ seed,
153
+ progress=gr.Progress(track_tqdm=True)
154
+ ):
155
+ """The main function to generate video, called by the Gradio interface."""
156
+ # Validate inputs
157
+ errors = validate_inputs(image, prompt, duration_seconds)
158
+ if errors:
159
+ raise gr.Error("\n".join(errors))
160
+
161
+ progress(0, desc="Setting up...")
162
+
163
+ if seed == -1:
164
+ seed = random.randint(0, sys.maxsize)
165
+
166
+ progress(0.1, desc="Processing image...")
167
+
168
+ input_image = None
169
+ if image is not None:
170
+ input_image = Image.fromarray(image).convert("RGB")
171
+ # Resize image to match selected size
172
+ target_height, target_width = map(int, size.split('*'))
173
+ input_image = input_image.resize((target_width, target_height))
174
+
175
+ # Calculate number of frames based on duration
176
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
177
+
178
+ progress(0.2, desc="Generating video...")
179
+
180
+ try:
181
+ video_tensor = pipeline.generate(
182
+ input_prompt=prompt,
183
+ img=input_image, # Pass None for T2V, Image for I2V
184
+ size=SIZE_CONFIGS[size],
185
+ max_area=MAX_AREA_CONFIGS[size],
186
+ frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
187
+ shift=shift,
188
+ sample_solver='unipc',
189
+ sampling_steps=int(sampling_steps),
190
+ guide_scale=guide_scale,
191
+ seed=seed,
192
+ offload_model=True
193
+ )
194
+
195
+ progress(0.9, desc="Saving video...")
196
+
197
+ # Save the video to a temporary file
198
+ video_path = cache_video(
199
+ tensor=video_tensor[None], # Add a batch dimension
200
+ save_file=None, # cache_video will create a temp file
201
+ fps=cfg.sample_fps,
202
+ normalize=True,
203
+ value_range=(-1, 1)
204
+ )
205
+
206
+ progress(1.0, desc="Complete!")
207
+
208
+ except torch.cuda.OutOfMemoryError:
209
+ clear_gpu_memory()
210
+ raise gr.Error("GPU out of memory. Please try with lower settings.")
211
+ except Exception as e:
212
+ raise gr.Error(f"Video generation failed: {str(e)}")
213
+ finally:
214
+ if 'video_tensor' in locals():
215
+ del video_tensor
216
+ clear_gpu_memory()
217
+
218
+ return video_path
219
+
220
+
221
+ # --- 3. Gradio Interface ---
222
+ css = """
223
+ .gradio-container {max-width: 1100px !important; margin: 0 auto}
224
+ #output_video {height: 500px;}
225
+ #input_image {height: 500px;}
226
+ .template-btn {margin: 2px !important;}
227
+ """
228
+
229
+ # Default prompt with motion emphasis
230
+ DEFAULT_PROMPT = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
231
+
232
+ # Prompt templates
233
+ templates = {
234
+ "Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
235
+ "Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
236
+ "Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
237
+ "Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
238
+ "Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
239
+ }
240
+
241
+ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
242
+ gr.Markdown("""
243
+ # Wan 2.2 TI2V Enhanced running on AMD MI355
244
+
245
+ Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**
246
+ [[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)
247
+
248
+ ### 💡 Tips for best results:
249
+ - 🖼️ Upload an image for better control over the video content
250
+ - ⏱️ Longer videos require more processing time
251
+ - 🎯 Be specific and descriptive in your prompts
252
+ - 🎬 Include motion-related keywords for dynamic videos
253
+ """)
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=2):
257
+ image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
258
+ prompt_input = gr.Textbox(
259
+ label="Prompt",
260
+ value=DEFAULT_PROMPT,
261
+ lines=3,
262
+ placeholder="Describe the video you want to generate..."
263
+ )
264
+
265
+ # Prompt templates section
266
+ with gr.Accordion("Prompt Templates", open=False):
267
+ gr.Markdown("Click a template to apply it to your prompt:")
268
+ with gr.Row():
269
+ template_buttons = {}
270
+ for name, template in templates.items():
271
+ btn = gr.Button(name, size="sm", elem_classes=["template-btn"])
272
+ template_buttons[name] = (btn, template)
273
+
274
+ # Connect template buttons
275
+ for name, (btn, template) in template_buttons.items():
276
+ btn.click(
277
+ fn=lambda t=template, p=prompt_input: apply_template(t, p),
278
+ inputs=[prompt_input],
279
+ outputs=prompt_input
280
+ )
281
+
282
+ duration_input = gr.Slider(
283
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
284
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
285
+ step=0.1,
286
+ value=2.0,
287
+ label="Duration (seconds)",
288
+ info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
289
+ )
290
+ size_input = gr.Dropdown(
291
+ label="Output Resolution",
292
+ choices=list(SUPPORTED_SIZES[TASK_NAME]),
293
+ value="704*1280"
294
+ )
295
+
296
+ with gr.Column(scale=2):
297
+ video_output = gr.Video(label="Generated Video", elem_id="output_video")
298
+
299
+ # Status indicators
300
+ with gr.Row():
301
+ status_text = gr.Textbox(
302
+ label="Status",
303
+ value="Ready",
304
+ interactive=False,
305
+ max_lines=1
306
+ )
307
+
308
+ with gr.Accordion("Advanced Settings", open=False):
309
+ steps_input = gr.Slider(
310
+ label="Sampling Steps",
311
+ minimum=10,
312
+ maximum=50,
313
+ value=38,
314
+ step=1,
315
+ info="Higher values = better quality but slower"
316
+ )
317
+ scale_input = gr.Slider(
318
+ label="Guidance Scale",
319
+ minimum=1.0,
320
+ maximum=10.0,
321
+ value=cfg.sample_guide_scale,
322
+ step=0.1,
323
+ info="Higher values = closer to prompt but less creative"
324
+ )
325
+ shift_input = gr.Slider(
326
+ label="Sample Shift",
327
+ minimum=1.0,
328
+ maximum=20.0,
329
+ value=cfg.sample_shift,
330
+ step=0.1,
331
+ info="Affects the sampling process dynamics"
332
+ )
333
+ seed_input = gr.Number(
334
+ label="Seed (-1 for random)",
335
+ value=-1,
336
+ precision=0,
337
+ info="Use same seed for reproducible results"
338
+ )
339
+
340
+ run_button = gr.Button("Generate Video", variant="primary", size="lg")
341
+
342
+ # Add image upload handler
343
+ image_input.upload(
344
+ fn=handle_image_upload,
345
+ inputs=[image_input],
346
+ outputs=[size_input]
347
+ )
348
+
349
+ image_input.clear(
350
+ fn=handle_image_upload,
351
+ inputs=[image_input],
352
+ outputs=[size_input]
353
+ )
354
+
355
+ # Update status when generating
356
+ def update_status_and_generate(*args):
357
+ status_text.value = "Generating..."
358
+ try:
359
+ result = generate_video(*args)
360
+ status_text.value = "Complete!"
361
+ return result
362
+ except Exception as e:
363
+ status_text.value = "Error occurred"
364
+ raise e
365
+
366
+ example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
367
+ gr.Examples(
368
+ examples=[
369
+ [None, "Golden hour, soft lighting, warm colors, saturated colors, wide shot, left-heavy composition. A weathered gondolier stands in a flat-bottomed boat, propelling it forward with a long wooden pole through the flooded ruins of Venice. The decaying buildings on either side are cloaked in creeping vines and marked by rusted metalwork, their once-proud facades now crumbling into the water. The camera moves slowly forward and tilts left, revealing behind him the majestic remnants of the city bathed in the amber glow of the setting sun. Silhouettes of collapsed archways and broken domes rise against the golden skyline, while the still water reflects the warm hues of the sky and surrounding structures.", "1280*704", 4.0],
370
+ [None, "In a surreal video, four miniature skiers glide down a winding, three-dimensional trail of thick white paint on a plain white canvas-like background. The textured paint mimics snow, with visible brushstrokes and uneven edges, enhanced by light and shadow. The skiers, in colorful gear, are posed dynamically from top to bottom, each casting a shadow that heightens the illusion of depth. This scene miniaturizes a grand outdoor sport into a vivid, imaginative artwork.", "1280*704", 2.0],
371
+ [None, "In a time-lapse video, a crane slowly lifts a steel beam on a construction site. The camera pulls back slowly from a close-up, revealing details of the crane and the steel beam. The skyline transitions from day to night, with buildings and machinery in the background constantly operating. As the camera pulls further back, the busy scene of the entire construction site comes into view; cranes and other equipment continue working under the night sky, shaping the city's outline.", "704*1280", 2.5],
372
+ [None, "Cinematic racetrack scene: Low-angle medium long shot of jockey-horse leap. High-contrast backlighting, warm tones, silhouettes. Slow-motion freeze with dust for dynamic tension. Scoreboard detail. Optimized for immersive video generation.", "1280*704", 3.0],
373
+ ],
374
+ inputs=[image_input, prompt_input, size_input, duration_input],
375
+ outputs=video_output,
376
+ fn=generate_video,
377
+ cache_examples=False,
378
+ )
379
+
380
+ run_button.click(
381
+ fn=generate_video,
382
+ inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
383
+ outputs=video_output
384
+ )
385
+
386
+ if __name__ == "__main__":
387
+ demo.launch()
examples/i2v_input.JPG ADDED

Git LFS Details

  • SHA256: 077e3d965090c9028c69c00931675f42e1acc815c6eb450ab291b3b72d211a8e
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
generate.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import cache_video, str2bool
22
+
23
+ EXAMPLE_PROMPT = {
24
+ "t2v-A14B": {
25
+ "prompt":
26
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
27
+ },
28
+ "i2v-A14B": {
29
+ "prompt":
30
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
31
+ "image":
32
+ "examples/i2v_input.JPG",
33
+ },
34
+ "ti2v-5B": {
35
+ "prompt":
36
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
37
+ },
38
+ }
39
+
40
+
41
+ def _validate_args(args):
42
+ # Basic check
43
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
44
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
45
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
46
+
47
+ if args.prompt is None:
48
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
49
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
50
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
51
+
52
+ if args.task == "i2v-A14B":
53
+ assert args.image is not None, "Please specify the image path for i2v."
54
+
55
+ cfg = WAN_CONFIGS[args.task]
56
+
57
+ if args.sample_steps is None:
58
+ args.sample_steps = cfg.sample_steps
59
+
60
+ if args.sample_shift is None:
61
+ args.sample_shift = cfg.sample_shift
62
+
63
+ if args.sample_guide_scale is None:
64
+ args.sample_guide_scale = cfg.sample_guide_scale
65
+
66
+ if args.frame_num is None:
67
+ args.frame_num = cfg.frame_num
68
+
69
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
70
+ 0, sys.maxsize)
71
+ # Size check
72
+ assert args.size in SUPPORTED_SIZES[
73
+ args.
74
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
75
+
76
+
77
+ def _parse_args():
78
+ parser = argparse.ArgumentParser(
79
+ description="Generate a image or video from a text prompt or image using Wan"
80
+ )
81
+ parser.add_argument(
82
+ "--task",
83
+ type=str,
84
+ default="t2v-A14B",
85
+ choices=list(WAN_CONFIGS.keys()),
86
+ help="The task to run.")
87
+ parser.add_argument(
88
+ "--size",
89
+ type=str,
90
+ default="1280*720",
91
+ choices=list(SIZE_CONFIGS.keys()),
92
+ help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
93
+ )
94
+ parser.add_argument(
95
+ "--frame_num",
96
+ type=int,
97
+ default=None,
98
+ help="How many frames of video are generated. The number should be 4n+1"
99
+ )
100
+ parser.add_argument(
101
+ "--ckpt_dir",
102
+ type=str,
103
+ default=None,
104
+ help="The path to the checkpoint directory.")
105
+ parser.add_argument(
106
+ "--offload_model",
107
+ type=str2bool,
108
+ default=None,
109
+ help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
110
+ )
111
+ parser.add_argument(
112
+ "--ulysses_size",
113
+ type=int,
114
+ default=1,
115
+ help="The size of the ulysses parallelism in DiT.")
116
+ parser.add_argument(
117
+ "--t5_fsdp",
118
+ action="store_true",
119
+ default=False,
120
+ help="Whether to use FSDP for T5.")
121
+ parser.add_argument(
122
+ "--t5_cpu",
123
+ action="store_true",
124
+ default=False,
125
+ help="Whether to place T5 model on CPU.")
126
+ parser.add_argument(
127
+ "--dit_fsdp",
128
+ action="store_true",
129
+ default=False,
130
+ help="Whether to use FSDP for DiT.")
131
+ parser.add_argument(
132
+ "--save_file",
133
+ type=str,
134
+ default=None,
135
+ help="The file to save the generated video to.")
136
+ parser.add_argument(
137
+ "--prompt",
138
+ type=str,
139
+ default=None,
140
+ help="The prompt to generate the video from.")
141
+ parser.add_argument(
142
+ "--use_prompt_extend",
143
+ action="store_true",
144
+ default=False,
145
+ help="Whether to use prompt extend.")
146
+ parser.add_argument(
147
+ "--prompt_extend_method",
148
+ type=str,
149
+ default="local_qwen",
150
+ choices=["dashscope", "local_qwen"],
151
+ help="The prompt extend method to use.")
152
+ parser.add_argument(
153
+ "--prompt_extend_model",
154
+ type=str,
155
+ default=None,
156
+ help="The prompt extend model to use.")
157
+ parser.add_argument(
158
+ "--prompt_extend_target_lang",
159
+ type=str,
160
+ default="zh",
161
+ choices=["zh", "en"],
162
+ help="The target language of prompt extend.")
163
+ parser.add_argument(
164
+ "--base_seed",
165
+ type=int,
166
+ default=-1,
167
+ help="The seed to use for generating the video.")
168
+ parser.add_argument(
169
+ "--image",
170
+ type=str,
171
+ default=None,
172
+ help="The image to generate the video from.")
173
+ parser.add_argument(
174
+ "--sample_solver",
175
+ type=str,
176
+ default='unipc',
177
+ choices=['unipc', 'dpm++'],
178
+ help="The solver used to sample.")
179
+ parser.add_argument(
180
+ "--sample_steps", type=int, default=None, help="The sampling steps.")
181
+ parser.add_argument(
182
+ "--sample_shift",
183
+ type=float,
184
+ default=None,
185
+ help="Sampling shift factor for flow matching schedulers.")
186
+ parser.add_argument(
187
+ "--sample_guide_scale",
188
+ type=float,
189
+ default=None,
190
+ help="Classifier free guidance scale.")
191
+ parser.add_argument(
192
+ "--convert_model_dtype",
193
+ action="store_true",
194
+ default=False,
195
+ help="Whether to convert model paramerters dtype.")
196
+
197
+ args = parser.parse_args()
198
+
199
+ _validate_args(args)
200
+
201
+ return args
202
+
203
+
204
+ def _init_logging(rank):
205
+ # logging
206
+ if rank == 0:
207
+ # set format
208
+ logging.basicConfig(
209
+ level=logging.INFO,
210
+ format="[%(asctime)s] %(levelname)s: %(message)s",
211
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
212
+ else:
213
+ logging.basicConfig(level=logging.ERROR)
214
+
215
+
216
+ def generate(args):
217
+ rank = int(os.getenv("RANK", 0))
218
+ world_size = int(os.getenv("WORLD_SIZE", 1))
219
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
220
+ device = local_rank
221
+ _init_logging(rank)
222
+
223
+ if args.offload_model is None:
224
+ args.offload_model = False if world_size > 1 else True
225
+ logging.info(
226
+ f"offload_model is not specified, set to {args.offload_model}.")
227
+ if world_size > 1:
228
+ torch.cuda.set_device(local_rank)
229
+ dist.init_process_group(
230
+ backend="nccl",
231
+ init_method="env://",
232
+ rank=rank,
233
+ world_size=world_size)
234
+ else:
235
+ assert not (
236
+ args.t5_fsdp or args.dit_fsdp
237
+ ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
238
+ assert not (
239
+ args.ulysses_size > 1
240
+ ), f"sequence parallel are not supported in non-distributed environments."
241
+
242
+ if args.ulysses_size > 1:
243
+ assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
244
+ init_distributed_group()
245
+
246
+ if args.use_prompt_extend:
247
+ if args.prompt_extend_method == "dashscope":
248
+ prompt_expander = DashScopePromptExpander(
249
+ model_name=args.prompt_extend_model,
250
+ task=args.task,
251
+ is_vl=args.image is not None)
252
+ elif args.prompt_extend_method == "local_qwen":
253
+ prompt_expander = QwenPromptExpander(
254
+ model_name=args.prompt_extend_model,
255
+ task=args.task,
256
+ is_vl=args.image is not None,
257
+ device=rank)
258
+ else:
259
+ raise NotImplementedError(
260
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
261
+
262
+ cfg = WAN_CONFIGS[args.task]
263
+ if args.ulysses_size > 1:
264
+ assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
265
+
266
+ logging.info(f"Generation job args: {args}")
267
+ logging.info(f"Generation model config: {cfg}")
268
+
269
+ if dist.is_initialized():
270
+ base_seed = [args.base_seed] if rank == 0 else [None]
271
+ dist.broadcast_object_list(base_seed, src=0)
272
+ args.base_seed = base_seed[0]
273
+
274
+ logging.info(f"Input prompt: {args.prompt}")
275
+ img = None
276
+ if args.image is not None:
277
+ img = Image.open(args.image).convert("RGB")
278
+ logging.info(f"Input image: {args.image}")
279
+
280
+ # prompt extend
281
+ if args.use_prompt_extend:
282
+ logging.info("Extending prompt ...")
283
+ if rank == 0:
284
+ prompt_output = prompt_expander(
285
+ args.prompt,
286
+ image=img,
287
+ tar_lang=args.prompt_extend_target_lang,
288
+ seed=args.base_seed)
289
+ if prompt_output.status == False:
290
+ logging.info(
291
+ f"Extending prompt failed: {prompt_output.message}")
292
+ logging.info("Falling back to original prompt.")
293
+ input_prompt = args.prompt
294
+ else:
295
+ input_prompt = prompt_output.prompt
296
+ input_prompt = [input_prompt]
297
+ else:
298
+ input_prompt = [None]
299
+ if dist.is_initialized():
300
+ dist.broadcast_object_list(input_prompt, src=0)
301
+ args.prompt = input_prompt[0]
302
+ logging.info(f"Extended prompt: {args.prompt}")
303
+
304
+ if "t2v" in args.task:
305
+ logging.info("Creating WanT2V pipeline.")
306
+ wan_t2v = wan.WanT2V(
307
+ config=cfg,
308
+ checkpoint_dir=args.ckpt_dir,
309
+ device_id=device,
310
+ rank=rank,
311
+ t5_fsdp=args.t5_fsdp,
312
+ dit_fsdp=args.dit_fsdp,
313
+ use_sp=(args.ulysses_size > 1),
314
+ t5_cpu=args.t5_cpu,
315
+ convert_model_dtype=args.convert_model_dtype,
316
+ )
317
+
318
+ logging.info(f"Generating video ...")
319
+ video = wan_t2v.generate(
320
+ args.prompt,
321
+ size=SIZE_CONFIGS[args.size],
322
+ frame_num=args.frame_num,
323
+ shift=args.sample_shift,
324
+ sample_solver=args.sample_solver,
325
+ sampling_steps=args.sample_steps,
326
+ guide_scale=args.sample_guide_scale,
327
+ seed=args.base_seed,
328
+ offload_model=args.offload_model)
329
+ elif "ti2v" in args.task:
330
+ logging.info("Creating WanTI2V pipeline.")
331
+ wan_ti2v = wan.WanTI2V(
332
+ config=cfg,
333
+ checkpoint_dir=args.ckpt_dir,
334
+ device_id=device,
335
+ rank=rank,
336
+ t5_fsdp=args.t5_fsdp,
337
+ dit_fsdp=args.dit_fsdp,
338
+ use_sp=(args.ulysses_size > 1),
339
+ t5_cpu=args.t5_cpu,
340
+ convert_model_dtype=args.convert_model_dtype,
341
+ )
342
+
343
+ logging.info(f"Generating video ...")
344
+ video = wan_ti2v.generate(
345
+ args.prompt,
346
+ img=img,
347
+ size=SIZE_CONFIGS[args.size],
348
+ max_area=MAX_AREA_CONFIGS[args.size],
349
+ frame_num=args.frame_num,
350
+ shift=args.sample_shift,
351
+ sample_solver=args.sample_solver,
352
+ sampling_steps=args.sample_steps,
353
+ guide_scale=args.sample_guide_scale,
354
+ seed=args.base_seed,
355
+ offload_model=args.offload_model)
356
+ else:
357
+ logging.info("Creating WanI2V pipeline.")
358
+ wan_i2v = wan.WanI2V(
359
+ config=cfg,
360
+ checkpoint_dir=args.ckpt_dir,
361
+ device_id=device,
362
+ rank=rank,
363
+ t5_fsdp=args.t5_fsdp,
364
+ dit_fsdp=args.dit_fsdp,
365
+ use_sp=(args.ulysses_size > 1),
366
+ t5_cpu=args.t5_cpu,
367
+ convert_model_dtype=args.convert_model_dtype,
368
+ )
369
+
370
+ logging.info("Generating video ...")
371
+ video = wan_i2v.generate(
372
+ args.prompt,
373
+ img,
374
+ max_area=MAX_AREA_CONFIGS[args.size],
375
+ frame_num=args.frame_num,
376
+ shift=args.sample_shift,
377
+ sample_solver=args.sample_solver,
378
+ sampling_steps=args.sample_steps,
379
+ guide_scale=args.sample_guide_scale,
380
+ seed=args.base_seed,
381
+ offload_model=args.offload_model)
382
+
383
+ if rank == 0:
384
+ if args.save_file is None:
385
+ formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
386
+ formatted_prompt = args.prompt.replace(" ", "_").replace("/",
387
+ "_")[:50]
388
+ suffix = '.mp4'
389
+ args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
390
+
391
+ logging.info(f"Saving generated video to {args.save_file}")
392
+ cache_video(
393
+ tensor=video[None],
394
+ save_file=args.save_file,
395
+ fps=cfg.sample_fps,
396
+ nrow=1,
397
+ normalize=True,
398
+ value_range=(-1, 1))
399
+ del video
400
+
401
+ torch.cuda.synchronize()
402
+ if dist.is_initialized():
403
+ dist.barrier()
404
+ dist.destroy_process_group()
405
+
406
+ logging.info("Finished.")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ args = _parse_args()
411
+ generate(args)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ torchvision>=0.19.0
3
+ opencv-python>=4.9.0.80
4
+ diffusers>=0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ dashscope
13
+ imageio-ffmpeg
14
+ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.7-cp310-cp310-linux_x86_64.whl
15
+ numpy>=1.23.5,<2
wan/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
3
+ from .image2video import WanI2V
4
+ from .text2video import WanT2V
5
+ from .textimage2video import WanTI2V
wan/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (333 Bytes). View file
 
wan/__pycache__/image2video.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
wan/__pycache__/text2video.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
wan/__pycache__/textimage2video.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
wan/configs/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_A14B import i2v_A14B
8
+ from .wan_t2v_A14B import t2v_A14B
9
+ from .wan_ti2v_5B import ti2v_5B
10
+
11
+ WAN_CONFIGS = {
12
+ 't2v-A14B': t2v_A14B,
13
+ 'i2v-A14B': i2v_A14B,
14
+ 'ti2v-5B': ti2v_5B,
15
+ }
16
+
17
+ SIZE_CONFIGS = {
18
+ '720*1280': (720, 1280),
19
+ '1280*720': (1280, 720),
20
+ '480*832': (480, 832),
21
+ '832*480': (832, 480),
22
+ '704*1280': (704, 1280),
23
+ '1280*704': (1280, 704)
24
+ }
25
+
26
+ MAX_AREA_CONFIGS = {
27
+ '720*1280': 720 * 1280,
28
+ '1280*720': 1280 * 720,
29
+ '480*832': 480 * 832,
30
+ '832*480': 832 * 480,
31
+ '704*1280': 704 * 1280,
32
+ '1280*704': 1280 * 704,
33
+ }
34
+
35
+ SUPPORTED_SIZES = {
36
+ 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
37
+ 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
38
+ 'ti2v-5B': ('704*1280', '1280*704'),
39
+ }
wan/configs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (737 Bytes). View file
 
wan/configs/__pycache__/shared_config.cpython-310.pyc ADDED
Binary file (848 Bytes). View file
 
wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc ADDED
Binary file (968 Bytes). View file
 
wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc ADDED
Binary file (955 Bytes). View file
 
wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc ADDED
Binary file (863 Bytes). View file
 
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ wan_shared_cfg.frame_num = 81
wan/configs/wan_i2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V A14B ------------------------#
8
+
9
+ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
+ i2v_A14B.update(wan_shared_cfg)
11
+
12
+ i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ i2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ i2v_A14B.patch_size = (1, 2, 2)
21
+ i2v_A14B.dim = 5120
22
+ i2v_A14B.ffn_dim = 13824
23
+ i2v_A14B.freq_dim = 256
24
+ i2v_A14B.num_heads = 40
25
+ i2v_A14B.num_layers = 40
26
+ i2v_A14B.window_size = (-1, -1)
27
+ i2v_A14B.qk_norm = True
28
+ i2v_A14B.cross_attn_norm = True
29
+ i2v_A14B.eps = 1e-6
30
+ i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ i2v_A14B.sample_shift = 5.0
35
+ i2v_A14B.sample_steps = 40
36
+ i2v_A14B.boundary = 0.900
37
+ i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
wan/configs/wan_t2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V A14B ------------------------#
7
+
8
+ t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
+ t2v_A14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_A14B.patch_size = (1, 2, 2)
21
+ t2v_A14B.dim = 5120
22
+ t2v_A14B.ffn_dim = 13824
23
+ t2v_A14B.freq_dim = 256
24
+ t2v_A14B.num_heads = 40
25
+ t2v_A14B.num_layers = 40
26
+ t2v_A14B.window_size = (-1, -1)
27
+ t2v_A14B.qk_norm = True
28
+ t2v_A14B.cross_attn_norm = True
29
+ t2v_A14B.eps = 1e-6
30
+ t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ t2v_A14B.sample_shift = 12.0
35
+ t2v_A14B.sample_steps = 40
36
+ t2v_A14B.boundary = 0.875
37
+ t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
wan/configs/wan_ti2v_5B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan TI2V 5B ------------------------#
7
+
8
+ ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
+ ti2v_5B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
+ ti2v_5B.vae_stride = (4, 16, 16)
18
+
19
+ # transformer
20
+ ti2v_5B.patch_size = (1, 2, 2)
21
+ ti2v_5B.dim = 3072
22
+ ti2v_5B.ffn_dim = 14336
23
+ ti2v_5B.freq_dim = 256
24
+ ti2v_5B.num_heads = 24
25
+ ti2v_5B.num_layers = 30
26
+ ti2v_5B.window_size = (-1, -1)
27
+ ti2v_5B.qk_norm = True
28
+ ti2v_5B.cross_attn_norm = True
29
+ ti2v_5B.eps = 1e-6
30
+
31
+ # inference
32
+ ti2v_5B.sample_fps = 24
33
+ ti2v_5B.sample_shift = 5.0
34
+ ti2v_5B.sample_steps = 50
35
+ ti2v_5B.sample_guide_scale = 5.0
36
+ ti2v_5B.frame_num = 121
wan/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan/distributed/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
wan/distributed/__pycache__/fsdp.cpython-310.pyc ADDED
Binary file (1.36 kB). View file
 
wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc ADDED
Binary file (5.24 kB). View file
 
wan/distributed/__pycache__/ulysses.cpython-310.pyc ADDED
Binary file (1.23 kB). View file
 
wan/distributed/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.93 kB). View file
 
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ ):
22
+ model = FSDP(
23
+ module=model,
24
+ process_group=process_group,
25
+ sharding_strategy=sharding_strategy,
26
+ auto_wrap_policy=partial(
27
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28
+ mixed_precision=MixedPrecision(
29
+ param_dtype=param_dtype,
30
+ reduce_dtype=reduce_dtype,
31
+ buffer_dtype=buffer_dtype),
32
+ device_id=device_id,
33
+ sync_module_states=sync_module_states)
34
+ return model
35
+
36
+
37
+ def free_model(model):
38
+ for m in model.modules():
39
+ if isinstance(m, FSDP):
40
+ _free_storage(m._handle.flat_param.data)
41
+ del model
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
wan/distributed/sequence_parallel.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+
5
+ from ..modules.model import sinusoidal_embedding_1d
6
+ from .ulysses import distributed_attention
7
+ from .util import gather_forward, get_rank, get_world_size
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+
23
+ @torch.amp.autocast('cuda', enabled=False)
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_world_size()
51
+ sp_rank = get_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output).float()
62
+
63
+
64
+ def sp_dit_forward(
65
+ self,
66
+ x,
67
+ t,
68
+ context,
69
+ seq_len,
70
+ y=None,
71
+ ):
72
+ """
73
+ x: A list of videos each with shape [C, T, H, W].
74
+ t: [B].
75
+ context: A list of text embeddings each with shape [L, C].
76
+ """
77
+ if self.model_type == 'i2v':
78
+ assert y is not None
79
+ # params
80
+ device = self.patch_embedding.weight.device
81
+ if self.freqs.device != device:
82
+ self.freqs = self.freqs.to(device)
83
+
84
+ if y is not None:
85
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
+
87
+ # embeddings
88
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
+ grid_sizes = torch.stack(
90
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
+ x = [u.flatten(2).transpose(1, 2) for u in x]
92
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
+ assert seq_lens.max() <= seq_len
94
+ x = torch.cat([
95
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
+ for u in x
97
+ ])
98
+
99
+ # time embeddings
100
+ if t.dim() == 1:
101
+ t = t.expand(t.size(0), seq_len)
102
+ with torch.amp.autocast('cuda', dtype=torch.float32):
103
+ bt = t.size(0)
104
+ t = t.flatten()
105
+ e = self.time_embedding(
106
+ sinusoidal_embedding_1d(self.freq_dim,
107
+ t).unflatten(0, (bt, seq_len)).float())
108
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
+
111
+ # context
112
+ context_lens = None
113
+ context = self.text_embedding(
114
+ torch.stack([
115
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
+ for u in context
117
+ ]))
118
+
119
+ # Context Parallel
120
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
+ e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
+ e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
+
124
+ # arguments
125
+ kwargs = dict(
126
+ e=e0,
127
+ seq_lens=seq_lens,
128
+ grid_sizes=grid_sizes,
129
+ freqs=self.freqs,
130
+ context=context,
131
+ context_lens=context_lens)
132
+
133
+ for block in self.blocks:
134
+ x = block(x, **kwargs)
135
+
136
+ # head
137
+ x = self.head(x, e)
138
+
139
+ # Context Parallel
140
+ x = gather_forward(x, dim=1)
141
+
142
+ # unpatchify
143
+ x = self.unpatchify(x, grid_sizes)
144
+ return [u.float() for u in x]
145
+
146
+
147
+ def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ half_dtypes = (torch.float16, torch.bfloat16)
150
+
151
+ def half(x):
152
+ return x if x.dtype in half_dtypes else x.to(dtype)
153
+
154
+ # query, key, value function
155
+ def qkv_fn(x):
156
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
157
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
158
+ v = self.v(x).view(b, s, n, d)
159
+ return q, k, v
160
+
161
+ q, k, v = qkv_fn(x)
162
+ q = rope_apply(q, grid_sizes, freqs)
163
+ k = rope_apply(k, grid_sizes, freqs)
164
+
165
+ x = distributed_attention(
166
+ half(q),
167
+ half(k),
168
+ half(v),
169
+ seq_lens,
170
+ window_size=self.window_size,
171
+ )
172
+
173
+ # output
174
+ x = x.flatten(2)
175
+ x = self.o(x)
176
+ return x
wan/distributed/ulysses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from ..modules.attention import flash_attention
6
+ from .util import all_to_all
7
+
8
+
9
+ def distributed_attention(
10
+ q,
11
+ k,
12
+ v,
13
+ seq_lens,
14
+ window_size=(-1, -1),
15
+ ):
16
+ """
17
+ Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
+ please refer to https://arxiv.org/pdf/2309.14509
19
+
20
+ Args:
21
+ q: [B, Lq // p, Nq, C1].
22
+ k: [B, Lk // p, Nk, C1].
23
+ v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
+ seq_lens: [B], length of each sequence in batch
25
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
+ """
27
+ if not dist.is_initialized():
28
+ raise ValueError("distributed group should be initialized.")
29
+ b = q.shape[0]
30
+
31
+ # gather q/k/v sequence
32
+ q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
+ k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
+ v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
+
36
+ # apply attention
37
+ x = flash_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ k_lens=seq_lens,
42
+ window_size=window_size,
43
+ )
44
+
45
+ # scatter q/k/v sequence
46
+ x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
+ return x
wan/distributed/util.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def init_distributed_group():
7
+ """r initialize sequence parallel group.
8
+ """
9
+ if not dist.is_initialized():
10
+ dist.init_process_group(backend='nccl')
11
+
12
+
13
+ def get_rank():
14
+ return dist.get_rank()
15
+
16
+
17
+ def get_world_size():
18
+ return dist.get_world_size()
19
+
20
+
21
+ def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
+ """
23
+ `scatter` along one dimension and `gather` along another.
24
+ """
25
+ world_size = get_world_size()
26
+ if world_size > 1:
27
+ inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
+ outputs = [torch.empty_like(u) for u in inputs]
29
+ dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
+ x = torch.cat(outputs, dim=gather_dim).contiguous()
31
+ return x
32
+
33
+
34
+ def all_gather(tensor):
35
+ world_size = dist.get_world_size()
36
+ if world_size == 1:
37
+ return [tensor]
38
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
+ torch.distributed.all_gather(tensor_list, tensor)
40
+ return tensor_list
41
+
42
+
43
+ def gather_forward(input, dim):
44
+ # skip if world_size == 1
45
+ world_size = dist.get_world_size()
46
+ if world_size == 1:
47
+ return input
48
+
49
+ # gather sequence
50
+ output = all_gather(input)
51
+ return torch.cat(output, dim=dim).contiguous()
wan/image2video.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_1 import Wan2_1_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+
32
+
33
+ class WanI2V:
34
+
35
+ def __init__(
36
+ self,
37
+ config,
38
+ checkpoint_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_sp=False,
44
+ t5_cpu=False,
45
+ init_on_cpu=True,
46
+ convert_model_dtype=False,
47
+ ):
48
+ r"""
49
+ Initializes the image-to-video generation model components.
50
+
51
+ Args:
52
+ config (EasyDict):
53
+ Object containing model parameters initialized from config.py
54
+ checkpoint_dir (`str`):
55
+ Path to directory containing model checkpoints
56
+ device_id (`int`, *optional*, defaults to 0):
57
+ Id of target GPU device
58
+ rank (`int`, *optional*, defaults to 0):
59
+ Process rank for distributed training
60
+ t5_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for T5 model
62
+ dit_fsdp (`bool`, *optional*, defaults to False):
63
+ Enable FSDP sharding for DiT model
64
+ use_sp (`bool`, *optional*, defaults to False):
65
+ Enable distribution strategy of sequence parallel.
66
+ t5_cpu (`bool`, *optional*, defaults to False):
67
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
68
+ init_on_cpu (`bool`, *optional*, defaults to True):
69
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
+ convert_model_dtype (`bool`, *optional*, defaults to False):
71
+ Convert DiT model parameters dtype to 'config.param_dtype'.
72
+ Only works without FSDP.
73
+ """
74
+ self.device = torch.device(f"cuda:{device_id}")
75
+ self.config = config
76
+ self.rank = rank
77
+ self.t5_cpu = t5_cpu
78
+ self.init_on_cpu = init_on_cpu
79
+
80
+ self.num_train_timesteps = config.num_train_timesteps
81
+ self.boundary = config.boundary
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None,
95
+ )
96
+
97
+ self.vae_stride = config.vae_stride
98
+ self.patch_size = config.patch_size
99
+ self.vae = Wan2_1_VAE(
100
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
+ device=self.device)
102
+
103
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
104
+ self.low_noise_model = WanModel.from_pretrained(
105
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
+ self.low_noise_model = self._configure_model(
107
+ model=self.low_noise_model,
108
+ use_sp=use_sp,
109
+ dit_fsdp=dit_fsdp,
110
+ shard_fn=shard_fn,
111
+ convert_model_dtype=convert_model_dtype)
112
+
113
+ self.high_noise_model = WanModel.from_pretrained(
114
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
+ self.high_noise_model = self._configure_model(
116
+ model=self.high_noise_model,
117
+ use_sp=use_sp,
118
+ dit_fsdp=dit_fsdp,
119
+ shard_fn=shard_fn,
120
+ convert_model_dtype=convert_model_dtype)
121
+ if use_sp:
122
+ self.sp_size = get_world_size()
123
+ else:
124
+ self.sp_size = 1
125
+
126
+ self.sample_neg_prompt = config.sample_neg_prompt
127
+
128
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
+ convert_model_dtype):
130
+ """
131
+ Configures a model object. This includes setting evaluation modes,
132
+ applying distributed parallel strategy, and handling device placement.
133
+
134
+ Args:
135
+ model (torch.nn.Module):
136
+ The model instance to configure.
137
+ use_sp (`bool`):
138
+ Enable distribution strategy of sequence parallel.
139
+ dit_fsdp (`bool`):
140
+ Enable FSDP sharding for DiT model.
141
+ shard_fn (callable):
142
+ The function to apply FSDP sharding.
143
+ convert_model_dtype (`bool`):
144
+ Convert DiT model parameters dtype to 'config.param_dtype'.
145
+ Only works without FSDP.
146
+
147
+ Returns:
148
+ torch.nn.Module:
149
+ The configured model.
150
+ """
151
+ model.eval().requires_grad_(False)
152
+
153
+ if use_sp:
154
+ for block in model.blocks:
155
+ block.self_attn.forward = types.MethodType(
156
+ sp_attn_forward, block.self_attn)
157
+ model.forward = types.MethodType(sp_dit_forward, model)
158
+
159
+ if dist.is_initialized():
160
+ dist.barrier()
161
+
162
+ if dit_fsdp:
163
+ model = shard_fn(model)
164
+ else:
165
+ if convert_model_dtype:
166
+ model.to(self.param_dtype)
167
+ if not self.init_on_cpu:
168
+ model.to(self.device)
169
+
170
+ return model
171
+
172
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
+ r"""
174
+ Prepares and returns the required model for the current timestep.
175
+
176
+ Args:
177
+ t (torch.Tensor):
178
+ current timestep.
179
+ boundary (`int`):
180
+ The timestep threshold. If `t` is at or above this value,
181
+ the `high_noise_model` is considered as the required model.
182
+ offload_model (`bool`):
183
+ A flag intended to control the offloading behavior.
184
+
185
+ Returns:
186
+ torch.nn.Module:
187
+ The active model on the target device for the current timestep.
188
+ """
189
+ if t.item() >= boundary:
190
+ required_model_name = 'high_noise_model'
191
+ offload_model_name = 'low_noise_model'
192
+ else:
193
+ required_model_name = 'low_noise_model'
194
+ offload_model_name = 'high_noise_model'
195
+ if offload_model or self.init_on_cpu:
196
+ if next(getattr(
197
+ self,
198
+ offload_model_name).parameters()).device.type == 'cuda':
199
+ getattr(self, offload_model_name).to('cpu')
200
+ if next(getattr(
201
+ self,
202
+ required_model_name).parameters()).device.type == 'cpu':
203
+ getattr(self, required_model_name).to(self.device)
204
+ return getattr(self, required_model_name)
205
+
206
+ def generate(self,
207
+ input_prompt,
208
+ img,
209
+ max_area=720 * 1280,
210
+ frame_num=81,
211
+ shift=5.0,
212
+ sample_solver='unipc',
213
+ sampling_steps=40,
214
+ guide_scale=5.0,
215
+ n_prompt="",
216
+ seed=-1,
217
+ offload_model=True):
218
+ r"""
219
+ Generates video frames from input image and text prompt using diffusion process.
220
+
221
+ Args:
222
+ input_prompt (`str`):
223
+ Text prompt for content generation.
224
+ img (PIL.Image.Image):
225
+ Input image tensor. Shape: [3, H, W]
226
+ max_area (`int`, *optional*, defaults to 720*1280):
227
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
228
+ frame_num (`int`, *optional*, defaults to 81):
229
+ How many frames to sample from a video. The number should be 4n+1
230
+ shift (`float`, *optional*, defaults to 5.0):
231
+ Noise schedule shift parameter. Affects temporal dynamics
232
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
234
+ Solver used to sample the video.
235
+ sampling_steps (`int`, *optional*, defaults to 40):
236
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
237
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
+ If tuple, the first guide_scale will be used for low noise model and
240
+ the second guide_scale will be used for high noise model.
241
+ n_prompt (`str`, *optional*, defaults to ""):
242
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
+ seed (`int`, *optional*, defaults to -1):
244
+ Random seed for noise generation. If -1, use random seed
245
+ offload_model (`bool`, *optional*, defaults to True):
246
+ If True, offloads models to CPU during generation to save VRAM
247
+
248
+ Returns:
249
+ torch.Tensor:
250
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
251
+ - C: Color channels (3 for RGB)
252
+ - N: Number of frames (81)
253
+ - H: Frame height (from max_area)
254
+ - W: Frame width from max_area)
255
+ """
256
+ # preprocess
257
+ guide_scale = (guide_scale, guide_scale) if isinstance(
258
+ guide_scale, float) else guide_scale
259
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
+
261
+ F = frame_num
262
+ h, w = img.shape[1:]
263
+ aspect_ratio = h / w
264
+ lat_h = round(
265
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
+ self.patch_size[1] * self.patch_size[1])
267
+ lat_w = round(
268
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
+ self.patch_size[2] * self.patch_size[2])
270
+ h = lat_h * self.vae_stride[1]
271
+ w = lat_w * self.vae_stride[2]
272
+
273
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
+ self.patch_size[1] * self.patch_size[2])
275
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
+
277
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
+ seed_g = torch.Generator(device=self.device)
279
+ seed_g.manual_seed(seed)
280
+ noise = torch.randn(
281
+ 16,
282
+ 21,
283
+ lat_h,
284
+ lat_w,
285
+ dtype=torch.float32,
286
+ generator=seed_g,
287
+ device=self.device)
288
+
289
+ msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
290
+ msk[:, 1:] = 0
291
+ msk = torch.concat([
292
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
+ ],
294
+ dim=1)
295
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
+ msk = msk.transpose(1, 2)[0]
297
+
298
+ if n_prompt == "":
299
+ n_prompt = self.sample_neg_prompt
300
+
301
+ # preprocess
302
+ if not self.t5_cpu:
303
+ self.text_encoder.model.to(self.device)
304
+ context = self.text_encoder([input_prompt], self.device)
305
+ context_null = self.text_encoder([n_prompt], self.device)
306
+ if offload_model:
307
+ self.text_encoder.model.cpu()
308
+ else:
309
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
310
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
+ context = [t.to(self.device) for t in context]
312
+ context_null = [t.to(self.device) for t in context_null]
313
+
314
+ y = self.vae.encode([
315
+ torch.concat([
316
+ torch.nn.functional.interpolate(
317
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
+ 0, 1),
319
+ torch.zeros(3, 80, h, w)
320
+ ],
321
+ dim=1).to(self.device)
322
+ ])[0]
323
+ y = torch.concat([msk, y])
324
+
325
+ @contextmanager
326
+ def noop_no_sync():
327
+ yield
328
+
329
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
+ noop_no_sync)
331
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
+ noop_no_sync)
333
+
334
+ # evaluation mode
335
+ with (
336
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
337
+ torch.no_grad(),
338
+ no_sync_low_noise(),
339
+ no_sync_high_noise(),
340
+ ):
341
+ boundary = self.boundary * self.num_train_timesteps
342
+
343
+ if sample_solver == 'unipc':
344
+ sample_scheduler = FlowUniPCMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sample_scheduler.set_timesteps(
349
+ sampling_steps, device=self.device, shift=shift)
350
+ timesteps = sample_scheduler.timesteps
351
+ elif sample_solver == 'dpm++':
352
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
353
+ num_train_timesteps=self.num_train_timesteps,
354
+ shift=1,
355
+ use_dynamic_shifting=False)
356
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
+ timesteps, _ = retrieve_timesteps(
358
+ sample_scheduler,
359
+ device=self.device,
360
+ sigmas=sampling_sigmas)
361
+ else:
362
+ raise NotImplementedError("Unsupported solver.")
363
+
364
+ # sample videos
365
+ latent = noise
366
+
367
+ arg_c = {
368
+ 'context': [context[0]],
369
+ 'seq_len': max_seq_len,
370
+ 'y': [y],
371
+ }
372
+
373
+ arg_null = {
374
+ 'context': context_null,
375
+ 'seq_len': max_seq_len,
376
+ 'y': [y],
377
+ }
378
+
379
+ if offload_model:
380
+ torch.cuda.empty_cache()
381
+
382
+ for _, t in enumerate(tqdm(timesteps)):
383
+ latent_model_input = [latent.to(self.device)]
384
+ timestep = [t]
385
+
386
+ timestep = torch.stack(timestep).to(self.device)
387
+
388
+ model = self._prepare_model_for_timestep(
389
+ t, boundary, offload_model)
390
+ sample_guide_scale = guide_scale[1] if t.item(
391
+ ) >= boundary else guide_scale[0]
392
+
393
+ noise_pred_cond = model(
394
+ latent_model_input, t=timestep, **arg_c)[0]
395
+ if offload_model:
396
+ torch.cuda.empty_cache()
397
+ noise_pred_uncond = model(
398
+ latent_model_input, t=timestep, **arg_null)[0]
399
+ if offload_model:
400
+ torch.cuda.empty_cache()
401
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
402
+ noise_pred_cond - noise_pred_uncond)
403
+
404
+ temp_x0 = sample_scheduler.step(
405
+ noise_pred.unsqueeze(0),
406
+ t,
407
+ latent.unsqueeze(0),
408
+ return_dict=False,
409
+ generator=seed_g)[0]
410
+ latent = temp_x0.squeeze(0)
411
+
412
+ x0 = [latent]
413
+ del latent_model_input, timestep
414
+
415
+ if offload_model:
416
+ self.low_noise_model.cpu()
417
+ self.high_noise_model.cpu()
418
+ torch.cuda.empty_cache()
419
+
420
+ if self.rank == 0:
421
+ videos = self.vae.decode(x0)
422
+
423
+ del noise, latent, x0
424
+ del sample_scheduler
425
+ if offload_model:
426
+ gc.collect()
427
+ torch.cuda.synchronize()
428
+ if dist.is_initialized():
429
+ dist.barrier()
430
+
431
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .attention import flash_attention
3
+ from .model import WanModel
4
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
+ from .tokenizers import HuggingfaceTokenizer
6
+ from .vae2_1 import Wan2_1_VAE
7
+ from .vae2_2 import Wan2_2_VAE
8
+
9
+ __all__ = [
10
+ 'Wan2_1_VAE',
11
+ 'Wan2_2_VAE',
12
+ 'WanModel',
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ 'HuggingfaceTokenizer',
18
+ 'flash_attention',
19
+ ]
wan/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (528 Bytes). View file
 
wan/modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (3.95 kB). View file
 
wan/modules/__pycache__/model.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
wan/modules/__pycache__/t5.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
wan/modules/__pycache__/tokenizers.cpython-310.pyc ADDED
Binary file (2.55 kB). View file
 
wan/modules/__pycache__/vae2_1.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
wan/modules/__pycache__/vae2_2.cpython-310.pyc ADDED
Binary file (22.1 kB). View file
 
wan/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+ FLASH_ATTN_3_AVAILABLE = True
7
+ except ModuleNotFoundError:
8
+ FLASH_ATTN_3_AVAILABLE = False
9
+
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTN_2_AVAILABLE = True
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_2_AVAILABLE = False
15
+
16
+ import warnings
17
+
18
+ __all__ = [
19
+ 'flash_attention',
20
+ 'attention',
21
+ ]
22
+
23
+
24
+ def flash_attention(
25
+ q,
26
+ k,
27
+ v,
28
+ q_lens=None,
29
+ k_lens=None,
30
+ dropout_p=0.,
31
+ softmax_scale=None,
32
+ q_scale=None,
33
+ causal=False,
34
+ window_size=(-1, -1),
35
+ deterministic=False,
36
+ dtype=torch.bfloat16,
37
+ version=None,
38
+ ):
39
+ """
40
+ q: [B, Lq, Nq, C1].
41
+ k: [B, Lk, Nk, C1].
42
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43
+ q_lens: [B].
44
+ k_lens: [B].
45
+ dropout_p: float. Dropout probability.
46
+ softmax_scale: float. The scaling of QK^T before applying softmax.
47
+ causal: bool. Whether to apply causal attention mask.
48
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
49
+ deterministic: bool. If True, slightly slower and uses more memory.
50
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
+ """
52
+ half_dtypes = (torch.float16, torch.bfloat16)
53
+ assert dtype in half_dtypes
54
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
55
+
56
+ # params
57
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
+
59
+ def half(x):
60
+ return x if x.dtype in half_dtypes else x.to(dtype)
61
+
62
+ # preprocess query
63
+ if q_lens is None:
64
+ q = half(q.flatten(0, 1))
65
+ q_lens = torch.tensor(
66
+ [lq] * b, dtype=torch.int32).to(
67
+ device=q.device, non_blocking=True)
68
+ else:
69
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
+
71
+ # preprocess key, value
72
+ if k_lens is None:
73
+ k = half(k.flatten(0, 1))
74
+ v = half(v.flatten(0, 1))
75
+ k_lens = torch.tensor(
76
+ [lk] * b, dtype=torch.int32).to(
77
+ device=k.device, non_blocking=True)
78
+ else:
79
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81
+
82
+ q = q.to(v.dtype)
83
+ k = k.to(v.dtype)
84
+
85
+ if q_scale is not None:
86
+ q = q * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn(
90
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
91
+ )
92
+
93
+ # apply attention
94
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95
+ # Note: dropout_p, window_size are not supported in FA3 now.
96
+ x = flash_attn_interface.flash_attn_varlen_func(
97
+ q=q,
98
+ k=k,
99
+ v=v,
100
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
102
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
104
+ seqused_q=None,
105
+ seqused_k=None,
106
+ max_seqlen_q=lq,
107
+ max_seqlen_k=lk,
108
+ softmax_scale=softmax_scale,
109
+ causal=causal,
110
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
111
+ else:
112
+ assert FLASH_ATTN_2_AVAILABLE
113
+ x = flash_attn.flash_attn_varlen_func(
114
+ q=q,
115
+ k=k,
116
+ v=v,
117
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
119
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ max_seqlen_q=lq,
122
+ max_seqlen_k=lk,
123
+ dropout_p=dropout_p,
124
+ softmax_scale=softmax_scale,
125
+ causal=causal,
126
+ window_size=window_size,
127
+ deterministic=deterministic).unflatten(0, (b, lq))
128
+
129
+ # output
130
+ return x.type(out_dtype)
131
+
132
+
133
+ def attention(
134
+ q,
135
+ k,
136
+ v,
137
+ q_lens=None,
138
+ k_lens=None,
139
+ dropout_p=0.,
140
+ softmax_scale=None,
141
+ q_scale=None,
142
+ causal=False,
143
+ window_size=(-1, -1),
144
+ deterministic=False,
145
+ dtype=torch.bfloat16,
146
+ fa_version=None,
147
+ ):
148
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
+ return flash_attention(
150
+ q=q,
151
+ k=k,
152
+ v=v,
153
+ q_lens=q_lens,
154
+ k_lens=k_lens,
155
+ dropout_p=dropout_p,
156
+ softmax_scale=softmax_scale,
157
+ q_scale=q_scale,
158
+ causal=causal,
159
+ window_size=window_size,
160
+ deterministic=deterministic,
161
+ dtype=dtype,
162
+ version=fa_version,
163
+ )
164
+ else:
165
+ if q_lens is not None or k_lens is not None:
166
+ warnings.warn(
167
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168
+ )
169
+ attn_mask = None
170
+
171
+ q = q.transpose(1, 2).to(dtype)
172
+ k = k.transpose(1, 2).to(dtype)
173
+ v = v.transpose(1, 2).to(dtype)
174
+
175
+ out = torch.nn.functional.scaled_dot_product_attention(
176
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
wan/modules/model.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+ from .attention import flash_attention
10
+
11
+ __all__ = ['WanModel']
12
+
13
+
14
+ def sinusoidal_embedding_1d(dim, position):
15
+ # preprocess
16
+ assert dim % 2 == 0
17
+ half = dim // 2
18
+ position = position.type(torch.float64)
19
+
20
+ # calculation
21
+ sinusoid = torch.outer(
22
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
23
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
24
+ return x
25
+
26
+
27
+ @torch.amp.autocast('cuda', enabled=False)
28
+ def rope_params(max_seq_len, dim, theta=10000):
29
+ assert dim % 2 == 0
30
+ freqs = torch.outer(
31
+ torch.arange(max_seq_len),
32
+ 1.0 / torch.pow(theta,
33
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
34
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
35
+ return freqs
36
+
37
+
38
+ @torch.amp.autocast('cuda', enabled=False)
39
+ def rope_apply(x, grid_sizes, freqs):
40
+ n, c = x.size(2), x.size(3) // 2
41
+
42
+ # split freqs
43
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
+
45
+ # loop over samples
46
+ output = []
47
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
+ seq_len = f * h * w
49
+
50
+ # precompute multipliers
51
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
52
+ seq_len, n, -1, 2))
53
+ freqs_i = torch.cat([
54
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
55
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
56
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
57
+ ],
58
+ dim=-1).reshape(seq_len, 1, -1)
59
+
60
+ # apply rotary embedding
61
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
62
+ x_i = torch.cat([x_i, x[i, seq_len:]])
63
+
64
+ # append to collection
65
+ output.append(x_i)
66
+ return torch.stack(output).float()
67
+
68
+
69
+ class WanRMSNorm(nn.Module):
70
+
71
+ def __init__(self, dim, eps=1e-5):
72
+ super().__init__()
73
+ self.dim = dim
74
+ self.eps = eps
75
+ self.weight = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x):
78
+ r"""
79
+ Args:
80
+ x(Tensor): Shape [B, L, C]
81
+ """
82
+ return self._norm(x.float()).type_as(x) * self.weight
83
+
84
+ def _norm(self, x):
85
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
86
+
87
+
88
+ class WanLayerNorm(nn.LayerNorm):
89
+
90
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
91
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return super().forward(x.float()).type_as(x)
99
+
100
+
101
+ class WanSelfAttention(nn.Module):
102
+
103
+ def __init__(self,
104
+ dim,
105
+ num_heads,
106
+ window_size=(-1, -1),
107
+ qk_norm=True,
108
+ eps=1e-6):
109
+ assert dim % num_heads == 0
110
+ super().__init__()
111
+ self.dim = dim
112
+ self.num_heads = num_heads
113
+ self.head_dim = dim // num_heads
114
+ self.window_size = window_size
115
+ self.qk_norm = qk_norm
116
+ self.eps = eps
117
+
118
+ # layers
119
+ self.q = nn.Linear(dim, dim)
120
+ self.k = nn.Linear(dim, dim)
121
+ self.v = nn.Linear(dim, dim)
122
+ self.o = nn.Linear(dim, dim)
123
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
124
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+
126
+ def forward(self, x, seq_lens, grid_sizes, freqs):
127
+ r"""
128
+ Args:
129
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
130
+ seq_lens(Tensor): Shape [B]
131
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
132
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133
+ """
134
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
+
136
+ # query, key, value function
137
+ def qkv_fn(x):
138
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
139
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
140
+ v = self.v(x).view(b, s, n, d)
141
+ return q, k, v
142
+
143
+ q, k, v = qkv_fn(x)
144
+
145
+ x = flash_attention(
146
+ q=rope_apply(q, grid_sizes, freqs),
147
+ k=rope_apply(k, grid_sizes, freqs),
148
+ v=v,
149
+ k_lens=seq_lens,
150
+ window_size=self.window_size)
151
+
152
+ # output
153
+ x = x.flatten(2)
154
+ x = self.o(x)
155
+ return x
156
+
157
+
158
+ class WanCrossAttention(WanSelfAttention):
159
+
160
+ def forward(self, x, context, context_lens):
161
+ r"""
162
+ Args:
163
+ x(Tensor): Shape [B, L1, C]
164
+ context(Tensor): Shape [B, L2, C]
165
+ context_lens(Tensor): Shape [B]
166
+ """
167
+ b, n, d = x.size(0), self.num_heads, self.head_dim
168
+
169
+ # compute query, key, value
170
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
171
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
172
+ v = self.v(context).view(b, -1, n, d)
173
+
174
+ # compute attention
175
+ x = flash_attention(q, k, v, k_lens=context_lens)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
181
+
182
+
183
+ class WanAttentionBlock(nn.Module):
184
+
185
+ def __init__(self,
186
+ dim,
187
+ ffn_dim,
188
+ num_heads,
189
+ window_size=(-1, -1),
190
+ qk_norm=True,
191
+ cross_attn_norm=False,
192
+ eps=1e-6):
193
+ super().__init__()
194
+ self.dim = dim
195
+ self.ffn_dim = ffn_dim
196
+ self.num_heads = num_heads
197
+ self.window_size = window_size
198
+ self.qk_norm = qk_norm
199
+ self.cross_attn_norm = cross_attn_norm
200
+ self.eps = eps
201
+
202
+ # layers
203
+ self.norm1 = WanLayerNorm(dim, eps)
204
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
205
+ eps)
206
+ self.norm3 = WanLayerNorm(
207
+ dim, eps,
208
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
209
+ self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
210
+ eps)
211
+ self.norm2 = WanLayerNorm(dim, eps)
212
+ self.ffn = nn.Sequential(
213
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
214
+ nn.Linear(ffn_dim, dim))
215
+
216
+ # modulation
217
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
218
+
219
+ def forward(
220
+ self,
221
+ x,
222
+ e,
223
+ seq_lens,
224
+ grid_sizes,
225
+ freqs,
226
+ context,
227
+ context_lens,
228
+ ):
229
+ r"""
230
+ Args:
231
+ x(Tensor): Shape [B, L, C]
232
+ e(Tensor): Shape [B, L1, 6, C]
233
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
234
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
235
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
236
+ """
237
+ assert e.dtype == torch.float32
238
+ with torch.amp.autocast('cuda', dtype=torch.float32):
239
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
240
+ assert e[0].dtype == torch.float32
241
+
242
+ # self-attention
243
+ y = self.self_attn(
244
+ self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
245
+ seq_lens, grid_sizes, freqs)
246
+ with torch.amp.autocast('cuda', dtype=torch.float32):
247
+ x = x + y * e[2].squeeze(2)
248
+
249
+ # cross-attention & ffn function
250
+ def cross_attn_ffn(x, context, context_lens, e):
251
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
252
+ y = self.ffn(
253
+ self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
254
+ with torch.amp.autocast('cuda', dtype=torch.float32):
255
+ x = x + y * e[5].squeeze(2)
256
+ return x
257
+
258
+ x = cross_attn_ffn(x, context, context_lens, e)
259
+ return x
260
+
261
+
262
+ class Head(nn.Module):
263
+
264
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
265
+ super().__init__()
266
+ self.dim = dim
267
+ self.out_dim = out_dim
268
+ self.patch_size = patch_size
269
+ self.eps = eps
270
+
271
+ # layers
272
+ out_dim = math.prod(patch_size) * out_dim
273
+ self.norm = WanLayerNorm(dim, eps)
274
+ self.head = nn.Linear(dim, out_dim)
275
+
276
+ # modulation
277
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
278
+
279
+ def forward(self, x, e):
280
+ r"""
281
+ Args:
282
+ x(Tensor): Shape [B, L1, C]
283
+ e(Tensor): Shape [B, L1, C]
284
+ """
285
+ assert e.dtype == torch.float32
286
+ with torch.amp.autocast('cuda', dtype=torch.float32):
287
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
288
+ x = (
289
+ self.head(
290
+ self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
291
+ return x
292
+
293
+
294
+ class WanModel(ModelMixin, ConfigMixin):
295
+ r"""
296
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
297
+ """
298
+
299
+ ignore_for_config = [
300
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
301
+ ]
302
+ _no_split_modules = ['WanAttentionBlock']
303
+
304
+ @register_to_config
305
+ def __init__(self,
306
+ model_type='t2v',
307
+ patch_size=(1, 2, 2),
308
+ text_len=512,
309
+ in_dim=16,
310
+ dim=2048,
311
+ ffn_dim=8192,
312
+ freq_dim=256,
313
+ text_dim=4096,
314
+ out_dim=16,
315
+ num_heads=16,
316
+ num_layers=32,
317
+ window_size=(-1, -1),
318
+ qk_norm=True,
319
+ cross_attn_norm=True,
320
+ eps=1e-6):
321
+ r"""
322
+ Initialize the diffusion model backbone.
323
+
324
+ Args:
325
+ model_type (`str`, *optional*, defaults to 't2v'):
326
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
327
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
328
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
329
+ text_len (`int`, *optional*, defaults to 512):
330
+ Fixed length for text embeddings
331
+ in_dim (`int`, *optional*, defaults to 16):
332
+ Input video channels (C_in)
333
+ dim (`int`, *optional*, defaults to 2048):
334
+ Hidden dimension of the transformer
335
+ ffn_dim (`int`, *optional*, defaults to 8192):
336
+ Intermediate dimension in feed-forward network
337
+ freq_dim (`int`, *optional*, defaults to 256):
338
+ Dimension for sinusoidal time embeddings
339
+ text_dim (`int`, *optional*, defaults to 4096):
340
+ Input dimension for text embeddings
341
+ out_dim (`int`, *optional*, defaults to 16):
342
+ Output video channels (C_out)
343
+ num_heads (`int`, *optional*, defaults to 16):
344
+ Number of attention heads
345
+ num_layers (`int`, *optional*, defaults to 32):
346
+ Number of transformer blocks
347
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
348
+ Window size for local attention (-1 indicates global attention)
349
+ qk_norm (`bool`, *optional*, defaults to True):
350
+ Enable query/key normalization
351
+ cross_attn_norm (`bool`, *optional*, defaults to False):
352
+ Enable cross-attention normalization
353
+ eps (`float`, *optional*, defaults to 1e-6):
354
+ Epsilon value for normalization layers
355
+ """
356
+
357
+ super().__init__()
358
+
359
+ assert model_type in ['t2v', 'i2v', 'ti2v']
360
+ self.model_type = model_type
361
+
362
+ self.patch_size = patch_size
363
+ self.text_len = text_len
364
+ self.in_dim = in_dim
365
+ self.dim = dim
366
+ self.ffn_dim = ffn_dim
367
+ self.freq_dim = freq_dim
368
+ self.text_dim = text_dim
369
+ self.out_dim = out_dim
370
+ self.num_heads = num_heads
371
+ self.num_layers = num_layers
372
+ self.window_size = window_size
373
+ self.qk_norm = qk_norm
374
+ self.cross_attn_norm = cross_attn_norm
375
+ self.eps = eps
376
+
377
+ # embeddings
378
+ self.patch_embedding = nn.Conv3d(
379
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
380
+ self.text_embedding = nn.Sequential(
381
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
382
+ nn.Linear(dim, dim))
383
+
384
+ self.time_embedding = nn.Sequential(
385
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
386
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
387
+
388
+ # blocks
389
+ self.blocks = nn.ModuleList([
390
+ WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
391
+ cross_attn_norm, eps) for _ in range(num_layers)
392
+ ])
393
+
394
+ # head
395
+ self.head = Head(dim, out_dim, patch_size, eps)
396
+
397
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
398
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
399
+ d = dim // num_heads
400
+ self.freqs = torch.cat([
401
+ rope_params(1024, d - 4 * (d // 6)),
402
+ rope_params(1024, 2 * (d // 6)),
403
+ rope_params(1024, 2 * (d // 6))
404
+ ],
405
+ dim=1)
406
+
407
+ # initialize weights
408
+ self.init_weights()
409
+
410
+ def forward(
411
+ self,
412
+ x,
413
+ t,
414
+ context,
415
+ seq_len,
416
+ y=None,
417
+ ):
418
+ r"""
419
+ Forward pass through the diffusion model
420
+
421
+ Args:
422
+ x (List[Tensor]):
423
+ List of input video tensors, each with shape [C_in, F, H, W]
424
+ t (Tensor):
425
+ Diffusion timesteps tensor of shape [B]
426
+ context (List[Tensor]):
427
+ List of text embeddings each with shape [L, C]
428
+ seq_len (`int`):
429
+ Maximum sequence length for positional encoding
430
+ y (List[Tensor], *optional*):
431
+ Conditional video inputs for image-to-video mode, same shape as x
432
+
433
+ Returns:
434
+ List[Tensor]:
435
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
436
+ """
437
+ if self.model_type == 'i2v':
438
+ assert y is not None
439
+ # params
440
+ device = self.patch_embedding.weight.device
441
+ if self.freqs.device != device:
442
+ self.freqs = self.freqs.to(device)
443
+
444
+ if y is not None:
445
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
446
+
447
+ # embeddings
448
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
449
+ grid_sizes = torch.stack(
450
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
451
+ x = [u.flatten(2).transpose(1, 2) for u in x]
452
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
453
+ assert seq_lens.max() <= seq_len
454
+ x = torch.cat([
455
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
456
+ dim=1) for u in x
457
+ ])
458
+
459
+ # time embeddings
460
+ if t.dim() == 1:
461
+ t = t.expand(t.size(0), seq_len)
462
+ with torch.amp.autocast('cuda', dtype=torch.float32):
463
+ bt = t.size(0)
464
+ t = t.flatten()
465
+ e = self.time_embedding(
466
+ sinusoidal_embedding_1d(self.freq_dim,
467
+ t).unflatten(0, (bt, seq_len)).float())
468
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
469
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
+
471
+ # context
472
+ context_lens = None
473
+ context = self.text_embedding(
474
+ torch.stack([
475
+ torch.cat(
476
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
477
+ for u in context
478
+ ]))
479
+
480
+ # arguments
481
+ kwargs = dict(
482
+ e=e0,
483
+ seq_lens=seq_lens,
484
+ grid_sizes=grid_sizes,
485
+ freqs=self.freqs,
486
+ context=context,
487
+ context_lens=context_lens)
488
+
489
+ for block in self.blocks:
490
+ x = block(x, **kwargs)
491
+
492
+ # head
493
+ x = self.head(x, e)
494
+
495
+ # unpatchify
496
+ x = self.unpatchify(x, grid_sizes)
497
+ return [u.float() for u in x]
498
+
499
+ def unpatchify(self, x, grid_sizes):
500
+ r"""
501
+ Reconstruct video tensors from patch embeddings.
502
+
503
+ Args:
504
+ x (List[Tensor]):
505
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
506
+ grid_sizes (Tensor):
507
+ Original spatial-temporal grid dimensions before patching,
508
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
509
+
510
+ Returns:
511
+ List[Tensor]:
512
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
513
+ """
514
+
515
+ c = self.out_dim
516
+ out = []
517
+ for u, v in zip(x, grid_sizes.tolist()):
518
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
519
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
520
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
521
+ out.append(u)
522
+ return out
523
+
524
+ def init_weights(self):
525
+ r"""
526
+ Initialize model parameters using Xavier initialization.
527
+ """
528
+
529
+ # basic init
530
+ for m in self.modules():
531
+ if isinstance(m, nn.Linear):
532
+ nn.init.xavier_uniform_(m.weight)
533
+ if m.bias is not None:
534
+ nn.init.zeros_(m.bias)
535
+
536
+ # init embeddings
537
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
538
+ for m in self.text_embedding.modules():
539
+ if isinstance(m, nn.Linear):
540
+ nn.init.normal_(m.weight, std=.02)
541
+ for m in self.time_embedding.modules():
542
+ if isinstance(m, nn.Linear):
543
+ nn.init.normal_(m.weight, std=.02)
544
+
545
+ # init output layer
546
+ nn.init.zeros_(self.head.head.weight)
wan/modules/t5.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+
48
+ def forward(self, x):
49
+ return 0.5 * x * (1.0 + torch.tanh(
50
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
+
52
+
53
+ class T5LayerNorm(nn.Module):
54
+
55
+ def __init__(self, dim, eps=1e-6):
56
+ super(T5LayerNorm, self).__init__()
57
+ self.dim = dim
58
+ self.eps = eps
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
+ self.eps)
64
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
+ x = x.type_as(self.weight)
66
+ return self.weight * x
67
+
68
+
69
+ class T5Attention(nn.Module):
70
+
71
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
+ assert dim_attn % num_heads == 0
73
+ super(T5Attention, self).__init__()
74
+ self.dim = dim
75
+ self.dim_attn = dim_attn
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim_attn // num_heads
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim_attn, bias=False)
81
+ self.k = nn.Linear(dim, dim_attn, bias=False)
82
+ self.v = nn.Linear(dim, dim_attn, bias=False)
83
+ self.o = nn.Linear(dim_attn, dim, bias=False)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x, context=None, mask=None, pos_bias=None):
87
+ """
88
+ x: [B, L1, C].
89
+ context: [B, L2, C] or None.
90
+ mask: [B, L2] or [B, L1, L2] or None.
91
+ """
92
+ # check inputs
93
+ context = x if context is None else context
94
+ b, n, c = x.size(0), self.num_heads, self.head_dim
95
+
96
+ # compute query, key, value
97
+ q = self.q(x).view(b, -1, n, c)
98
+ k = self.k(context).view(b, -1, n, c)
99
+ v = self.v(context).view(b, -1, n, c)
100
+
101
+ # attention bias
102
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
+ if pos_bias is not None:
104
+ attn_bias += pos_bias
105
+ if mask is not None:
106
+ assert mask.ndim in [2, 3]
107
+ mask = mask.view(b, 1, 1,
108
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
+
111
+ # compute attention (T5 does not use scaling)
112
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
115
+
116
+ # output
117
+ x = x.reshape(b, -1, n * c)
118
+ x = self.o(x)
119
+ x = self.dropout(x)
120
+ return x
121
+
122
+
123
+ class T5FeedForward(nn.Module):
124
+
125
+ def __init__(self, dim, dim_ffn, dropout=0.1):
126
+ super(T5FeedForward, self).__init__()
127
+ self.dim = dim
128
+ self.dim_ffn = dim_ffn
129
+
130
+ # layers
131
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
+ self.dropout = nn.Dropout(dropout)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x) * self.gate(x)
138
+ x = self.dropout(x)
139
+ x = self.fc2(x)
140
+ x = self.dropout(x)
141
+ return x
142
+
143
+
144
+ class T5SelfAttention(nn.Module):
145
+
146
+ def __init__(self,
147
+ dim,
148
+ dim_attn,
149
+ dim_ffn,
150
+ num_heads,
151
+ num_buckets,
152
+ shared_pos=True,
153
+ dropout=0.1):
154
+ super(T5SelfAttention, self).__init__()
155
+ self.dim = dim
156
+ self.dim_attn = dim_attn
157
+ self.dim_ffn = dim_ffn
158
+ self.num_heads = num_heads
159
+ self.num_buckets = num_buckets
160
+ self.shared_pos = shared_pos
161
+
162
+ # layers
163
+ self.norm1 = T5LayerNorm(dim)
164
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
+ self.norm2 = T5LayerNorm(dim)
166
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
+ num_buckets, num_heads, bidirectional=True)
169
+
170
+ def forward(self, x, mask=None, pos_bias=None):
171
+ e = pos_bias if self.shared_pos else self.pos_embedding(
172
+ x.size(1), x.size(1))
173
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class T5CrossAttention(nn.Module):
179
+
180
+ def __init__(self,
181
+ dim,
182
+ dim_attn,
183
+ dim_ffn,
184
+ num_heads,
185
+ num_buckets,
186
+ shared_pos=True,
187
+ dropout=0.1):
188
+ super(T5CrossAttention, self).__init__()
189
+ self.dim = dim
190
+ self.dim_attn = dim_attn
191
+ self.dim_ffn = dim_ffn
192
+ self.num_heads = num_heads
193
+ self.num_buckets = num_buckets
194
+ self.shared_pos = shared_pos
195
+
196
+ # layers
197
+ self.norm1 = T5LayerNorm(dim)
198
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
+ self.norm2 = T5LayerNorm(dim)
200
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
+ self.norm3 = T5LayerNorm(dim)
202
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
+ num_buckets, num_heads, bidirectional=False)
205
+
206
+ def forward(self,
207
+ x,
208
+ mask=None,
209
+ encoder_states=None,
210
+ encoder_mask=None,
211
+ pos_bias=None):
212
+ e = pos_bias if self.shared_pos else self.pos_embedding(
213
+ x.size(1), x.size(1))
214
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
+ x = fp16_clamp(x + self.cross_attn(
216
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
217
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
+ return x
219
+
220
+
221
+ class T5RelativeEmbedding(nn.Module):
222
+
223
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
+ super(T5RelativeEmbedding, self).__init__()
225
+ self.num_buckets = num_buckets
226
+ self.num_heads = num_heads
227
+ self.bidirectional = bidirectional
228
+ self.max_dist = max_dist
229
+
230
+ # layers
231
+ self.embedding = nn.Embedding(num_buckets, num_heads)
232
+
233
+ def forward(self, lq, lk):
234
+ device = self.embedding.weight.device
235
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
+ # torch.arange(lq).unsqueeze(1).to(device)
237
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
+ torch.arange(lq, device=device).unsqueeze(1)
239
+ rel_pos = self._relative_position_bucket(rel_pos)
240
+ rel_pos_embeds = self.embedding(rel_pos)
241
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
+ 0) # [1, N, Lq, Lk]
243
+ return rel_pos_embeds.contiguous()
244
+
245
+ def _relative_position_bucket(self, rel_pos):
246
+ # preprocess
247
+ if self.bidirectional:
248
+ num_buckets = self.num_buckets // 2
249
+ rel_buckets = (rel_pos > 0).long() * num_buckets
250
+ rel_pos = torch.abs(rel_pos)
251
+ else:
252
+ num_buckets = self.num_buckets
253
+ rel_buckets = 0
254
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
+
256
+ # embeddings for small and large positions
257
+ max_exact = num_buckets // 2
258
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
+ math.log(self.max_dist / max_exact) *
260
+ (num_buckets - max_exact)).long()
261
+ rel_pos_large = torch.min(
262
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
+ return rel_buckets
265
+
266
+
267
+ class T5Encoder(nn.Module):
268
+
269
+ def __init__(self,
270
+ vocab,
271
+ dim,
272
+ dim_attn,
273
+ dim_ffn,
274
+ num_heads,
275
+ num_layers,
276
+ num_buckets,
277
+ shared_pos=True,
278
+ dropout=0.1):
279
+ super(T5Encoder, self).__init__()
280
+ self.dim = dim
281
+ self.dim_attn = dim_attn
282
+ self.dim_ffn = dim_ffn
283
+ self.num_heads = num_heads
284
+ self.num_layers = num_layers
285
+ self.num_buckets = num_buckets
286
+ self.shared_pos = shared_pos
287
+
288
+ # layers
289
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
+ else nn.Embedding(vocab, dim)
291
+ self.pos_embedding = T5RelativeEmbedding(
292
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.blocks = nn.ModuleList([
295
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
+ shared_pos, dropout) for _ in range(num_layers)
297
+ ])
298
+ self.norm = T5LayerNorm(dim)
299
+
300
+ # initialize weights
301
+ self.apply(init_weights)
302
+
303
+ def forward(self, ids, mask=None):
304
+ x = self.token_embedding(ids)
305
+ x = self.dropout(x)
306
+ e = self.pos_embedding(x.size(1),
307
+ x.size(1)) if self.shared_pos else None
308
+ for block in self.blocks:
309
+ x = block(x, mask, pos_bias=e)
310
+ x = self.norm(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class T5Decoder(nn.Module):
316
+
317
+ def __init__(self,
318
+ vocab,
319
+ dim,
320
+ dim_attn,
321
+ dim_ffn,
322
+ num_heads,
323
+ num_layers,
324
+ num_buckets,
325
+ shared_pos=True,
326
+ dropout=0.1):
327
+ super(T5Decoder, self).__init__()
328
+ self.dim = dim
329
+ self.dim_attn = dim_attn
330
+ self.dim_ffn = dim_ffn
331
+ self.num_heads = num_heads
332
+ self.num_layers = num_layers
333
+ self.num_buckets = num_buckets
334
+ self.shared_pos = shared_pos
335
+
336
+ # layers
337
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
+ else nn.Embedding(vocab, dim)
339
+ self.pos_embedding = T5RelativeEmbedding(
340
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
+ self.dropout = nn.Dropout(dropout)
342
+ self.blocks = nn.ModuleList([
343
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
+ shared_pos, dropout) for _ in range(num_layers)
345
+ ])
346
+ self.norm = T5LayerNorm(dim)
347
+
348
+ # initialize weights
349
+ self.apply(init_weights)
350
+
351
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
+ b, s = ids.size()
353
+
354
+ # causal mask
355
+ if mask is None:
356
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
+ elif mask.ndim == 2:
358
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
+
360
+ # layers
361
+ x = self.token_embedding(ids)
362
+ x = self.dropout(x)
363
+ e = self.pos_embedding(x.size(1),
364
+ x.size(1)) if self.shared_pos else None
365
+ for block in self.blocks:
366
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
+ x = self.norm(x)
368
+ x = self.dropout(x)
369
+ return x
370
+
371
+
372
+ class T5Model(nn.Module):
373
+
374
+ def __init__(self,
375
+ vocab_size,
376
+ dim,
377
+ dim_attn,
378
+ dim_ffn,
379
+ num_heads,
380
+ encoder_layers,
381
+ decoder_layers,
382
+ num_buckets,
383
+ shared_pos=True,
384
+ dropout=0.1):
385
+ super(T5Model, self).__init__()
386
+ self.vocab_size = vocab_size
387
+ self.dim = dim
388
+ self.dim_attn = dim_attn
389
+ self.dim_ffn = dim_ffn
390
+ self.num_heads = num_heads
391
+ self.encoder_layers = encoder_layers
392
+ self.decoder_layers = decoder_layers
393
+ self.num_buckets = num_buckets
394
+
395
+ # layers
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
+ num_heads, encoder_layers, num_buckets,
399
+ shared_pos, dropout)
400
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
+ num_heads, decoder_layers, num_buckets,
402
+ shared_pos, dropout)
403
+ self.head = nn.Linear(dim, vocab_size, bias=False)
404
+
405
+ # initialize weights
406
+ self.apply(init_weights)
407
+
408
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
+ x = self.encoder(encoder_ids, encoder_mask)
410
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
+ x = self.head(x)
412
+ return x
413
+
414
+
415
+ def _t5(name,
416
+ encoder_only=False,
417
+ decoder_only=False,
418
+ return_tokenizer=False,
419
+ tokenizer_kwargs={},
420
+ dtype=torch.float32,
421
+ device='cpu',
422
+ **kwargs):
423
+ # sanity check
424
+ assert not (encoder_only and decoder_only)
425
+
426
+ # params
427
+ if encoder_only:
428
+ model_cls = T5Encoder
429
+ kwargs['vocab'] = kwargs.pop('vocab_size')
430
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
+ _ = kwargs.pop('decoder_layers')
432
+ elif decoder_only:
433
+ model_cls = T5Decoder
434
+ kwargs['vocab'] = kwargs.pop('vocab_size')
435
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
+ _ = kwargs.pop('encoder_layers')
437
+ else:
438
+ model_cls = T5Model
439
+
440
+ # init model
441
+ with torch.device(device):
442
+ model = model_cls(**kwargs)
443
+
444
+ # set device
445
+ model = model.to(dtype=dtype, device=device)
446
+
447
+ # init tokenizer
448
+ if return_tokenizer:
449
+ from .tokenizers import HuggingfaceTokenizer
450
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
+ return model, tokenizer
452
+ else:
453
+ return model
454
+
455
+
456
+ def umt5_xxl(**kwargs):
457
+ cfg = dict(
458
+ vocab_size=256384,
459
+ dim=4096,
460
+ dim_attn=4096,
461
+ dim_ffn=10240,
462
+ num_heads=64,
463
+ encoder_layers=24,
464
+ decoder_layers=24,
465
+ num_buckets=32,
466
+ shared_pos=False,
467
+ dropout=0.1)
468
+ cfg.update(**kwargs)
469
+ return _t5('umt5-xxl', **cfg)
470
+
471
+
472
+ class T5EncoderModel:
473
+
474
+ def __init__(
475
+ self,
476
+ text_len,
477
+ dtype=torch.bfloat16,
478
+ device=torch.cuda.current_device(),
479
+ checkpoint_path=None,
480
+ tokenizer_path=None,
481
+ shard_fn=None,
482
+ ):
483
+ self.text_len = text_len
484
+ self.dtype = dtype
485
+ self.device = device
486
+ self.checkpoint_path = checkpoint_path
487
+ self.tokenizer_path = tokenizer_path
488
+
489
+ # init model
490
+ model = umt5_xxl(
491
+ encoder_only=True,
492
+ return_tokenizer=False,
493
+ dtype=dtype,
494
+ device=device).eval().requires_grad_(False)
495
+ logging.info(f'loading {checkpoint_path}')
496
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497
+ self.model = model
498
+ if shard_fn is not None:
499
+ self.model = shard_fn(self.model, sync_module_states=False)
500
+ else:
501
+ self.model.to(self.device)
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
+
506
+ def __call__(self, texts, device):
507
+ ids, mask = self.tokenizer(
508
+ texts, return_mask=True, add_special_tokens=True)
509
+ ids = ids.to(device)
510
+ mask = mask.to(device)
511
+ seq_lens = mask.gt(0).sum(dim=1).long()
512
+ context = self.model(ids, mask)
513
+ return [u[:v] for u, v in zip(context, seq_lens)]
wan/modules/tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ['HuggingfaceTokenizer']
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r'\s+', ' ', text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace('_', ' ')
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans('', '', string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string))
30
+ else:
31
+ text = text.translate(str.maketrans('', '', string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r'\s+', ' ', text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop('return_mask', False)
51
+
52
+ # arguments
53
+ _kwargs = {'return_tensors': 'pt'}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({
56
+ 'padding': 'max_length',
57
+ 'truncation': True,
58
+ 'max_length': self.seq_len
59
+ })
60
+ _kwargs.update(**kwargs)
61
+
62
+ # tokenization
63
+ if isinstance(sequence, str):
64
+ sequence = [sequence]
65
+ if self.clean:
66
+ sequence = [self._clean(u) for u in sequence]
67
+ ids = self.tokenizer(sequence, **_kwargs)
68
+
69
+ # output
70
+ if return_mask:
71
+ return ids.input_ids, ids.attention_mask
72
+ else:
73
+ return ids.input_ids
74
+
75
+ def _clean(self, text):
76
+ if self.clean == 'whitespace':
77
+ text = whitespace_clean(basic_clean(text))
78
+ elif self.clean == 'lower':
79
+ text = whitespace_clean(basic_clean(text)).lower()
80
+ elif self.clean == 'canonicalize':
81
+ text = canonicalize(basic_clean(text))
82
+ return text
wan/modules/vae2_1.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ 'Wan2_1_VAE',
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
+ self.padding[1], 2 * self.padding[0], 0)
26
+ self.padding = (0, 0, 0)
27
+
28
+ def forward(self, x, cache_x=None):
29
+ padding = list(self._padding)
30
+ if cache_x is not None and self._padding[4] > 0:
31
+ cache_x = cache_x.to(x.device)
32
+ x = torch.cat([cache_x, x], dim=2)
33
+ padding[4] -= cache_x.shape[2]
34
+ x = F.pad(x, padding)
35
+
36
+ return super().forward(x)
37
+
38
+
39
+ class RMS_norm(nn.Module):
40
+
41
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
42
+ super().__init__()
43
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
+
46
+ self.channel_first = channel_first
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(shape))
49
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
+
51
+ def forward(self, x):
52
+ return F.normalize(
53
+ x, dim=(1 if self.channel_first else
54
+ -1)) * self.scale * self.gamma + self.bias
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+
59
+ def forward(self, x):
60
+ """
61
+ Fix bfloat16 support for nearest neighbor interpolation.
62
+ """
63
+ return super().forward(x.float()).type_as(x)
64
+
65
+
66
+ class Resample(nn.Module):
67
+
68
+ def __init__(self, dim, mode):
69
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
+ 'downsample3d')
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == 'upsample2d':
77
+ self.resample = nn.Sequential(
78
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
80
+ elif mode == 'upsample3d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ self.time_conv = CausalConv3d(
85
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+
87
+ elif mode == 'downsample2d':
88
+ self.resample = nn.Sequential(
89
+ nn.ZeroPad2d((0, 1, 0, 1)),
90
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
+ elif mode == 'downsample3d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ self.time_conv = CausalConv3d(
96
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
+
98
+ else:
99
+ self.resample = nn.Identity()
100
+
101
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
102
+ b, c, t, h, w = x.size()
103
+ if self.mode == 'upsample3d':
104
+ if feat_cache is not None:
105
+ idx = feat_idx[0]
106
+ if feat_cache[idx] is None:
107
+ feat_cache[idx] = 'Rep'
108
+ feat_idx[0] += 1
109
+ else:
110
+
111
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
+ if cache_x.shape[2] < 2 and feat_cache[
113
+ idx] is not None and feat_cache[idx] != 'Rep':
114
+ # cache last frame of last two chunk
115
+ cache_x = torch.cat([
116
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
+ cache_x.device), cache_x
118
+ ],
119
+ dim=2)
120
+ if cache_x.shape[2] < 2 and feat_cache[
121
+ idx] is not None and feat_cache[idx] == 'Rep':
122
+ cache_x = torch.cat([
123
+ torch.zeros_like(cache_x).to(cache_x.device),
124
+ cache_x
125
+ ],
126
+ dim=2)
127
+ if feat_cache[idx] == 'Rep':
128
+ x = self.time_conv(x)
129
+ else:
130
+ x = self.time_conv(x, feat_cache[idx])
131
+ feat_cache[idx] = cache_x
132
+ feat_idx[0] += 1
133
+
134
+ x = x.reshape(b, 2, c, t, h, w)
135
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
+ 3)
137
+ x = x.reshape(b, c, t * 2, h, w)
138
+ t = x.shape[2]
139
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
140
+ x = self.resample(x)
141
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
+
143
+ if self.mode == 'downsample3d':
144
+ if feat_cache is not None:
145
+ idx = feat_idx[0]
146
+ if feat_cache[idx] is None:
147
+ feat_cache[idx] = x.clone()
148
+ feat_idx[0] += 1
149
+ else:
150
+
151
+ cache_x = x[:, :, -1:, :, :].clone()
152
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
+ # # cache last frame of last two chunk
154
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
+
156
+ x = self.time_conv(
157
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
+ feat_cache[idx] = cache_x
159
+ feat_idx[0] += 1
160
+ return x
161
+
162
+ def init_weight(self, conv):
163
+ conv_weight = conv.weight
164
+ nn.init.zeros_(conv_weight)
165
+ c1, c2, t, h, w = conv_weight.size()
166
+ one_matrix = torch.eye(c1, c2)
167
+ init_matrix = one_matrix
168
+ nn.init.zeros_(conv_weight)
169
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
171
+ conv.weight.data.copy_(conv_weight)
172
+ nn.init.zeros_(conv.bias.data)
173
+
174
+ def init_weight2(self, conv):
175
+ conv_weight = conv.weight.data
176
+ nn.init.zeros_(conv_weight)
177
+ c1, c2, t, h, w = conv_weight.size()
178
+ init_matrix = torch.eye(c1 // 2, c2)
179
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
+ conv.weight.data.copy_(conv_weight)
183
+ nn.init.zeros_(conv.bias.data)
184
+
185
+
186
+ class ResidualBlock(nn.Module):
187
+
188
+ def __init__(self, in_dim, out_dim, dropout=0.0):
189
+ super().__init__()
190
+ self.in_dim = in_dim
191
+ self.out_dim = out_dim
192
+
193
+ # layers
194
+ self.residual = nn.Sequential(
195
+ RMS_norm(in_dim, images=False), nn.SiLU(),
196
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
197
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
199
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
+ if in_dim != out_dim else nn.Identity()
201
+
202
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
203
+ h = self.shortcut(x)
204
+ for layer in self.residual:
205
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
+ idx = feat_idx[0]
207
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
+ # cache last frame of last two chunk
210
+ cache_x = torch.cat([
211
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
+ cache_x.device), cache_x
213
+ ],
214
+ dim=2)
215
+ x = layer(x, feat_cache[idx])
216
+ feat_cache[idx] = cache_x
217
+ feat_idx[0] += 1
218
+ else:
219
+ x = layer(x)
220
+ return x + h
221
+
222
+
223
+ class AttentionBlock(nn.Module):
224
+ """
225
+ Causal self-attention with a single head.
226
+ """
227
+
228
+ def __init__(self, dim):
229
+ super().__init__()
230
+ self.dim = dim
231
+
232
+ # layers
233
+ self.norm = RMS_norm(dim)
234
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
+ self.proj = nn.Conv2d(dim, dim, 1)
236
+
237
+ # zero out the last layer params
238
+ nn.init.zeros_(self.proj.weight)
239
+
240
+ def forward(self, x):
241
+ identity = x
242
+ b, c, t, h, w = x.size()
243
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
244
+ x = self.norm(x)
245
+ # compute query, key, value
246
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
+ -1).permute(0, 1, 3,
248
+ 2).contiguous().chunk(
249
+ 3, dim=-1)
250
+
251
+ # apply attention
252
+ x = F.scaled_dot_product_attention(
253
+ q,
254
+ k,
255
+ v,
256
+ )
257
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
+
259
+ # output
260
+ x = self.proj(x)
261
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
+ return x + identity
263
+
264
+
265
+ class Encoder3d(nn.Module):
266
+
267
+ def __init__(self,
268
+ dim=128,
269
+ z_dim=4,
270
+ dim_mult=[1, 2, 4, 4],
271
+ num_res_blocks=2,
272
+ attn_scales=[],
273
+ temperal_downsample=[True, True, False],
274
+ dropout=0.0):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.z_dim = z_dim
278
+ self.dim_mult = dim_mult
279
+ self.num_res_blocks = num_res_blocks
280
+ self.attn_scales = attn_scales
281
+ self.temperal_downsample = temperal_downsample
282
+
283
+ # dimensions
284
+ dims = [dim * u for u in [1] + dim_mult]
285
+ scale = 1.0
286
+
287
+ # init block
288
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
+
290
+ # downsample blocks
291
+ downsamples = []
292
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
+ # residual (+attention) blocks
294
+ for _ in range(num_res_blocks):
295
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
+ if scale in attn_scales:
297
+ downsamples.append(AttentionBlock(out_dim))
298
+ in_dim = out_dim
299
+
300
+ # downsample block
301
+ if i != len(dim_mult) - 1:
302
+ mode = 'downsample3d' if temperal_downsample[
303
+ i] else 'downsample2d'
304
+ downsamples.append(Resample(out_dim, mode=mode))
305
+ scale /= 2.0
306
+ self.downsamples = nn.Sequential(*downsamples)
307
+
308
+ # middle blocks
309
+ self.middle = nn.Sequential(
310
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
+ ResidualBlock(out_dim, out_dim, dropout))
312
+
313
+ # output blocks
314
+ self.head = nn.Sequential(
315
+ RMS_norm(out_dim, images=False), nn.SiLU(),
316
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
317
+
318
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
319
+ if feat_cache is not None:
320
+ idx = feat_idx[0]
321
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
+ # cache last frame of last two chunk
324
+ cache_x = torch.cat([
325
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
+ cache_x.device), cache_x
327
+ ],
328
+ dim=2)
329
+ x = self.conv1(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = self.conv1(x)
334
+
335
+ ## downsamples
336
+ for layer in self.downsamples:
337
+ if feat_cache is not None:
338
+ x = layer(x, feat_cache, feat_idx)
339
+ else:
340
+ x = layer(x)
341
+
342
+ ## middle
343
+ for layer in self.middle:
344
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
+ x = layer(x, feat_cache, feat_idx)
346
+ else:
347
+ x = layer(x)
348
+
349
+ ## head
350
+ for layer in self.head:
351
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
+ idx = feat_idx[0]
353
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
+ # cache last frame of last two chunk
356
+ cache_x = torch.cat([
357
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
+ cache_x.device), cache_x
359
+ ],
360
+ dim=2)
361
+ x = layer(x, feat_cache[idx])
362
+ feat_cache[idx] = cache_x
363
+ feat_idx[0] += 1
364
+ else:
365
+ x = layer(x)
366
+ return x
367
+
368
+
369
+ class Decoder3d(nn.Module):
370
+
371
+ def __init__(self,
372
+ dim=128,
373
+ z_dim=4,
374
+ dim_mult=[1, 2, 4, 4],
375
+ num_res_blocks=2,
376
+ attn_scales=[],
377
+ temperal_upsample=[False, True, True],
378
+ dropout=0.0):
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.z_dim = z_dim
382
+ self.dim_mult = dim_mult
383
+ self.num_res_blocks = num_res_blocks
384
+ self.attn_scales = attn_scales
385
+ self.temperal_upsample = temperal_upsample
386
+
387
+ # dimensions
388
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
+ scale = 1.0 / 2**(len(dim_mult) - 2)
390
+
391
+ # init block
392
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
+
394
+ # middle blocks
395
+ self.middle = nn.Sequential(
396
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
+ ResidualBlock(dims[0], dims[0], dropout))
398
+
399
+ # upsample blocks
400
+ upsamples = []
401
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
+ # residual (+attention) blocks
403
+ if i == 1 or i == 2 or i == 3:
404
+ in_dim = in_dim // 2
405
+ for _ in range(num_res_blocks + 1):
406
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
+ if scale in attn_scales:
408
+ upsamples.append(AttentionBlock(out_dim))
409
+ in_dim = out_dim
410
+
411
+ # upsample block
412
+ if i != len(dim_mult) - 1:
413
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
+ upsamples.append(Resample(out_dim, mode=mode))
415
+ scale *= 2.0
416
+ self.upsamples = nn.Sequential(*upsamples)
417
+
418
+ # output blocks
419
+ self.head = nn.Sequential(
420
+ RMS_norm(out_dim, images=False), nn.SiLU(),
421
+ CausalConv3d(out_dim, 3, 3, padding=1))
422
+
423
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
424
+ ## conv1
425
+ if feat_cache is not None:
426
+ idx = feat_idx[0]
427
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
+ # cache last frame of last two chunk
430
+ cache_x = torch.cat([
431
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
+ cache_x.device), cache_x
433
+ ],
434
+ dim=2)
435
+ x = self.conv1(x, feat_cache[idx])
436
+ feat_cache[idx] = cache_x
437
+ feat_idx[0] += 1
438
+ else:
439
+ x = self.conv1(x)
440
+
441
+ ## middle
442
+ for layer in self.middle:
443
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
+ x = layer(x, feat_cache, feat_idx)
445
+ else:
446
+ x = layer(x)
447
+
448
+ ## upsamples
449
+ for layer in self.upsamples:
450
+ if feat_cache is not None:
451
+ x = layer(x, feat_cache, feat_idx)
452
+ else:
453
+ x = layer(x)
454
+
455
+ ## head
456
+ for layer in self.head:
457
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
+ idx = feat_idx[0]
459
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
+ # cache last frame of last two chunk
462
+ cache_x = torch.cat([
463
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
+ cache_x.device), cache_x
465
+ ],
466
+ dim=2)
467
+ x = layer(x, feat_cache[idx])
468
+ feat_cache[idx] = cache_x
469
+ feat_idx[0] += 1
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ def count_conv3d(model):
476
+ count = 0
477
+ for m in model.modules():
478
+ if isinstance(m, CausalConv3d):
479
+ count += 1
480
+ return count
481
+
482
+
483
+ class WanVAE_(nn.Module):
484
+
485
+ def __init__(self,
486
+ dim=128,
487
+ z_dim=4,
488
+ dim_mult=[1, 2, 4, 4],
489
+ num_res_blocks=2,
490
+ attn_scales=[],
491
+ temperal_downsample=[True, True, False],
492
+ dropout=0.0):
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.z_dim = z_dim
496
+ self.dim_mult = dim_mult
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attn_scales = attn_scales
499
+ self.temperal_downsample = temperal_downsample
500
+ self.temperal_upsample = temperal_downsample[::-1]
501
+
502
+ # modules
503
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
+ attn_scales, self.temperal_downsample, dropout)
505
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_upsample, dropout)
509
+
510
+ def forward(self, x):
511
+ mu, log_var = self.encode(x)
512
+ z = self.reparameterize(mu, log_var)
513
+ x_recon = self.decode(z)
514
+ return x_recon, mu, log_var
515
+
516
+ def encode(self, x, scale):
517
+ self.clear_cache()
518
+ ## cache
519
+ t = x.shape[2]
520
+ iter_ = 1 + (t - 1) // 4
521
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
522
+ for i in range(iter_):
523
+ self._enc_conv_idx = [0]
524
+ if i == 0:
525
+ out = self.encoder(
526
+ x[:, :, :1, :, :],
527
+ feat_cache=self._enc_feat_map,
528
+ feat_idx=self._enc_conv_idx)
529
+ else:
530
+ out_ = self.encoder(
531
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ out = torch.cat([out, out_], 2)
535
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
536
+ if isinstance(scale[0], torch.Tensor):
537
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
+ 1, self.z_dim, 1, 1, 1)
539
+ else:
540
+ mu = (mu - scale[0]) * scale[1]
541
+ self.clear_cache()
542
+ return mu
543
+
544
+ def decode(self, z, scale):
545
+ self.clear_cache()
546
+ # z: [b,c,t,h,w]
547
+ if isinstance(scale[0], torch.Tensor):
548
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
+ 1, self.z_dim, 1, 1, 1)
550
+ else:
551
+ z = z / scale[1] + scale[0]
552
+ iter_ = z.shape[2]
553
+ x = self.conv2(z)
554
+ for i in range(iter_):
555
+ self._conv_idx = [0]
556
+ if i == 0:
557
+ out = self.decoder(
558
+ x[:, :, i:i + 1, :, :],
559
+ feat_cache=self._feat_map,
560
+ feat_idx=self._conv_idx)
561
+ else:
562
+ out_ = self.decoder(
563
+ x[:, :, i:i + 1, :, :],
564
+ feat_cache=self._feat_map,
565
+ feat_idx=self._conv_idx)
566
+ out = torch.cat([out, out_], 2)
567
+ self.clear_cache()
568
+ return out
569
+
570
+ def reparameterize(self, mu, log_var):
571
+ std = torch.exp(0.5 * log_var)
572
+ eps = torch.randn_like(std)
573
+ return eps * std + mu
574
+
575
+ def sample(self, imgs, deterministic=False):
576
+ mu, log_var = self.encode(imgs)
577
+ if deterministic:
578
+ return mu
579
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
+ return mu + std * torch.randn_like(std)
581
+
582
+ def clear_cache(self):
583
+ self._conv_num = count_conv3d(self.decoder)
584
+ self._conv_idx = [0]
585
+ self._feat_map = [None] * self._conv_num
586
+ #cache encode
587
+ self._enc_conv_num = count_conv3d(self.encoder)
588
+ self._enc_conv_idx = [0]
589
+ self._enc_feat_map = [None] * self._enc_conv_num
590
+
591
+
592
+ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
+ """
594
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
+ """
596
+ # params
597
+ cfg = dict(
598
+ dim=96,
599
+ z_dim=z_dim,
600
+ dim_mult=[1, 2, 4, 4],
601
+ num_res_blocks=2,
602
+ attn_scales=[],
603
+ temperal_downsample=[False, True, True],
604
+ dropout=0.0)
605
+ cfg.update(**kwargs)
606
+
607
+ # init model
608
+ with torch.device('meta'):
609
+ model = WanVAE_(**cfg)
610
+
611
+ # load checkpoint
612
+ logging.info(f'loading {pretrained_path}')
613
+ model.load_state_dict(
614
+ torch.load(pretrained_path, map_location=device), assign=True)
615
+
616
+ return model
617
+
618
+
619
+ class Wan2_1_VAE:
620
+
621
+ def __init__(self,
622
+ z_dim=16,
623
+ vae_pth='cache/vae_step_411000.pth',
624
+ dtype=torch.float,
625
+ device="cuda"):
626
+ self.dtype = dtype
627
+ self.device = device
628
+
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
+ self.std = torch.tensor(std, dtype=dtype, device=device)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ pretrained_path=vae_pth,
644
+ z_dim=z_dim,
645
+ ).eval().requires_grad_(False).to(device)
646
+
647
+ def encode(self, videos):
648
+ """
649
+ videos: A list of videos each with shape [C, T, H, W].
650
+ """
651
+ with amp.autocast(dtype=self.dtype):
652
+ return [
653
+ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
+ for u in videos
655
+ ]
656
+
657
+ def decode(self, zs):
658
+ with amp.autocast(dtype=self.dtype):
659
+ return [
660
+ self.model.decode(u.unsqueeze(0),
661
+ self.scale).float().clamp_(-1, 1).squeeze(0)
662
+ for u in zs
663
+ ]
wan/modules/vae2_2.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ "Wan2_2_VAE",
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (
25
+ self.padding[2],
26
+ self.padding[2],
27
+ self.padding[1],
28
+ self.padding[1],
29
+ 2 * self.padding[0],
30
+ 0,
31
+ )
32
+ self.padding = (0, 0, 0)
33
+
34
+ def forward(self, x, cache_x=None):
35
+ padding = list(self._padding)
36
+ if cache_x is not None and self._padding[4] > 0:
37
+ cache_x = cache_x.to(x.device)
38
+ x = torch.cat([cache_x, x], dim=2)
39
+ padding[4] -= cache_x.shape[2]
40
+ x = F.pad(x, padding)
41
+
42
+ return super().forward(x)
43
+
44
+
45
+ class RMS_norm(nn.Module):
46
+
47
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
48
+ super().__init__()
49
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
50
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
51
+
52
+ self.channel_first = channel_first
53
+ self.scale = dim**0.5
54
+ self.gamma = nn.Parameter(torch.ones(shape))
55
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
56
+
57
+ def forward(self, x):
58
+ return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
59
+ self.scale * self.gamma + self.bias)
60
+
61
+
62
+ class Upsample(nn.Upsample):
63
+
64
+ def forward(self, x):
65
+ """
66
+ Fix bfloat16 support for nearest neighbor interpolation.
67
+ """
68
+ return super().forward(x.float()).type_as(x)
69
+
70
+
71
+ class Resample(nn.Module):
72
+
73
+ def __init__(self, dim, mode):
74
+ assert mode in (
75
+ "none",
76
+ "upsample2d",
77
+ "upsample3d",
78
+ "downsample2d",
79
+ "downsample3d",
80
+ )
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.mode = mode
84
+
85
+ # layers
86
+ if mode == "upsample2d":
87
+ self.resample = nn.Sequential(
88
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
89
+ nn.Conv2d(dim, dim, 3, padding=1),
90
+ )
91
+ elif mode == "upsample3d":
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
+ nn.Conv2d(dim, dim, 3, padding=1),
95
+ # nn.Conv2d(dim, dim//2, 3, padding=1)
96
+ )
97
+ self.time_conv = CausalConv3d(
98
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
99
+ elif mode == "downsample2d":
100
+ self.resample = nn.Sequential(
101
+ nn.ZeroPad2d((0, 1, 0, 1)),
102
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
103
+ elif mode == "downsample3d":
104
+ self.resample = nn.Sequential(
105
+ nn.ZeroPad2d((0, 1, 0, 1)),
106
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
107
+ self.time_conv = CausalConv3d(
108
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
109
+ else:
110
+ self.resample = nn.Identity()
111
+
112
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
113
+ b, c, t, h, w = x.size()
114
+ if self.mode == "upsample3d":
115
+ if feat_cache is not None:
116
+ idx = feat_idx[0]
117
+ if feat_cache[idx] is None:
118
+ feat_cache[idx] = "Rep"
119
+ feat_idx[0] += 1
120
+ else:
121
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
122
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
123
+ feat_cache[idx] != "Rep"):
124
+ # cache last frame of last two chunk
125
+ cache_x = torch.cat(
126
+ [
127
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
128
+ cache_x.device),
129
+ cache_x,
130
+ ],
131
+ dim=2,
132
+ )
133
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
134
+ feat_cache[idx] == "Rep"):
135
+ cache_x = torch.cat(
136
+ [
137
+ torch.zeros_like(cache_x).to(cache_x.device),
138
+ cache_x
139
+ ],
140
+ dim=2,
141
+ )
142
+ if feat_cache[idx] == "Rep":
143
+ x = self.time_conv(x)
144
+ else:
145
+ x = self.time_conv(x, feat_cache[idx])
146
+ feat_cache[idx] = cache_x
147
+ feat_idx[0] += 1
148
+ x = x.reshape(b, 2, c, t, h, w)
149
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
150
+ 3)
151
+ x = x.reshape(b, c, t * 2, h, w)
152
+ t = x.shape[2]
153
+ x = rearrange(x, "b c t h w -> (b t) c h w")
154
+ x = self.resample(x)
155
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
156
+
157
+ if self.mode == "downsample3d":
158
+ if feat_cache is not None:
159
+ idx = feat_idx[0]
160
+ if feat_cache[idx] is None:
161
+ feat_cache[idx] = x.clone()
162
+ feat_idx[0] += 1
163
+ else:
164
+ cache_x = x[:, :, -1:, :, :].clone()
165
+ x = self.time_conv(
166
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
167
+ feat_cache[idx] = cache_x
168
+ feat_idx[0] += 1
169
+ return x
170
+
171
+ def init_weight(self, conv):
172
+ conv_weight = conv.weight.detach().clone()
173
+ nn.init.zeros_(conv_weight)
174
+ c1, c2, t, h, w = conv_weight.size()
175
+ one_matrix = torch.eye(c1, c2)
176
+ init_matrix = one_matrix
177
+ nn.init.zeros_(conv_weight)
178
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
179
+ conv.weight = nn.Parameter(conv_weight)
180
+ nn.init.zeros_(conv.bias.data)
181
+
182
+ def init_weight2(self, conv):
183
+ conv_weight = conv.weight.data.detach().clone()
184
+ nn.init.zeros_(conv_weight)
185
+ c1, c2, t, h, w = conv_weight.size()
186
+ init_matrix = torch.eye(c1 // 2, c2)
187
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
188
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
189
+ conv.weight = nn.Parameter(conv_weight)
190
+ nn.init.zeros_(conv.bias.data)
191
+
192
+
193
+ class ResidualBlock(nn.Module):
194
+
195
+ def __init__(self, in_dim, out_dim, dropout=0.0):
196
+ super().__init__()
197
+ self.in_dim = in_dim
198
+ self.out_dim = out_dim
199
+
200
+ # layers
201
+ self.residual = nn.Sequential(
202
+ RMS_norm(in_dim, images=False),
203
+ nn.SiLU(),
204
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
205
+ RMS_norm(out_dim, images=False),
206
+ nn.SiLU(),
207
+ nn.Dropout(dropout),
208
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
209
+ )
210
+ self.shortcut = (
211
+ CausalConv3d(in_dim, out_dim, 1)
212
+ if in_dim != out_dim else nn.Identity())
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat(
223
+ [
224
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
225
+ cache_x.device),
226
+ cache_x,
227
+ ],
228
+ dim=2,
229
+ )
230
+ x = layer(x, feat_cache[idx])
231
+ feat_cache[idx] = cache_x
232
+ feat_idx[0] += 1
233
+ else:
234
+ x = layer(x)
235
+ return x + h
236
+
237
+
238
+ class AttentionBlock(nn.Module):
239
+ """
240
+ Causal self-attention with a single head.
241
+ """
242
+
243
+ def __init__(self, dim):
244
+ super().__init__()
245
+ self.dim = dim
246
+
247
+ # layers
248
+ self.norm = RMS_norm(dim)
249
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
250
+ self.proj = nn.Conv2d(dim, dim, 1)
251
+
252
+ # zero out the last layer params
253
+ nn.init.zeros_(self.proj.weight)
254
+
255
+ def forward(self, x):
256
+ identity = x
257
+ b, c, t, h, w = x.size()
258
+ x = rearrange(x, "b c t h w -> (b t) c h w")
259
+ x = self.norm(x)
260
+ # compute query, key, value
261
+ q, k, v = (
262
+ self.to_qkv(x).reshape(b * t, 1, c * 3,
263
+ -1).permute(0, 1, 3,
264
+ 2).contiguous().chunk(3, dim=-1))
265
+
266
+ # apply attention
267
+ x = F.scaled_dot_product_attention(
268
+ q,
269
+ k,
270
+ v,
271
+ )
272
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
273
+
274
+ # output
275
+ x = self.proj(x)
276
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
277
+ return x + identity
278
+
279
+
280
+ def patchify(x, patch_size):
281
+ if patch_size == 1:
282
+ return x
283
+ if x.dim() == 4:
284
+ x = rearrange(
285
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
286
+ elif x.dim() == 5:
287
+ x = rearrange(
288
+ x,
289
+ "b c f (h q) (w r) -> b (c r q) f h w",
290
+ q=patch_size,
291
+ r=patch_size,
292
+ )
293
+ else:
294
+ raise ValueError(f"Invalid input shape: {x.shape}")
295
+
296
+ return x
297
+
298
+
299
+ def unpatchify(x, patch_size):
300
+ if patch_size == 1:
301
+ return x
302
+
303
+ if x.dim() == 4:
304
+ x = rearrange(
305
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
306
+ elif x.dim() == 5:
307
+ x = rearrange(
308
+ x,
309
+ "b (c r q) f h w -> b c f (h q) (w r)",
310
+ q=patch_size,
311
+ r=patch_size,
312
+ )
313
+ return x
314
+
315
+
316
+ class AvgDown3D(nn.Module):
317
+
318
+ def __init__(
319
+ self,
320
+ in_channels,
321
+ out_channels,
322
+ factor_t,
323
+ factor_s=1,
324
+ ):
325
+ super().__init__()
326
+ self.in_channels = in_channels
327
+ self.out_channels = out_channels
328
+ self.factor_t = factor_t
329
+ self.factor_s = factor_s
330
+ self.factor = self.factor_t * self.factor_s * self.factor_s
331
+
332
+ assert in_channels * self.factor % out_channels == 0
333
+ self.group_size = in_channels * self.factor // out_channels
334
+
335
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
336
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
337
+ pad = (0, 0, 0, 0, pad_t, 0)
338
+ x = F.pad(x, pad)
339
+ B, C, T, H, W = x.shape
340
+ x = x.view(
341
+ B,
342
+ C,
343
+ T // self.factor_t,
344
+ self.factor_t,
345
+ H // self.factor_s,
346
+ self.factor_s,
347
+ W // self.factor_s,
348
+ self.factor_s,
349
+ )
350
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
351
+ x = x.view(
352
+ B,
353
+ C * self.factor,
354
+ T // self.factor_t,
355
+ H // self.factor_s,
356
+ W // self.factor_s,
357
+ )
358
+ x = x.view(
359
+ B,
360
+ self.out_channels,
361
+ self.group_size,
362
+ T // self.factor_t,
363
+ H // self.factor_s,
364
+ W // self.factor_s,
365
+ )
366
+ x = x.mean(dim=2)
367
+ return x
368
+
369
+
370
+ class DupUp3D(nn.Module):
371
+
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ factor_t,
377
+ factor_s=1,
378
+ ):
379
+ super().__init__()
380
+ self.in_channels = in_channels
381
+ self.out_channels = out_channels
382
+
383
+ self.factor_t = factor_t
384
+ self.factor_s = factor_s
385
+ self.factor = self.factor_t * self.factor_s * self.factor_s
386
+
387
+ assert out_channels * self.factor % in_channels == 0
388
+ self.repeats = out_channels * self.factor // in_channels
389
+
390
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
391
+ x = x.repeat_interleave(self.repeats, dim=1)
392
+ x = x.view(
393
+ x.size(0),
394
+ self.out_channels,
395
+ self.factor_t,
396
+ self.factor_s,
397
+ self.factor_s,
398
+ x.size(2),
399
+ x.size(3),
400
+ x.size(4),
401
+ )
402
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
403
+ x = x.view(
404
+ x.size(0),
405
+ self.out_channels,
406
+ x.size(2) * self.factor_t,
407
+ x.size(4) * self.factor_s,
408
+ x.size(6) * self.factor_s,
409
+ )
410
+ if first_chunk:
411
+ x = x[:, :, self.factor_t - 1:, :, :]
412
+ return x
413
+
414
+
415
+ class Down_ResidualBlock(nn.Module):
416
+
417
+ def __init__(self,
418
+ in_dim,
419
+ out_dim,
420
+ dropout,
421
+ mult,
422
+ temperal_downsample=False,
423
+ down_flag=False):
424
+ super().__init__()
425
+
426
+ # Shortcut path with downsample
427
+ self.avg_shortcut = AvgDown3D(
428
+ in_dim,
429
+ out_dim,
430
+ factor_t=2 if temperal_downsample else 1,
431
+ factor_s=2 if down_flag else 1,
432
+ )
433
+
434
+ # Main path with residual blocks and downsample
435
+ downsamples = []
436
+ for _ in range(mult):
437
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
438
+ in_dim = out_dim
439
+
440
+ # Add the final downsample block
441
+ if down_flag:
442
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
443
+ downsamples.append(Resample(out_dim, mode=mode))
444
+
445
+ self.downsamples = nn.Sequential(*downsamples)
446
+
447
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
448
+ x_copy = x.clone()
449
+ for module in self.downsamples:
450
+ x = module(x, feat_cache, feat_idx)
451
+
452
+ return x + self.avg_shortcut(x_copy)
453
+
454
+
455
+ class Up_ResidualBlock(nn.Module):
456
+
457
+ def __init__(self,
458
+ in_dim,
459
+ out_dim,
460
+ dropout,
461
+ mult,
462
+ temperal_upsample=False,
463
+ up_flag=False):
464
+ super().__init__()
465
+ # Shortcut path with upsample
466
+ if up_flag:
467
+ self.avg_shortcut = DupUp3D(
468
+ in_dim,
469
+ out_dim,
470
+ factor_t=2 if temperal_upsample else 1,
471
+ factor_s=2 if up_flag else 1,
472
+ )
473
+ else:
474
+ self.avg_shortcut = None
475
+
476
+ # Main path with residual blocks and upsample
477
+ upsamples = []
478
+ for _ in range(mult):
479
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
480
+ in_dim = out_dim
481
+
482
+ # Add the final upsample block
483
+ if up_flag:
484
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
485
+ upsamples.append(Resample(out_dim, mode=mode))
486
+
487
+ self.upsamples = nn.Sequential(*upsamples)
488
+
489
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
490
+ x_main = x.clone()
491
+ for module in self.upsamples:
492
+ x_main = module(x_main, feat_cache, feat_idx)
493
+ if self.avg_shortcut is not None:
494
+ x_shortcut = self.avg_shortcut(x, first_chunk)
495
+ return x_main + x_shortcut
496
+ else:
497
+ return x_main
498
+
499
+
500
+ class Encoder3d(nn.Module):
501
+
502
+ def __init__(
503
+ self,
504
+ dim=128,
505
+ z_dim=4,
506
+ dim_mult=[1, 2, 4, 4],
507
+ num_res_blocks=2,
508
+ attn_scales=[],
509
+ temperal_downsample=[True, True, False],
510
+ dropout=0.0,
511
+ ):
512
+ super().__init__()
513
+ self.dim = dim
514
+ self.z_dim = z_dim
515
+ self.dim_mult = dim_mult
516
+ self.num_res_blocks = num_res_blocks
517
+ self.attn_scales = attn_scales
518
+ self.temperal_downsample = temperal_downsample
519
+
520
+ # dimensions
521
+ dims = [dim * u for u in [1] + dim_mult]
522
+ scale = 1.0
523
+
524
+ # init block
525
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
526
+
527
+ # downsample blocks
528
+ downsamples = []
529
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
530
+ t_down_flag = (
531
+ temperal_downsample[i]
532
+ if i < len(temperal_downsample) else False)
533
+ downsamples.append(
534
+ Down_ResidualBlock(
535
+ in_dim=in_dim,
536
+ out_dim=out_dim,
537
+ dropout=dropout,
538
+ mult=num_res_blocks,
539
+ temperal_downsample=t_down_flag,
540
+ down_flag=i != len(dim_mult) - 1,
541
+ ))
542
+ scale /= 2.0
543
+ self.downsamples = nn.Sequential(*downsamples)
544
+
545
+ # middle blocks
546
+ self.middle = nn.Sequential(
547
+ ResidualBlock(out_dim, out_dim, dropout),
548
+ AttentionBlock(out_dim),
549
+ ResidualBlock(out_dim, out_dim, dropout),
550
+ )
551
+
552
+ # # output blocks
553
+ self.head = nn.Sequential(
554
+ RMS_norm(out_dim, images=False),
555
+ nn.SiLU(),
556
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
557
+ )
558
+
559
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
560
+
561
+ if feat_cache is not None:
562
+ idx = feat_idx[0]
563
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
564
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
565
+ cache_x = torch.cat(
566
+ [
567
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
568
+ cache_x.device),
569
+ cache_x,
570
+ ],
571
+ dim=2,
572
+ )
573
+ x = self.conv1(x, feat_cache[idx])
574
+ feat_cache[idx] = cache_x
575
+ feat_idx[0] += 1
576
+ else:
577
+ x = self.conv1(x)
578
+
579
+ ## downsamples
580
+ for layer in self.downsamples:
581
+ if feat_cache is not None:
582
+ x = layer(x, feat_cache, feat_idx)
583
+ else:
584
+ x = layer(x)
585
+
586
+ ## middle
587
+ for layer in self.middle:
588
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
589
+ x = layer(x, feat_cache, feat_idx)
590
+ else:
591
+ x = layer(x)
592
+
593
+ ## head
594
+ for layer in self.head:
595
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
596
+ idx = feat_idx[0]
597
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
598
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
599
+ cache_x = torch.cat(
600
+ [
601
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
602
+ cache_x.device),
603
+ cache_x,
604
+ ],
605
+ dim=2,
606
+ )
607
+ x = layer(x, feat_cache[idx])
608
+ feat_cache[idx] = cache_x
609
+ feat_idx[0] += 1
610
+ else:
611
+ x = layer(x)
612
+
613
+ return x
614
+
615
+
616
+ class Decoder3d(nn.Module):
617
+
618
+ def __init__(
619
+ self,
620
+ dim=128,
621
+ z_dim=4,
622
+ dim_mult=[1, 2, 4, 4],
623
+ num_res_blocks=2,
624
+ attn_scales=[],
625
+ temperal_upsample=[False, True, True],
626
+ dropout=0.0,
627
+ ):
628
+ super().__init__()
629
+ self.dim = dim
630
+ self.z_dim = z_dim
631
+ self.dim_mult = dim_mult
632
+ self.num_res_blocks = num_res_blocks
633
+ self.attn_scales = attn_scales
634
+ self.temperal_upsample = temperal_upsample
635
+
636
+ # dimensions
637
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
638
+ scale = 1.0 / 2**(len(dim_mult) - 2)
639
+ # init block
640
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
641
+
642
+ # middle blocks
643
+ self.middle = nn.Sequential(
644
+ ResidualBlock(dims[0], dims[0], dropout),
645
+ AttentionBlock(dims[0]),
646
+ ResidualBlock(dims[0], dims[0], dropout),
647
+ )
648
+
649
+ # upsample blocks
650
+ upsamples = []
651
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
652
+ t_up_flag = temperal_upsample[i] if i < len(
653
+ temperal_upsample) else False
654
+ upsamples.append(
655
+ Up_ResidualBlock(
656
+ in_dim=in_dim,
657
+ out_dim=out_dim,
658
+ dropout=dropout,
659
+ mult=num_res_blocks + 1,
660
+ temperal_upsample=t_up_flag,
661
+ up_flag=i != len(dim_mult) - 1,
662
+ ))
663
+ self.upsamples = nn.Sequential(*upsamples)
664
+
665
+ # output blocks
666
+ self.head = nn.Sequential(
667
+ RMS_norm(out_dim, images=False),
668
+ nn.SiLU(),
669
+ CausalConv3d(out_dim, 12, 3, padding=1),
670
+ )
671
+
672
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
673
+ if feat_cache is not None:
674
+ idx = feat_idx[0]
675
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
676
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
677
+ cache_x = torch.cat(
678
+ [
679
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
680
+ cache_x.device),
681
+ cache_x,
682
+ ],
683
+ dim=2,
684
+ )
685
+ x = self.conv1(x, feat_cache[idx])
686
+ feat_cache[idx] = cache_x
687
+ feat_idx[0] += 1
688
+ else:
689
+ x = self.conv1(x)
690
+
691
+ for layer in self.middle:
692
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
693
+ x = layer(x, feat_cache, feat_idx)
694
+ else:
695
+ x = layer(x)
696
+
697
+ ## upsamples
698
+ for layer in self.upsamples:
699
+ if feat_cache is not None:
700
+ x = layer(x, feat_cache, feat_idx, first_chunk)
701
+ else:
702
+ x = layer(x)
703
+
704
+ ## head
705
+ for layer in self.head:
706
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
707
+ idx = feat_idx[0]
708
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
709
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
710
+ cache_x = torch.cat(
711
+ [
712
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
713
+ cache_x.device),
714
+ cache_x,
715
+ ],
716
+ dim=2,
717
+ )
718
+ x = layer(x, feat_cache[idx])
719
+ feat_cache[idx] = cache_x
720
+ feat_idx[0] += 1
721
+ else:
722
+ x = layer(x)
723
+ return x
724
+
725
+
726
+ def count_conv3d(model):
727
+ count = 0
728
+ for m in model.modules():
729
+ if isinstance(m, CausalConv3d):
730
+ count += 1
731
+ return count
732
+
733
+
734
+ class WanVAE_(nn.Module):
735
+
736
+ def __init__(
737
+ self,
738
+ dim=160,
739
+ dec_dim=256,
740
+ z_dim=16,
741
+ dim_mult=[1, 2, 4, 4],
742
+ num_res_blocks=2,
743
+ attn_scales=[],
744
+ temperal_downsample=[True, True, False],
745
+ dropout=0.0,
746
+ ):
747
+ super().__init__()
748
+ self.dim = dim
749
+ self.z_dim = z_dim
750
+ self.dim_mult = dim_mult
751
+ self.num_res_blocks = num_res_blocks
752
+ self.attn_scales = attn_scales
753
+ self.temperal_downsample = temperal_downsample
754
+ self.temperal_upsample = temperal_downsample[::-1]
755
+
756
+ # modules
757
+ self.encoder = Encoder3d(
758
+ dim,
759
+ z_dim * 2,
760
+ dim_mult,
761
+ num_res_blocks,
762
+ attn_scales,
763
+ self.temperal_downsample,
764
+ dropout,
765
+ )
766
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
767
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
768
+ self.decoder = Decoder3d(
769
+ dec_dim,
770
+ z_dim,
771
+ dim_mult,
772
+ num_res_blocks,
773
+ attn_scales,
774
+ self.temperal_upsample,
775
+ dropout,
776
+ )
777
+
778
+ def forward(self, x, scale=[0, 1]):
779
+ mu = self.encode(x, scale)
780
+ x_recon = self.decode(mu, scale)
781
+ return x_recon, mu
782
+
783
+ def encode(self, x, scale):
784
+ self.clear_cache()
785
+ x = patchify(x, patch_size=2)
786
+ t = x.shape[2]
787
+ iter_ = 1 + (t - 1) // 4
788
+ for i in range(iter_):
789
+ self._enc_conv_idx = [0]
790
+ if i == 0:
791
+ out = self.encoder(
792
+ x[:, :, :1, :, :],
793
+ feat_cache=self._enc_feat_map,
794
+ feat_idx=self._enc_conv_idx,
795
+ )
796
+ else:
797
+ out_ = self.encoder(
798
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
799
+ feat_cache=self._enc_feat_map,
800
+ feat_idx=self._enc_conv_idx,
801
+ )
802
+ out = torch.cat([out, out_], 2)
803
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
804
+ if isinstance(scale[0], torch.Tensor):
805
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
806
+ 1, self.z_dim, 1, 1, 1)
807
+ else:
808
+ mu = (mu - scale[0]) * scale[1]
809
+ self.clear_cache()
810
+ return mu
811
+
812
+ def decode(self, z, scale):
813
+ self.clear_cache()
814
+ if isinstance(scale[0], torch.Tensor):
815
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
816
+ 1, self.z_dim, 1, 1, 1)
817
+ else:
818
+ z = z / scale[1] + scale[0]
819
+ iter_ = z.shape[2]
820
+ x = self.conv2(z)
821
+ for i in range(iter_):
822
+ self._conv_idx = [0]
823
+ if i == 0:
824
+ out = self.decoder(
825
+ x[:, :, i:i + 1, :, :],
826
+ feat_cache=self._feat_map,
827
+ feat_idx=self._conv_idx,
828
+ first_chunk=True,
829
+ )
830
+ else:
831
+ out_ = self.decoder(
832
+ x[:, :, i:i + 1, :, :],
833
+ feat_cache=self._feat_map,
834
+ feat_idx=self._conv_idx,
835
+ )
836
+ out = torch.cat([out, out_], 2)
837
+ out = unpatchify(out, patch_size=2)
838
+ self.clear_cache()
839
+ return out
840
+
841
+ def reparameterize(self, mu, log_var):
842
+ std = torch.exp(0.5 * log_var)
843
+ eps = torch.randn_like(std)
844
+ return eps * std + mu
845
+
846
+ def sample(self, imgs, deterministic=False):
847
+ mu, log_var = self.encode(imgs)
848
+ if deterministic:
849
+ return mu
850
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
851
+ return mu + std * torch.randn_like(std)
852
+
853
+ def clear_cache(self):
854
+ self._conv_num = count_conv3d(self.decoder)
855
+ self._conv_idx = [0]
856
+ self._feat_map = [None] * self._conv_num
857
+ # cache encode
858
+ self._enc_conv_num = count_conv3d(self.encoder)
859
+ self._enc_conv_idx = [0]
860
+ self._enc_feat_map = [None] * self._enc_conv_num
861
+
862
+
863
+ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
864
+ # params
865
+ cfg = dict(
866
+ dim=dim,
867
+ z_dim=z_dim,
868
+ dim_mult=[1, 2, 4, 4],
869
+ num_res_blocks=2,
870
+ attn_scales=[],
871
+ temperal_downsample=[True, True, True],
872
+ dropout=0.0,
873
+ )
874
+ cfg.update(**kwargs)
875
+
876
+ # init model
877
+ with torch.device("meta"):
878
+ model = WanVAE_(**cfg)
879
+
880
+ # load checkpoint
881
+ logging.info(f"loading {pretrained_path}")
882
+ model.load_state_dict(
883
+ torch.load(pretrained_path, map_location=device), assign=True)
884
+
885
+ return model
886
+
887
+
888
+ class Wan2_2_VAE:
889
+
890
+ def __init__(
891
+ self,
892
+ z_dim=48,
893
+ c_dim=160,
894
+ vae_pth=None,
895
+ dim_mult=[1, 2, 4, 4],
896
+ temperal_downsample=[False, True, True],
897
+ dtype=torch.float,
898
+ device="cuda",
899
+ ):
900
+
901
+ self.dtype = dtype
902
+ self.device = device
903
+
904
+ mean = torch.tensor(
905
+ [
906
+ -0.2289,
907
+ -0.0052,
908
+ -0.1323,
909
+ -0.2339,
910
+ -0.2799,
911
+ 0.0174,
912
+ 0.1838,
913
+ 0.1557,
914
+ -0.1382,
915
+ 0.0542,
916
+ 0.2813,
917
+ 0.0891,
918
+ 0.1570,
919
+ -0.0098,
920
+ 0.0375,
921
+ -0.1825,
922
+ -0.2246,
923
+ -0.1207,
924
+ -0.0698,
925
+ 0.5109,
926
+ 0.2665,
927
+ -0.2108,
928
+ -0.2158,
929
+ 0.2502,
930
+ -0.2055,
931
+ -0.0322,
932
+ 0.1109,
933
+ 0.1567,
934
+ -0.0729,
935
+ 0.0899,
936
+ -0.2799,
937
+ -0.1230,
938
+ -0.0313,
939
+ -0.1649,
940
+ 0.0117,
941
+ 0.0723,
942
+ -0.2839,
943
+ -0.2083,
944
+ -0.0520,
945
+ 0.3748,
946
+ 0.0152,
947
+ 0.1957,
948
+ 0.1433,
949
+ -0.2944,
950
+ 0.3573,
951
+ -0.0548,
952
+ -0.1681,
953
+ -0.0667,
954
+ ],
955
+ dtype=dtype,
956
+ device=device,
957
+ )
958
+ std = torch.tensor(
959
+ [
960
+ 0.4765,
961
+ 1.0364,
962
+ 0.4514,
963
+ 1.1677,
964
+ 0.5313,
965
+ 0.4990,
966
+ 0.4818,
967
+ 0.5013,
968
+ 0.8158,
969
+ 1.0344,
970
+ 0.5894,
971
+ 1.0901,
972
+ 0.6885,
973
+ 0.6165,
974
+ 0.8454,
975
+ 0.4978,
976
+ 0.5759,
977
+ 0.3523,
978
+ 0.7135,
979
+ 0.6804,
980
+ 0.5833,
981
+ 1.4146,
982
+ 0.8986,
983
+ 0.5659,
984
+ 0.7069,
985
+ 0.5338,
986
+ 0.4889,
987
+ 0.4917,
988
+ 0.4069,
989
+ 0.4999,
990
+ 0.6866,
991
+ 0.4093,
992
+ 0.5709,
993
+ 0.6065,
994
+ 0.6415,
995
+ 0.4944,
996
+ 0.5726,
997
+ 1.2042,
998
+ 0.5458,
999
+ 1.6887,
1000
+ 0.3971,
1001
+ 1.0600,
1002
+ 0.3943,
1003
+ 0.5537,
1004
+ 0.5444,
1005
+ 0.4089,
1006
+ 0.7468,
1007
+ 0.7744,
1008
+ ],
1009
+ dtype=dtype,
1010
+ device=device,
1011
+ )
1012
+ self.scale = [mean, 1.0 / std]
1013
+
1014
+ # init model
1015
+ self.model = (
1016
+ _video_vae(
1017
+ pretrained_path=vae_pth,
1018
+ z_dim=z_dim,
1019
+ dim=c_dim,
1020
+ dim_mult=dim_mult,
1021
+ temperal_downsample=temperal_downsample,
1022
+ ).eval().requires_grad_(False).to(device))
1023
+
1024
+ def encode(self, videos):
1025
+ try:
1026
+ if not isinstance(videos, list):
1027
+ raise TypeError("videos should be a list")
1028
+ with amp.autocast(dtype=self.dtype):
1029
+ return [
1030
+ self.model.encode(u.unsqueeze(0),
1031
+ self.scale).float().squeeze(0)
1032
+ for u in videos
1033
+ ]
1034
+ except TypeError as e:
1035
+ logging.info(e)
1036
+ return None
1037
+
1038
+ def decode(self, zs):
1039
+ try:
1040
+ if not isinstance(zs, list):
1041
+ raise TypeError("zs should be a list")
1042
+ with amp.autocast(dtype=self.dtype):
1043
+ return [
1044
+ self.model.decode(u.unsqueeze(0),
1045
+ self.scale).float().clamp_(-1,
1046
+ 1).squeeze(0)
1047
+ for u in zs
1048
+ ]
1049
+ except TypeError as e:
1050
+ logging.info(e)
1051
+ return None
wan/text2video.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ from tqdm import tqdm
16
+
17
+ from .distributed.fsdp import shard_model
18
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
19
+ from .distributed.util import get_world_size
20
+ from .modules.model import WanModel
21
+ from .modules.t5 import T5EncoderModel
22
+ from .modules.vae2_1 import Wan2_1_VAE
23
+ from .utils.fm_solvers import (
24
+ FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas,
26
+ retrieve_timesteps,
27
+ )
28
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+
31
+ class WanT2V:
32
+
33
+ def __init__(
34
+ self,
35
+ config,
36
+ checkpoint_dir,
37
+ device_id=0,
38
+ rank=0,
39
+ t5_fsdp=False,
40
+ dit_fsdp=False,
41
+ use_sp=False,
42
+ t5_cpu=False,
43
+ init_on_cpu=True,
44
+ convert_model_dtype=False,
45
+ ):
46
+ r"""
47
+ Initializes the Wan text-to-video generation model components.
48
+
49
+ Args:
50
+ config (EasyDict):
51
+ Object containing model parameters initialized from config.py
52
+ checkpoint_dir (`str`):
53
+ Path to directory containing model checkpoints
54
+ device_id (`int`, *optional*, defaults to 0):
55
+ Id of target GPU device
56
+ rank (`int`, *optional*, defaults to 0):
57
+ Process rank for distributed training
58
+ t5_fsdp (`bool`, *optional*, defaults to False):
59
+ Enable FSDP sharding for T5 model
60
+ dit_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for DiT model
62
+ use_sp (`bool`, *optional*, defaults to False):
63
+ Enable distribution strategy of sequence parallel.
64
+ t5_cpu (`bool`, *optional*, defaults to False):
65
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
66
+ init_on_cpu (`bool`, *optional*, defaults to True):
67
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
68
+ convert_model_dtype (`bool`, *optional*, defaults to False):
69
+ Convert DiT model parameters dtype to 'config.param_dtype'.
70
+ Only works without FSDP.
71
+ """
72
+ self.device = torch.device(f"cuda:{device_id}")
73
+ self.config = config
74
+ self.rank = rank
75
+ self.t5_cpu = t5_cpu
76
+ self.init_on_cpu = init_on_cpu
77
+
78
+ self.num_train_timesteps = config.num_train_timesteps
79
+ self.boundary = config.boundary
80
+ self.param_dtype = config.param_dtype
81
+
82
+ if t5_fsdp or dit_fsdp or use_sp:
83
+ self.init_on_cpu = False
84
+
85
+ shard_fn = partial(shard_model, device_id=device_id)
86
+ self.text_encoder = T5EncoderModel(
87
+ text_len=config.text_len,
88
+ dtype=config.t5_dtype,
89
+ device=torch.device('cpu'),
90
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
91
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
92
+ shard_fn=shard_fn if t5_fsdp else None)
93
+
94
+ self.vae_stride = config.vae_stride
95
+ self.patch_size = config.patch_size
96
+ self.vae = Wan2_1_VAE(
97
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
98
+ device=self.device)
99
+
100
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
101
+ self.low_noise_model = WanModel.from_pretrained(
102
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
103
+ self.low_noise_model = self._configure_model(
104
+ model=self.low_noise_model,
105
+ use_sp=use_sp,
106
+ dit_fsdp=dit_fsdp,
107
+ shard_fn=shard_fn,
108
+ convert_model_dtype=convert_model_dtype)
109
+
110
+ self.high_noise_model = WanModel.from_pretrained(
111
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
112
+ self.high_noise_model = self._configure_model(
113
+ model=self.high_noise_model,
114
+ use_sp=use_sp,
115
+ dit_fsdp=dit_fsdp,
116
+ shard_fn=shard_fn,
117
+ convert_model_dtype=convert_model_dtype)
118
+ if use_sp:
119
+ self.sp_size = get_world_size()
120
+ else:
121
+ self.sp_size = 1
122
+
123
+ self.sample_neg_prompt = config.sample_neg_prompt
124
+
125
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
126
+ convert_model_dtype):
127
+ """
128
+ Configures a model object. This includes setting evaluation modes,
129
+ applying distributed parallel strategy, and handling device placement.
130
+
131
+ Args:
132
+ model (torch.nn.Module):
133
+ The model instance to configure.
134
+ use_sp (`bool`):
135
+ Enable distribution strategy of sequence parallel.
136
+ dit_fsdp (`bool`):
137
+ Enable FSDP sharding for DiT model.
138
+ shard_fn (callable):
139
+ The function to apply FSDP sharding.
140
+ convert_model_dtype (`bool`):
141
+ Convert DiT model parameters dtype to 'config.param_dtype'.
142
+ Only works without FSDP.
143
+
144
+ Returns:
145
+ torch.nn.Module:
146
+ The configured model.
147
+ """
148
+ model.eval().requires_grad_(False)
149
+
150
+ if use_sp:
151
+ for block in model.blocks:
152
+ block.self_attn.forward = types.MethodType(
153
+ sp_attn_forward, block.self_attn)
154
+ model.forward = types.MethodType(sp_dit_forward, model)
155
+
156
+ if dist.is_initialized():
157
+ dist.barrier()
158
+
159
+ if dit_fsdp:
160
+ model = shard_fn(model)
161
+ else:
162
+ if convert_model_dtype:
163
+ model.to(self.param_dtype)
164
+ if not self.init_on_cpu:
165
+ model.to(self.device)
166
+
167
+ return model
168
+
169
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
170
+ r"""
171
+ Prepares and returns the required model for the current timestep.
172
+
173
+ Args:
174
+ t (torch.Tensor):
175
+ current timestep.
176
+ boundary (`int`):
177
+ The timestep threshold. If `t` is at or above this value,
178
+ the `high_noise_model` is considered as the required model.
179
+ offload_model (`bool`):
180
+ A flag intended to control the offloading behavior.
181
+
182
+ Returns:
183
+ torch.nn.Module:
184
+ The active model on the target device for the current timestep.
185
+ """
186
+ if t.item() >= boundary:
187
+ required_model_name = 'high_noise_model'
188
+ offload_model_name = 'low_noise_model'
189
+ else:
190
+ required_model_name = 'low_noise_model'
191
+ offload_model_name = 'high_noise_model'
192
+ if offload_model or self.init_on_cpu:
193
+ if next(getattr(
194
+ self,
195
+ offload_model_name).parameters()).device.type == 'cuda':
196
+ getattr(self, offload_model_name).to('cpu')
197
+ if next(getattr(
198
+ self,
199
+ required_model_name).parameters()).device.type == 'cpu':
200
+ getattr(self, required_model_name).to(self.device)
201
+ return getattr(self, required_model_name)
202
+
203
+ def generate(self,
204
+ input_prompt,
205
+ size=(1280, 720),
206
+ frame_num=81,
207
+ shift=5.0,
208
+ sample_solver='unipc',
209
+ sampling_steps=50,
210
+ guide_scale=5.0,
211
+ n_prompt="",
212
+ seed=-1,
213
+ offload_model=True):
214
+ r"""
215
+ Generates video frames from text prompt using diffusion process.
216
+
217
+ Args:
218
+ input_prompt (`str`):
219
+ Text prompt for content generation
220
+ size (`tuple[int]`, *optional*, defaults to (1280,720)):
221
+ Controls video resolution, (width,height).
222
+ frame_num (`int`, *optional*, defaults to 81):
223
+ How many frames to sample from a video. The number should be 4n+1
224
+ shift (`float`, *optional*, defaults to 5.0):
225
+ Noise schedule shift parameter. Affects temporal dynamics
226
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
227
+ Solver used to sample the video.
228
+ sampling_steps (`int`, *optional*, defaults to 50):
229
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
230
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
231
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
232
+ If tuple, the first guide_scale will be used for low noise model and
233
+ the second guide_scale will be used for high noise model.
234
+ n_prompt (`str`, *optional*, defaults to ""):
235
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
236
+ seed (`int`, *optional*, defaults to -1):
237
+ Random seed for noise generation. If -1, use random seed.
238
+ offload_model (`bool`, *optional*, defaults to True):
239
+ If True, offloads models to CPU during generation to save VRAM
240
+
241
+ Returns:
242
+ torch.Tensor:
243
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
244
+ - C: Color channels (3 for RGB)
245
+ - N: Number of frames (81)
246
+ - H: Frame height (from size)
247
+ - W: Frame width from size)
248
+ """
249
+ # preprocess
250
+ guide_scale = (guide_scale, guide_scale) if isinstance(
251
+ guide_scale, float) else guide_scale
252
+ F = frame_num
253
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
254
+ size[1] // self.vae_stride[1],
255
+ size[0] // self.vae_stride[2])
256
+
257
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
258
+ (self.patch_size[1] * self.patch_size[2]) *
259
+ target_shape[1] / self.sp_size) * self.sp_size
260
+
261
+ if n_prompt == "":
262
+ n_prompt = self.sample_neg_prompt
263
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
264
+ seed_g = torch.Generator(device=self.device)
265
+ seed_g.manual_seed(seed)
266
+
267
+ if not self.t5_cpu:
268
+ self.text_encoder.model.to(self.device)
269
+ context = self.text_encoder([input_prompt], self.device)
270
+ context_null = self.text_encoder([n_prompt], self.device)
271
+ if offload_model:
272
+ self.text_encoder.model.cpu()
273
+ else:
274
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
275
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
276
+ context = [t.to(self.device) for t in context]
277
+ context_null = [t.to(self.device) for t in context_null]
278
+
279
+ noise = [
280
+ torch.randn(
281
+ target_shape[0],
282
+ target_shape[1],
283
+ target_shape[2],
284
+ target_shape[3],
285
+ dtype=torch.float32,
286
+ device=self.device,
287
+ generator=seed_g)
288
+ ]
289
+
290
+ @contextmanager
291
+ def noop_no_sync():
292
+ yield
293
+
294
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
295
+ noop_no_sync)
296
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
297
+ noop_no_sync)
298
+
299
+ # evaluation mode
300
+ with (
301
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
302
+ torch.no_grad(),
303
+ no_sync_low_noise(),
304
+ no_sync_high_noise(),
305
+ ):
306
+ boundary = self.boundary * self.num_train_timesteps
307
+
308
+ if sample_solver == 'unipc':
309
+ sample_scheduler = FlowUniPCMultistepScheduler(
310
+ num_train_timesteps=self.num_train_timesteps,
311
+ shift=1,
312
+ use_dynamic_shifting=False)
313
+ sample_scheduler.set_timesteps(
314
+ sampling_steps, device=self.device, shift=shift)
315
+ timesteps = sample_scheduler.timesteps
316
+ elif sample_solver == 'dpm++':
317
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
318
+ num_train_timesteps=self.num_train_timesteps,
319
+ shift=1,
320
+ use_dynamic_shifting=False)
321
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
322
+ timesteps, _ = retrieve_timesteps(
323
+ sample_scheduler,
324
+ device=self.device,
325
+ sigmas=sampling_sigmas)
326
+ else:
327
+ raise NotImplementedError("Unsupported solver.")
328
+
329
+ # sample videos
330
+ latents = noise
331
+
332
+ arg_c = {'context': context, 'seq_len': seq_len}
333
+ arg_null = {'context': context_null, 'seq_len': seq_len}
334
+
335
+ for _, t in enumerate(tqdm(timesteps)):
336
+ latent_model_input = latents
337
+ timestep = [t]
338
+
339
+ timestep = torch.stack(timestep)
340
+
341
+ model = self._prepare_model_for_timestep(
342
+ t, boundary, offload_model)
343
+ sample_guide_scale = guide_scale[1] if t.item(
344
+ ) >= boundary else guide_scale[0]
345
+
346
+ noise_pred_cond = model(
347
+ latent_model_input, t=timestep, **arg_c)[0]
348
+ noise_pred_uncond = model(
349
+ latent_model_input, t=timestep, **arg_null)[0]
350
+
351
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
352
+ noise_pred_cond - noise_pred_uncond)
353
+
354
+ temp_x0 = sample_scheduler.step(
355
+ noise_pred.unsqueeze(0),
356
+ t,
357
+ latents[0].unsqueeze(0),
358
+ return_dict=False,
359
+ generator=seed_g)[0]
360
+ latents = [temp_x0.squeeze(0)]
361
+
362
+ x0 = latents
363
+ if offload_model:
364
+ self.low_noise_model.cpu()
365
+ self.high_noise_model.cpu()
366
+ torch.cuda.empty_cache()
367
+ if self.rank == 0:
368
+ videos = self.vae.decode(x0)
369
+
370
+ del noise, latents
371
+ del sample_scheduler
372
+ if offload_model:
373
+ gc.collect()
374
+ torch.cuda.synchronize()
375
+ if dist.is_initialized():
376
+ dist.barrier()
377
+
378
+ return videos[0] if self.rank == 0 else None
wan/textimage2video.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ import torchvision.transforms.functional as TF
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_2 import Wan2_2_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+ from .utils.utils import best_output_size, masks_like
32
+
33
+
34
+ class WanTI2V:
35
+
36
+ def __init__(
37
+ self,
38
+ config,
39
+ checkpoint_dir,
40
+ device_id=0,
41
+ rank=0,
42
+ t5_fsdp=False,
43
+ dit_fsdp=False,
44
+ use_sp=False,
45
+ t5_cpu=False,
46
+ init_on_cpu=True,
47
+ convert_model_dtype=False,
48
+ ):
49
+ r"""
50
+ Initializes the Wan text-to-video generation model components.
51
+
52
+ Args:
53
+ config (EasyDict):
54
+ Object containing model parameters initialized from config.py
55
+ checkpoint_dir (`str`):
56
+ Path to directory containing model checkpoints
57
+ device_id (`int`, *optional*, defaults to 0):
58
+ Id of target GPU device
59
+ rank (`int`, *optional*, defaults to 0):
60
+ Process rank for distributed training
61
+ t5_fsdp (`bool`, *optional*, defaults to False):
62
+ Enable FSDP sharding for T5 model
63
+ dit_fsdp (`bool`, *optional*, defaults to False):
64
+ Enable FSDP sharding for DiT model
65
+ use_sp (`bool`, *optional*, defaults to False):
66
+ Enable distribution strategy of sequence parallel.
67
+ t5_cpu (`bool`, *optional*, defaults to False):
68
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
69
+ init_on_cpu (`bool`, *optional*, defaults to True):
70
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
71
+ convert_model_dtype (`bool`, *optional*, defaults to False):
72
+ Convert DiT model parameters dtype to 'config.param_dtype'.
73
+ Only works without FSDP.
74
+ """
75
+ self.device = torch.device(f"cuda:{device_id}")
76
+ self.config = config
77
+ self.rank = rank
78
+ self.t5_cpu = t5_cpu
79
+ self.init_on_cpu = init_on_cpu
80
+
81
+ self.num_train_timesteps = config.num_train_timesteps
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None)
95
+
96
+ self.vae_stride = config.vae_stride
97
+ self.patch_size = config.patch_size
98
+ self.vae = Wan2_2_VAE(
99
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
100
+ device=self.device)
101
+
102
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
103
+ self.model = WanModel.from_pretrained(checkpoint_dir)
104
+ self.model = self._configure_model(
105
+ model=self.model,
106
+ use_sp=use_sp,
107
+ dit_fsdp=dit_fsdp,
108
+ shard_fn=shard_fn,
109
+ convert_model_dtype=convert_model_dtype)
110
+
111
+ if use_sp:
112
+ self.sp_size = get_world_size()
113
+ else:
114
+ self.sp_size = 1
115
+
116
+ self.sample_neg_prompt = config.sample_neg_prompt
117
+
118
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
119
+ convert_model_dtype):
120
+ """
121
+ Configures a model object. This includes setting evaluation modes,
122
+ applying distributed parallel strategy, and handling device placement.
123
+
124
+ Args:
125
+ model (torch.nn.Module):
126
+ The model instance to configure.
127
+ use_sp (`bool`):
128
+ Enable distribution strategy of sequence parallel.
129
+ dit_fsdp (`bool`):
130
+ Enable FSDP sharding for DiT model.
131
+ shard_fn (callable):
132
+ The function to apply FSDP sharding.
133
+ convert_model_dtype (`bool`):
134
+ Convert DiT model parameters dtype to 'config.param_dtype'.
135
+ Only works without FSDP.
136
+
137
+ Returns:
138
+ torch.nn.Module:
139
+ The configured model.
140
+ """
141
+ model.eval().requires_grad_(False)
142
+
143
+ if use_sp:
144
+ for block in model.blocks:
145
+ block.self_attn.forward = types.MethodType(
146
+ sp_attn_forward, block.self_attn)
147
+ model.forward = types.MethodType(sp_dit_forward, model)
148
+
149
+ if dist.is_initialized():
150
+ dist.barrier()
151
+
152
+ if dit_fsdp:
153
+ model = shard_fn(model)
154
+ else:
155
+ if convert_model_dtype:
156
+ model.to(self.param_dtype)
157
+ if not self.init_on_cpu:
158
+ model.to(self.device)
159
+
160
+ return model
161
+
162
+ def generate(self,
163
+ input_prompt,
164
+ img=None,
165
+ size=(1280, 704),
166
+ max_area=704 * 1280,
167
+ frame_num=81,
168
+ shift=5.0,
169
+ sample_solver='unipc',
170
+ sampling_steps=50,
171
+ guide_scale=5.0,
172
+ n_prompt="",
173
+ seed=-1,
174
+ offload_model=True):
175
+ r"""
176
+ Generates video frames from text prompt using diffusion process.
177
+
178
+ Args:
179
+ input_prompt (`str`):
180
+ Text prompt for content generation
181
+ img (PIL.Image.Image):
182
+ Input image tensor. Shape: [3, H, W]
183
+ size (`tuple[int]`, *optional*, defaults to (1280,704)):
184
+ Controls video resolution, (width,height).
185
+ max_area (`int`, *optional*, defaults to 704*1280):
186
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
187
+ frame_num (`int`, *optional*, defaults to 81):
188
+ How many frames to sample from a video. The number should be 4n+1
189
+ shift (`float`, *optional*, defaults to 5.0):
190
+ Noise schedule shift parameter. Affects temporal dynamics
191
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
192
+ Solver used to sample the video.
193
+ sampling_steps (`int`, *optional*, defaults to 50):
194
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
195
+ guide_scale (`float`, *optional*, defaults 5.0):
196
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
197
+ n_prompt (`str`, *optional*, defaults to ""):
198
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
199
+ seed (`int`, *optional*, defaults to -1):
200
+ Random seed for noise generation. If -1, use random seed.
201
+ offload_model (`bool`, *optional*, defaults to True):
202
+ If True, offloads models to CPU during generation to save VRAM
203
+
204
+ Returns:
205
+ torch.Tensor:
206
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
207
+ - C: Color channels (3 for RGB)
208
+ - N: Number of frames (81)
209
+ - H: Frame height (from size)
210
+ - W: Frame width from size)
211
+ """
212
+ # i2v
213
+ if img is not None:
214
+ return self.i2v(
215
+ input_prompt=input_prompt,
216
+ img=img,
217
+ max_area=max_area,
218
+ frame_num=frame_num,
219
+ shift=shift,
220
+ sample_solver=sample_solver,
221
+ sampling_steps=sampling_steps,
222
+ guide_scale=guide_scale,
223
+ n_prompt=n_prompt,
224
+ seed=seed,
225
+ offload_model=offload_model)
226
+ # t2v
227
+ return self.t2v(
228
+ input_prompt=input_prompt,
229
+ size=size,
230
+ frame_num=frame_num,
231
+ shift=shift,
232
+ sample_solver=sample_solver,
233
+ sampling_steps=sampling_steps,
234
+ guide_scale=guide_scale,
235
+ n_prompt=n_prompt,
236
+ seed=seed,
237
+ offload_model=offload_model)
238
+
239
+ def t2v(self,
240
+ input_prompt,
241
+ size=(1280, 704),
242
+ frame_num=121,
243
+ shift=5.0,
244
+ sample_solver='unipc',
245
+ sampling_steps=50,
246
+ guide_scale=5.0,
247
+ n_prompt="",
248
+ seed=-1,
249
+ offload_model=True):
250
+ r"""
251
+ Generates video frames from text prompt using diffusion process.
252
+
253
+ Args:
254
+ input_prompt (`str`):
255
+ Text prompt for content generation
256
+ size (`tuple[int]`, *optional*, defaults to (1280,704)):
257
+ Controls video resolution, (width,height).
258
+ frame_num (`int`, *optional*, defaults to 121):
259
+ How many frames to sample from a video. The number should be 4n+1
260
+ shift (`float`, *optional*, defaults to 5.0):
261
+ Noise schedule shift parameter. Affects temporal dynamics
262
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
263
+ Solver used to sample the video.
264
+ sampling_steps (`int`, *optional*, defaults to 50):
265
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
266
+ guide_scale (`float`, *optional*, defaults 5.0):
267
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
268
+ n_prompt (`str`, *optional*, defaults to ""):
269
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
270
+ seed (`int`, *optional*, defaults to -1):
271
+ Random seed for noise generation. If -1, use random seed.
272
+ offload_model (`bool`, *optional*, defaults to True):
273
+ If True, offloads models to CPU during generation to save VRAM
274
+
275
+ Returns:
276
+ torch.Tensor:
277
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
278
+ - C: Color channels (3 for RGB)
279
+ - N: Number of frames (81)
280
+ - H: Frame height (from size)
281
+ - W: Frame width from size)
282
+ """
283
+ # preprocess
284
+ F = frame_num
285
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
286
+ size[1] // self.vae_stride[1],
287
+ size[0] // self.vae_stride[2])
288
+
289
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
290
+ (self.patch_size[1] * self.patch_size[2]) *
291
+ target_shape[1] / self.sp_size) * self.sp_size
292
+
293
+ if n_prompt == "":
294
+ n_prompt = self.sample_neg_prompt
295
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
296
+ seed_g = torch.Generator(device=self.device)
297
+ seed_g.manual_seed(seed)
298
+
299
+ if not self.t5_cpu:
300
+ self.text_encoder.model.to(self.device)
301
+ context = self.text_encoder([input_prompt], self.device)
302
+ context_null = self.text_encoder([n_prompt], self.device)
303
+ if offload_model:
304
+ self.text_encoder.model.cpu()
305
+ else:
306
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
307
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
308
+ context = [t.to(self.device) for t in context]
309
+ context_null = [t.to(self.device) for t in context_null]
310
+
311
+ noise = [
312
+ torch.randn(
313
+ target_shape[0],
314
+ target_shape[1],
315
+ target_shape[2],
316
+ target_shape[3],
317
+ dtype=torch.float32,
318
+ device=self.device,
319
+ generator=seed_g)
320
+ ]
321
+
322
+ @contextmanager
323
+ def noop_no_sync():
324
+ yield
325
+
326
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
327
+
328
+ # evaluation mode
329
+ with (
330
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
331
+ torch.no_grad(),
332
+ no_sync(),
333
+ ):
334
+
335
+ if sample_solver == 'unipc':
336
+ sample_scheduler = FlowUniPCMultistepScheduler(
337
+ num_train_timesteps=self.num_train_timesteps,
338
+ shift=1,
339
+ use_dynamic_shifting=False)
340
+ sample_scheduler.set_timesteps(
341
+ sampling_steps, device=self.device, shift=shift)
342
+ timesteps = sample_scheduler.timesteps
343
+ elif sample_solver == 'dpm++':
344
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
349
+ timesteps, _ = retrieve_timesteps(
350
+ sample_scheduler,
351
+ device=self.device,
352
+ sigmas=sampling_sigmas)
353
+ else:
354
+ raise NotImplementedError("Unsupported solver.")
355
+
356
+ # sample videos
357
+ latents = noise
358
+ mask1, mask2 = masks_like(noise, zero=False)
359
+
360
+ arg_c = {'context': context, 'seq_len': seq_len}
361
+ arg_null = {'context': context_null, 'seq_len': seq_len}
362
+
363
+ if offload_model or self.init_on_cpu:
364
+ self.model.to(self.device)
365
+ torch.cuda.empty_cache()
366
+
367
+ for _, t in enumerate(tqdm(timesteps)):
368
+ latent_model_input = latents
369
+ timestep = [t]
370
+
371
+ timestep = torch.stack(timestep)
372
+
373
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
374
+ temp_ts = torch.cat([
375
+ temp_ts,
376
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
377
+ ])
378
+ timestep = temp_ts.unsqueeze(0)
379
+
380
+ noise_pred_cond = self.model(
381
+ latent_model_input, t=timestep, **arg_c)[0]
382
+ noise_pred_uncond = self.model(
383
+ latent_model_input, t=timestep, **arg_null)[0]
384
+
385
+ noise_pred = noise_pred_uncond + guide_scale * (
386
+ noise_pred_cond - noise_pred_uncond)
387
+
388
+ temp_x0 = sample_scheduler.step(
389
+ noise_pred.unsqueeze(0),
390
+ t,
391
+ latents[0].unsqueeze(0),
392
+ return_dict=False,
393
+ generator=seed_g)[0]
394
+ latents = [temp_x0.squeeze(0)]
395
+ x0 = latents
396
+ if offload_model:
397
+ self.model.cpu()
398
+ torch.cuda.synchronize()
399
+ torch.cuda.empty_cache()
400
+ if self.rank == 0:
401
+ videos = self.vae.decode(x0)
402
+
403
+ del noise, latents
404
+ del sample_scheduler
405
+ if offload_model:
406
+ gc.collect()
407
+ torch.cuda.synchronize()
408
+ if dist.is_initialized():
409
+ dist.barrier()
410
+
411
+ return videos[0] if self.rank == 0 else None
412
+
413
+ def i2v(self,
414
+ input_prompt,
415
+ img,
416
+ max_area=704 * 1280,
417
+ frame_num=121,
418
+ shift=5.0,
419
+ sample_solver='unipc',
420
+ sampling_steps=40,
421
+ guide_scale=5.0,
422
+ n_prompt="",
423
+ seed=-1,
424
+ offload_model=True):
425
+ r"""
426
+ Generates video frames from input image and text prompt using diffusion process.
427
+
428
+ Args:
429
+ input_prompt (`str`):
430
+ Text prompt for content generation.
431
+ img (PIL.Image.Image):
432
+ Input image tensor. Shape: [3, H, W]
433
+ max_area (`int`, *optional*, defaults to 704*1280):
434
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
435
+ frame_num (`int`, *optional*, defaults to 121):
436
+ How many frames to sample from a video. The number should be 4n+1
437
+ shift (`float`, *optional*, defaults to 5.0):
438
+ Noise schedule shift parameter. Affects temporal dynamics
439
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
440
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
441
+ Solver used to sample the video.
442
+ sampling_steps (`int`, *optional*, defaults to 40):
443
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
444
+ guide_scale (`float`, *optional*, defaults 5.0):
445
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
446
+ n_prompt (`str`, *optional*, defaults to ""):
447
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
448
+ seed (`int`, *optional*, defaults to -1):
449
+ Random seed for noise generation. If -1, use random seed
450
+ offload_model (`bool`, *optional*, defaults to True):
451
+ If True, offloads models to CPU during generation to save VRAM
452
+
453
+ Returns:
454
+ torch.Tensor:
455
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
456
+ - C: Color channels (3 for RGB)
457
+ - N: Number of frames (121)
458
+ - H: Frame height (from max_area)
459
+ - W: Frame width (from max_area)
460
+ """
461
+ # preprocess
462
+ ih, iw = img.height, img.width
463
+ dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[
464
+ 2] * self.vae_stride[2]
465
+ ow, oh = best_output_size(iw, ih, dw, dh, max_area)
466
+
467
+ scale = max(ow / iw, oh / ih)
468
+ img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
469
+
470
+ # center-crop
471
+ x1 = (img.width - ow) // 2
472
+ y1 = (img.height - oh) // 2
473
+ img = img.crop((x1, y1, x1 + ow, y1 + oh))
474
+ assert img.width == ow and img.height == oh
475
+
476
+ # to tensor
477
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
478
+
479
+ F = frame_num
480
+ seq_len = ((F - 1) // self.vae_stride[0] + 1) * (
481
+ oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (
482
+ self.patch_size[1] * self.patch_size[2])
483
+ seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size
484
+
485
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
486
+ seed_g = torch.Generator(device=self.device)
487
+ seed_g.manual_seed(seed)
488
+ noise = torch.randn(
489
+ self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
490
+ oh // self.vae_stride[1],
491
+ ow // self.vae_stride[2],
492
+ dtype=torch.float32,
493
+ generator=seed_g,
494
+ device=self.device)
495
+
496
+ if n_prompt == "":
497
+ n_prompt = self.sample_neg_prompt
498
+
499
+ # preprocess
500
+ if not self.t5_cpu:
501
+ self.text_encoder.model.to(self.device)
502
+ context = self.text_encoder([input_prompt], self.device)
503
+ context_null = self.text_encoder([n_prompt], self.device)
504
+ if offload_model:
505
+ self.text_encoder.model.cpu()
506
+ else:
507
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
508
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
509
+ context = [t.to(self.device) for t in context]
510
+ context_null = [t.to(self.device) for t in context_null]
511
+
512
+ z = self.vae.encode([img])
513
+
514
+ @contextmanager
515
+ def noop_no_sync():
516
+ yield
517
+
518
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
519
+
520
+ # evaluation mode
521
+ with (
522
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
523
+ torch.no_grad(),
524
+ no_sync(),
525
+ ):
526
+
527
+ if sample_solver == 'unipc':
528
+ sample_scheduler = FlowUniPCMultistepScheduler(
529
+ num_train_timesteps=self.num_train_timesteps,
530
+ shift=1,
531
+ use_dynamic_shifting=False)
532
+ sample_scheduler.set_timesteps(
533
+ sampling_steps, device=self.device, shift=shift)
534
+ timesteps = sample_scheduler.timesteps
535
+ elif sample_solver == 'dpm++':
536
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
537
+ num_train_timesteps=self.num_train_timesteps,
538
+ shift=1,
539
+ use_dynamic_shifting=False)
540
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
541
+ timesteps, _ = retrieve_timesteps(
542
+ sample_scheduler,
543
+ device=self.device,
544
+ sigmas=sampling_sigmas)
545
+ else:
546
+ raise NotImplementedError("Unsupported solver.")
547
+
548
+ # sample videos
549
+ latent = noise
550
+ mask1, mask2 = masks_like([noise], zero=True)
551
+ latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
552
+
553
+ arg_c = {
554
+ 'context': [context[0]],
555
+ 'seq_len': seq_len,
556
+ }
557
+
558
+ arg_null = {
559
+ 'context': context_null,
560
+ 'seq_len': seq_len,
561
+ }
562
+
563
+ if offload_model or self.init_on_cpu:
564
+ self.model.to(self.device)
565
+ torch.cuda.empty_cache()
566
+
567
+ for _, t in enumerate(tqdm(timesteps)):
568
+ latent_model_input = [latent.to(self.device)]
569
+ timestep = [t]
570
+
571
+ timestep = torch.stack(timestep).to(self.device)
572
+
573
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
574
+ temp_ts = torch.cat([
575
+ temp_ts,
576
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
577
+ ])
578
+ timestep = temp_ts.unsqueeze(0)
579
+
580
+ noise_pred_cond = self.model(
581
+ latent_model_input, t=timestep, **arg_c)[0]
582
+ if offload_model:
583
+ torch.cuda.empty_cache()
584
+ noise_pred_uncond = self.model(
585
+ latent_model_input, t=timestep, **arg_null)[0]
586
+ if offload_model:
587
+ torch.cuda.empty_cache()
588
+ noise_pred = noise_pred_uncond + guide_scale * (
589
+ noise_pred_cond - noise_pred_uncond)
590
+
591
+ temp_x0 = sample_scheduler.step(
592
+ noise_pred.unsqueeze(0),
593
+ t,
594
+ latent.unsqueeze(0),
595
+ return_dict=False,
596
+ generator=seed_g)[0]
597
+ latent = temp_x0.squeeze(0)
598
+ latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
599
+
600
+ x0 = [latent]
601
+ del latent_model_input, timestep
602
+
603
+ if offload_model:
604
+ self.model.cpu()
605
+ torch.cuda.synchronize()
606
+ torch.cuda.empty_cache()
607
+
608
+ if self.rank == 0:
609
+ videos = self.vae.decode(x0)
610
+
611
+ del noise, latent, x0
612
+ del sample_scheduler
613
+ if offload_model:
614
+ gc.collect()
615
+ torch.cuda.synchronize()
616
+ if dist.is_initialized():
617
+ dist.barrier()
618
+
619
+ return videos[0] if self.rank == 0 else None
wan/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .fm_solvers import (
3
+ FlowDPMSolverMultistepScheduler,
4
+ get_sampling_sigmas,
5
+ retrieve_timesteps,
6
+ )
7
+ from .fm_solvers_unipc import FlowUniPCMultistepScheduler
8
+
9
+ __all__ = [
10
+ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
11
+ 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
12
+ ]
wan/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (393 Bytes). View file