Arrokothwhi commited on
Commit
0d0c9d9
·
1 Parent(s): d77af5d

Add I2V demo

Browse files
README.md CHANGED
@@ -1,13 +1,23 @@
1
  ---
2
- title: Test
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
- python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: RefDecoder I2V Demo
3
+ emoji: 🎬
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
+ python_version: "3.10"
9
  app_file: app.py
10
  pinned: false
11
  ---
12
 
13
+ # RefDecoder I2V Demo
14
+
15
+ This Space:
16
+
17
+ 1. Generates Wan I2V latents from an input image and prompt
18
+ 2. Saves the latent tensor as a `.pt` file
19
+ 3. Decodes the same latents with Wan VAE
20
+ 4. Decodes the same latents with RefDecoder
21
+
22
+ The RefDecoder checkpoint is downloaded at runtime from:
23
+ `Arrokothwhi/RefDecoder` -> `I2V_Wan2.1/model.pt`
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import sys
3
+ import tempfile
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import imageio
9
+ import numpy as np
10
+ import torch
11
+ from diffusers import AutoencoderKLWan as DiffusersWanVAE
12
+ from diffusers import WanImageToVideoPipeline
13
+ from huggingface_hub import hf_hub_download
14
+ from transformers import CLIPVisionModel
15
+ from src.models.Wan.autoencoder_wanT import AutoencoderKLWan
16
+ from src.models.Wan.transformer_wan import WanDecoderTransformer
17
+
18
+
19
+ ROOT = Path(__file__).resolve().parent
20
+ if str(ROOT) not in sys.path:
21
+ sys.path.insert(0, str(ROOT))
22
+
23
+ MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
24
+ REFDECODER_REPO_ID = "Arrokothwhi/RefDecoder"
25
+ REFDECODER_CKPT_PATH_IN_REPO = "I2V_Wan2.1/model.pt"
26
+ NEGATIVE_PROMPT = (
27
+ "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, "
28
+ "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, "
29
+ "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, "
30
+ "misshapen limbs, fused fingers, still picture, messy background, three legs, many people "
31
+ "in the background, walking backwards"
32
+ )
33
+ TARGET_AREA = 480 * 832
34
+ FPS = 16
35
+ NUM_FRAMES = 17
36
+ NUM_INFERENCE_STEPS = 50
37
+ GUIDANCE_SCALE = 5.0
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ PIPE_DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
40
+
41
+
42
+ @lru_cache(maxsize=1)
43
+ def get_generation_pipe():
44
+ image_encoder = CLIPVisionModel.from_pretrained(
45
+ MODEL_ID,
46
+ subfolder="image_encoder",
47
+ torch_dtype=torch.float32,
48
+ )
49
+ vae = DiffusersWanVAE.from_pretrained(
50
+ MODEL_ID,
51
+ subfolder="vae",
52
+ torch_dtype=torch.float32,
53
+ )
54
+ pipe = WanImageToVideoPipeline.from_pretrained(
55
+ MODEL_ID,
56
+ vae=vae,
57
+ image_encoder=image_encoder,
58
+ torch_dtype=PIPE_DTYPE,
59
+ )
60
+ if DEVICE == "cuda":
61
+ pipe.enable_model_cpu_offload()
62
+ else:
63
+ pipe = pipe.to(DEVICE)
64
+ return pipe
65
+
66
+
67
+ @lru_cache(maxsize=1)
68
+ def get_wan_vae():
69
+ vae = DiffusersWanVAE.from_pretrained(
70
+ MODEL_ID,
71
+ subfolder="vae",
72
+ torch_dtype=torch.float32,
73
+ )
74
+ vae = vae.to(DEVICE)
75
+ vae.eval()
76
+ return vae
77
+
78
+
79
+ @lru_cache(maxsize=1)
80
+ def get_refdecoder_module():
81
+ vae = AutoencoderKLWan(
82
+ dropout_p=0.0,
83
+ use_reference=True,
84
+ ).eval()
85
+ transformer = WanDecoderTransformer(
86
+ chunk=5,
87
+ num_layers=10,
88
+ num_heads=12,
89
+ head_dim=128,
90
+ reusing=True,
91
+ pretrained=False,
92
+ ).eval()
93
+
94
+ ckpt_path = hf_hub_download(
95
+ repo_id=REFDECODER_REPO_ID,
96
+ filename=REFDECODER_CKPT_PATH_IN_REPO,
97
+ )
98
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
99
+ state_dict = checkpoint.get("state_dict", checkpoint.get("module", checkpoint))
100
+
101
+ vae_sd = {}
102
+ transformer_sd = {}
103
+ for key, value in state_dict.items():
104
+ if key.startswith("vae."):
105
+ vae_sd[key[len("vae.") :]] = value
106
+ elif key.startswith("transformer."):
107
+ transformer_sd[key[len("transformer.") :]] = value
108
+
109
+ vae.load_state_dict(vae_sd, strict=False)
110
+ transformer.load_state_dict(transformer_sd, strict=False)
111
+
112
+ vae = vae.to(DEVICE).eval()
113
+ transformer = transformer.to(DEVICE).eval()
114
+ return vae, transformer
115
+
116
+
117
+ def resize_image_for_wan(image, pipe):
118
+ image = image.convert("RGB")
119
+ aspect_ratio = image.height / image.width
120
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
121
+ height = round(np.sqrt(TARGET_AREA * aspect_ratio)) // mod_value * mod_value
122
+ width = round(np.sqrt(TARGET_AREA / aspect_ratio)) // mod_value * mod_value
123
+ resized = image.resize((width, height))
124
+ return resized, height, width
125
+
126
+
127
+ def build_reference_frame(image, device):
128
+ ref_array = np.asarray(image).astype(np.float32)
129
+ ref_tensor = torch.from_numpy(ref_array).permute(2, 0, 1)
130
+ ref_tensor = (ref_tensor / 255.0 - 0.5) * 2.0
131
+ return ref_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.float32)
132
+
133
+
134
+ def normalize_latent_shape(latents):
135
+ if isinstance(latents, list):
136
+ latents = latents[0]
137
+ if latents.ndim == 4:
138
+ latents = latents.unsqueeze(0)
139
+ if latents.ndim != 5:
140
+ raise ValueError(f"Expected latent shape [B,C,T,H,W], got {tuple(latents.shape)}")
141
+ return latents
142
+
143
+
144
+ def save_video_tensor(video_tensor, output_path):
145
+ video = (video_tensor / 2 + 0.5).clamp(0, 1)
146
+ video = video.squeeze(0).permute(1, 2, 3, 0).detach().cpu().float().numpy()
147
+ video = (video * 255).astype(np.uint8)
148
+ imageio.mimwrite(output_path, video, fps=FPS, quality=10)
149
+ return str(output_path)
150
+
151
+
152
+ def decode_with_wan_vae(latents):
153
+ vae = get_wan_vae()
154
+ latents = latents.to(device=DEVICE, dtype=torch.float32)
155
+ latents_mean = torch.tensor(vae.config.latents_mean, device=DEVICE, dtype=torch.float32).view(1, -1, 1, 1, 1)
156
+ latents_std = torch.tensor(vae.config.latents_std, device=DEVICE, dtype=torch.float32).view(1, -1, 1, 1, 1)
157
+ latents = latents * latents_std + latents_mean
158
+ with torch.no_grad():
159
+ video = vae.decode(latents, return_dict=False)[0]
160
+ return video
161
+
162
+
163
+ def decode_with_refdecoder(latents, reference_frame):
164
+ vae, transformer = get_refdecoder_module()
165
+ latents = latents.to(device=DEVICE, dtype=torch.float32)
166
+ latents_mean = torch.tensor(
167
+ vae.config.latents_mean,
168
+ device=DEVICE,
169
+ dtype=torch.float32,
170
+ ).view(1, -1, 1, 1, 1)
171
+ latents_std = torch.tensor(
172
+ vae.config.latents_std,
173
+ device=DEVICE,
174
+ dtype=torch.float32,
175
+ ).view(1, -1, 1, 1, 1)
176
+ latents = latents * latents_std + latents_mean
177
+ with torch.no_grad():
178
+ video = vae.decode(
179
+ latents,
180
+ transformer,
181
+ return_dict=True,
182
+ reference_frame=reference_frame,
183
+ skip=False,
184
+ window_size=-1,
185
+ ).sample
186
+ if hasattr(vae, "clear_cache"):
187
+ vae.clear_cache()
188
+ return video
189
+
190
+
191
+ def generate_and_decode(image, prompt, seed, progress=gr.Progress(track_tqdm=False)):
192
+ if image is None:
193
+ raise gr.Error("Please upload an input image.")
194
+ if not prompt or not prompt.strip():
195
+ raise gr.Error("Please enter a prompt.")
196
+ if DEVICE != "cuda":
197
+ raise gr.Error("This demo expects a CUDA GPU to run Wan I2V generation.")
198
+
199
+ seed = int(seed) if seed is not None else random.randint(0, 2**32 - 1)
200
+ run_dir = Path(tempfile.mkdtemp(prefix="refdecoder_demo_"))
201
+
202
+ progress(0.05, desc="Loading Wan I2V pipeline")
203
+ pipe = get_generation_pipe()
204
+
205
+ progress(0.15, desc="Preparing image")
206
+ resized_image, height, width = resize_image_for_wan(image, pipe)
207
+ reference_frame = build_reference_frame(resized_image, DEVICE)
208
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
209
+
210
+ progress(0.3, desc="Generating latent video")
211
+ with torch.no_grad():
212
+ output = pipe(
213
+ image=resized_image,
214
+ prompt=prompt.strip(),
215
+ negative_prompt=NEGATIVE_PROMPT,
216
+ height=height,
217
+ width=width,
218
+ num_frames=NUM_FRAMES,
219
+ num_inference_steps=NUM_INFERENCE_STEPS,
220
+ guidance_scale=GUIDANCE_SCALE,
221
+ generator=generator,
222
+ output_type="latent",
223
+ )
224
+ latents = normalize_latent_shape(output.frames).detach().cpu()
225
+
226
+ latent_path = run_dir / "wan_latents.pt"
227
+ torch.save(
228
+ {
229
+ "latents": latents,
230
+ "height": height,
231
+ "width": width,
232
+ "prompt": prompt.strip(),
233
+ "seed": seed,
234
+ },
235
+ latent_path,
236
+ )
237
+
238
+ progress(0.65, desc="Decoding with Wan VAE")
239
+ wan_video = decode_with_wan_vae(latents)
240
+ wan_video_path = save_video_tensor(wan_video, run_dir / "wan_vae.mp4")
241
+
242
+ progress(0.82, desc="Decoding with RefDecoder")
243
+ ref_video = decode_with_refdecoder(latents, reference_frame)
244
+ ref_video_path = save_video_tensor(ref_video, run_dir / "refdecoder.mp4")
245
+
246
+ if torch.cuda.is_available():
247
+ torch.cuda.empty_cache()
248
+
249
+ status = (
250
+ f"Seed: {seed}\n"
251
+ f"Resolution: {width}x{height}\n"
252
+ f"Frames: {NUM_FRAMES}\n"
253
+ f"Latents: {tuple(latents.shape)}"
254
+ )
255
+ progress(1.0, desc="Done")
256
+ return str(latent_path), wan_video_path, ref_video_path, status
257
+
258
+
259
+ with gr.Blocks(title="RefDecoder I2V Demo") as demo:
260
+ gr.Markdown(
261
+ """
262
+ # RefDecoder I2V Demo
263
+ Upload one image and one prompt. The app generates Wan I2V latents once, then decodes the same latents with:
264
+ 1. Wan's original VAE
265
+ 2. RefDecoder (`ckpt/model.pt`)
266
+ """
267
+ )
268
+
269
+ with gr.Row():
270
+ image_input = gr.Image(label="Input Image", type="pil")
271
+ with gr.Column():
272
+ prompt_input = gr.Textbox(
273
+ label="Prompt",
274
+ lines=4,
275
+ placeholder="Describe the motion you want to generate...",
276
+ )
277
+ seed_input = gr.Number(
278
+ label="Seed",
279
+ value=0,
280
+ precision=0,
281
+ info="Use a fixed seed for reproducible results.",
282
+ )
283
+ run_button = gr.Button("Generate and Decode", variant="primary")
284
+
285
+ with gr.Row():
286
+ latent_output = gr.File(label="Wan Latents (.pt)")
287
+ status_output = gr.Textbox(label="Run Info")
288
+
289
+ with gr.Row():
290
+ wan_video_output = gr.Video(label="Wan VAE Decode")
291
+ ref_video_output = gr.Video(label="RefDecoder Decode")
292
+
293
+ run_button.click(
294
+ fn=generate_and_decode,
295
+ inputs=[image_input, prompt_input, seed_input],
296
+ outputs=[latent_output, wan_video_output, ref_video_output, status_output],
297
+ )
298
+
299
+
300
+ if __name__ == "__main__":
301
+ demo.queue(max_size=2).launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==6.14.0
2
+ imageio==2.37.0
3
+ numpy==1.26.4
4
+ torch==2.7.0
5
+ transformers==4.56.2
6
+ diffusers==0.36.0
7
+ accelerate==1.10.1
8
+ einops==0.8.1
9
+ sentencepiece==0.2.1
10
+ safetensors==0.6.2
11
+ peft==0.18.0
12
+ huggingface-hub==0.34.4
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/models/Wan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
src/models/Wan/autoencoder_wanT.py ADDED
@@ -0,0 +1,1916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin
23
+ from diffusers.models.autoencoders.autoencoder_kl import (
24
+ AutoencoderKLOutput,
25
+ DecoderOutput,
26
+ DiagonalGaussianDistribution,
27
+ )
28
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import logging
31
+ from diffusers.utils.accelerate_utils import apply_forward_hook
32
+ from einops import rearrange
33
+
34
+ _ACTS = {
35
+ "silu": nn.SiLU,
36
+ "swish": nn.SiLU,
37
+ "gelu": nn.GELU,
38
+ "relu": nn.ReLU,
39
+ "mish": nn.Mish,
40
+ "tanh": nn.Tanh,
41
+ "sigmoid": nn.Sigmoid,
42
+ "identity": nn.Identity,
43
+ "none": nn.Identity,
44
+ }
45
+
46
+
47
+ def resolve_activation(x):
48
+ if x is None:
49
+ return nn.Identity()
50
+ if isinstance(x, nn.Module):
51
+ return x
52
+ name = str(x).strip().lower()
53
+ if name in _ACTS:
54
+ return _ACTS[name]()
55
+ if name in ("lrelu", "leaky_relu"):
56
+ return nn.LeakyReLU(0.01)
57
+ raise ValueError(f"Unknown activation: {x}")
58
+
59
+
60
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
+
62
+ CACHE_T = 0
63
+ LATENT_T_STRIDE = 100
64
+ GRADIENT_CHECKPOINTING = False
65
+
66
+ class AvgDown3D(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_channels,
70
+ out_channels,
71
+ factor_t,
72
+ factor_s=1,
73
+ ):
74
+ super().__init__()
75
+ self.in_channels = in_channels
76
+ self.out_channels = out_channels
77
+ self.factor_t = factor_t
78
+ self.factor_s = factor_s
79
+ self.factor = self.factor_t * self.factor_s * self.factor_s
80
+
81
+ assert in_channels * self.factor % out_channels == 0
82
+ self.group_size = in_channels * self.factor // out_channels
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
86
+ pad = (0, 0, 0, 0, pad_t, 0)
87
+ x = F.pad(x, pad)
88
+ B, C, T, H, W = x.shape
89
+ x = x.view(
90
+ B,
91
+ C,
92
+ T // self.factor_t,
93
+ self.factor_t,
94
+ H // self.factor_s,
95
+ self.factor_s,
96
+ W // self.factor_s,
97
+ self.factor_s,
98
+ )
99
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
100
+ x = x.view(
101
+ B,
102
+ C * self.factor,
103
+ T // self.factor_t,
104
+ H // self.factor_s,
105
+ W // self.factor_s,
106
+ )
107
+ x = x.view(
108
+ B,
109
+ self.out_channels,
110
+ self.group_size,
111
+ T // self.factor_t,
112
+ H // self.factor_s,
113
+ W // self.factor_s,
114
+ )
115
+ x = x.mean(dim=2)
116
+ return x
117
+
118
+
119
+ class DupUp3D(nn.Module):
120
+ def __init__(
121
+ self,
122
+ in_channels: int,
123
+ out_channels: int,
124
+ factor_t,
125
+ factor_s=1,
126
+ ):
127
+ super().__init__()
128
+ self.in_channels = in_channels
129
+ self.out_channels = out_channels
130
+
131
+ self.factor_t = factor_t
132
+ self.factor_s = factor_s
133
+ self.factor = self.factor_t * self.factor_s * self.factor_s
134
+
135
+ assert out_channels * self.factor % in_channels == 0
136
+ self.repeats = out_channels * self.factor // in_channels
137
+
138
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
139
+ x = x.repeat_interleave(self.repeats, dim=1)
140
+ x = x.view(
141
+ x.size(0),
142
+ self.out_channels,
143
+ self.factor_t,
144
+ self.factor_s,
145
+ self.factor_s,
146
+ x.size(2),
147
+ x.size(3),
148
+ x.size(4),
149
+ )
150
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
151
+ x = x.view(
152
+ x.size(0),
153
+ self.out_channels,
154
+ x.size(2) * self.factor_t,
155
+ x.size(4) * self.factor_s,
156
+ x.size(6) * self.factor_s,
157
+ )
158
+ if first_chunk:
159
+ x = x[:, :, self.factor_t - 1 :, :, :]
160
+ return x
161
+
162
+
163
+ class WanCausalConv3d(nn.Conv3d):
164
+ r"""
165
+ A custom 3D causal convolution layer with feature caching support.
166
+
167
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
168
+ caching for efficient inference.
169
+
170
+ Args:
171
+ in_channels (int): Number of channels in the input image
172
+ out_channels (int): Number of channels produced by the convolution
173
+ kernel_size (int or tuple): Size of the convolving kernel
174
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
175
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ in_channels: int,
181
+ out_channels: int,
182
+ kernel_size: Union[int, Tuple[int, int, int]],
183
+ stride: Union[int, Tuple[int, int, int]] = 1,
184
+ padding: Union[int, Tuple[int, int, int]] = 0,
185
+ ) -> None:
186
+ super().__init__(
187
+ in_channels=in_channels,
188
+ out_channels=out_channels,
189
+ kernel_size=kernel_size,
190
+ stride=stride,
191
+ padding=padding,
192
+ )
193
+
194
+ # Set up causal padding
195
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
196
+ self.padding = (0, 0, 0)
197
+
198
+ def forward(self, x, cache_x=None, mode=None):
199
+ padding = list(self._padding)
200
+ if cache_x is not None and self._padding[4] > 0:
201
+ cache_x = cache_x.to(x.device)
202
+ x = torch.cat([cache_x, x], dim=2)
203
+ padding[4] -= cache_x.shape[2]
204
+
205
+ if mode == 'upsample3d':
206
+ # x: BCTHW
207
+ assert self.stride[0] == 1 and self.stride[1] == 1 and self.stride[2] == 1
208
+ assert self.kernel_size[0] == 3
209
+
210
+ assert padding[0] == padding[1] and padding[2] == padding[3]
211
+
212
+ results = []
213
+ for i in range(x.shape[2] if padding[-2] == 2 else x.shape[2] - 1):
214
+ if padding[-2] == 2:
215
+ if i == 0:
216
+ out = F.conv3d(x[:, :, 0:1, :, :], self.weight, self.bias, self.stride, (2, padding[2], padding[0]))[:, :, :-2] # BC1HW
217
+ elif i == 1:
218
+ out = F.conv3d(x[:, :, 0:2, :, :], self.weight, self.bias, self.stride, (1, padding[2], padding[0]))[:, :, :-1] # BC1HW
219
+ else:
220
+ out = F.conv3d(x[:, :, i - 2: i - 2 + self.kernel_size[0], :, :], self.weight, self.bias, self.stride, (0, padding[2], padding[0])) # BC1HW
221
+ elif padding[-2] == 1:
222
+ if i == 0:
223
+ out = F.conv3d(x[:, :, 0:2, :, :], self.weight, self.bias, self.stride, (1, padding[2], padding[0]))[:, :, :-1] # BC1HW
224
+ else:
225
+ out = F.conv3d(x[:, :, i - 1: i - 1 + self.kernel_size[0], :, :], self.weight, self.bias, self.stride, (0, padding[2], padding[0])) # BC1HW
226
+ else:
227
+ raise ValueError("Invalid padding for causal conv3d in upsample3d mode.")
228
+ results.append(out)
229
+
230
+ if not results:
231
+ breakpoint() # TODO
232
+
233
+ return torch.cat(results, dim=2) # BCTHW
234
+
235
+ x = F.pad(x, padding)
236
+ return super().forward(x)
237
+
238
+
239
+ '''
240
+ if mode == "upsample3d":
241
+ padding = list(self._padding)
242
+ x = F.pad(x, padding)
243
+ t = x.shape[2]
244
+ itr = t - 2
245
+ print(f"DEBUG: time frame {t}")
246
+ out = super().forward(x[:, :, :1, :, :])
247
+ for i in range(1, itr):
248
+ out_ = super().forward(x[:, :, i: i + 4, :, :])
249
+ out = torch.cat([out, out_], 2)
250
+ return out
251
+ else:
252
+ padding = list(self._padding)
253
+ if cache_x is not None and self._padding[4] > 0:
254
+ cache_x = cache_x.to(x.device)
255
+ x = torch.cat([cache_x, x], dim=2)
256
+ padding[4] -= cache_x.shape[2]
257
+ x = F.pad(x, padding)
258
+
259
+ print(x.shape, self.weight.shape)
260
+ print(x.dtype, self.weight.dtype)
261
+ return super().forward(x)
262
+ '''
263
+
264
+
265
+ class WanRMS_norm(nn.Module):
266
+ r"""
267
+ A custom RMS normalization layer.
268
+
269
+ Args:
270
+ dim (int): The number of dimensions to normalize over.
271
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
272
+ Default is True.
273
+ images (bool, optional): Whether the input represents image data. Default is True.
274
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
275
+ """
276
+
277
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
278
+ super().__init__()
279
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
280
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
281
+
282
+ self.channel_first = channel_first
283
+ self.scale = dim**0.5
284
+ self.gamma = nn.Parameter(torch.ones(shape))
285
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
286
+
287
+ def forward(self, x):
288
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
289
+
290
+
291
+ class WanUpsample(nn.Upsample):
292
+ r"""
293
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
294
+
295
+ Args:
296
+ x (torch.Tensor): Input tensor to be upsampled.
297
+
298
+ Returns:
299
+ torch.Tensor: Upsampled tensor with the same data type as the input.
300
+ """
301
+
302
+ def forward(self, x):
303
+ return super().forward(x.float()).type_as(x)
304
+
305
+
306
+ class WanResample(nn.Module):
307
+ r"""
308
+ A custom resampling module for 2D and 3D data.
309
+
310
+ Args:
311
+ dim (int): The number of input/output channels.
312
+ mode (str): The resampling mode. Must be one of:
313
+ - 'none': No resampling (identity operation).
314
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
315
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
316
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
317
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
318
+ """
319
+
320
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
321
+ super().__init__()
322
+ self.dim = dim
323
+ self.mode = mode
324
+
325
+ # default to dim //2
326
+ if upsample_out_dim is None:
327
+ upsample_out_dim = dim // 2
328
+
329
+ # layers
330
+ if mode == "upsample2d":
331
+ self.resample = nn.Sequential(
332
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
333
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
334
+ )
335
+ elif mode == "upsample3d":
336
+ self.resample = nn.Sequential(
337
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
338
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
339
+ )
340
+ self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
341
+
342
+ elif mode == "downsample2d":
343
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
344
+ elif mode == "downsample3d":
345
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
346
+ self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
347
+
348
+ else:
349
+ self.resample = nn.Identity()
350
+
351
+ def forward(self, x, feat_cache=None, feat_idx=[0], is_reference=False, first_chunk=False):
352
+ b, c, t, h, w = x.size()
353
+
354
+ if self.mode == "upsample3d":
355
+ if feat_cache is not None and not is_reference:
356
+ # Latent frames: full caching logic
357
+ idx = feat_idx[0]
358
+
359
+ if feat_cache[idx] is None:
360
+ if t <= 1:
361
+ feat_cache[idx] = "Rep"
362
+ feat_idx[0] += 1
363
+ else:
364
+ subseq = x[:, :, 1:]
365
+ cache_x = subseq[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else subseq[:, :, :0, :, :]
366
+ if cache_x.shape[2] < 2:
367
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
368
+
369
+ subseq = self.time_conv(subseq, mode=self.mode)
370
+
371
+ feat_cache[idx] = cache_x
372
+ feat_idx[0] += 1
373
+
374
+ subseq = subseq.reshape(b, 2, c, t - 1, h, w)
375
+ subseq = torch.stack((subseq[:, 0, :, :, :, :], subseq[:, 1, :, :, :, :]), 3)
376
+ subseq = subseq.reshape(b, c, (t - 1) * 2, h, w)
377
+ x = torch.cat([x[:, :, :1, :, :], subseq], dim=2)
378
+ else:
379
+ cache_x = x[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else x[:, :, :0, :, :]
380
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
381
+ cache_x = torch.cat([feat_cache[idx][:, :, -1:, :, :].to(cache_x.device), cache_x], dim=2)
382
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
383
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
384
+
385
+ if feat_cache[idx] == "Rep":
386
+ x = self.time_conv(x, mode=self.mode)
387
+ else:
388
+ x = self.time_conv(x, feat_cache[idx], mode=self.mode)
389
+
390
+ feat_cache[idx] = cache_x
391
+ feat_idx[0] += 1
392
+
393
+ x = x.reshape(b, 2, c, t, h, w)
394
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
395
+ x = x.reshape(b, c, t * 2, h, w)
396
+
397
+ # Spatial resampling (applies to all paths)
398
+ t = x.shape[2]
399
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
400
+ x = self.resample(x)
401
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
402
+
403
+ if self.mode == "downsample3d":
404
+ if feat_cache is not None and not is_reference:
405
+ idx = feat_idx[0]
406
+ if feat_cache[idx] is None:
407
+ if t <= 1:
408
+ feat_cache[idx] = x.clone()
409
+ feat_idx[0] += 1
410
+ else:
411
+ subseq = x[:, :, 1:]
412
+ cache_x = subseq[:, :, -1:, :, :].clone()
413
+ subseq = self.time_conv(x)
414
+ x = torch.cat([x[:, :, :1, :, :], subseq], dim=2)
415
+ feat_cache[idx] = cache_x
416
+ feat_idx[0] += 1
417
+ else:
418
+ cache_x = x[:, :, -1:, :, :].clone()
419
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
420
+ feat_cache[idx] = cache_x
421
+ feat_idx[0] += 1
422
+ return x
423
+
424
+
425
+ class WanResidualBlock(nn.Module):
426
+ r"""
427
+ A custom residual block module.
428
+
429
+ Args:
430
+ in_dim (int): Number of input channels.
431
+ out_dim (int): Number of output channels.
432
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
433
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
434
+ """
435
+
436
+ def __init__(
437
+ self,
438
+ in_dim: int,
439
+ out_dim: int,
440
+ dropout: float = 0.0,
441
+ non_linearity: str = "silu",
442
+ ) -> None:
443
+ super().__init__()
444
+ self.in_dim = in_dim
445
+ self.out_dim = out_dim
446
+ self.nonlinearity = resolve_activation(non_linearity)
447
+
448
+ # layers
449
+ self.norm1 = WanRMS_norm(in_dim, images=False)
450
+ self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
451
+ self.norm2 = WanRMS_norm(out_dim, images=False)
452
+ self.dropout = nn.Dropout(dropout)
453
+ self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
454
+ self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
455
+
456
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
457
+ # Apply shortcut connection
458
+ h = self.conv_shortcut(x)
459
+
460
+ # First normalization and activation
461
+ x = self.norm1(x)
462
+ x = self.nonlinearity(x)
463
+
464
+ if feat_cache is not None:
465
+ idx = feat_idx[0]
466
+ cache_x = x[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else x[:, :, :0, :, :]
467
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
468
+ cache_x = torch.cat([feat_cache[idx][:, :, -1:, :, :].to(cache_x.device), cache_x], dim=2)
469
+
470
+ x = self.conv1(x, feat_cache[idx], mode='upsample3d')
471
+ feat_cache[idx] = cache_x
472
+ feat_idx[0] += 1
473
+ else:
474
+ x = self.conv1(x, mode='upsample3d')
475
+
476
+ # Second normalization and activation
477
+ x = self.norm2(x)
478
+ x = self.nonlinearity(x)
479
+
480
+ # Dropout
481
+ x = self.dropout(x)
482
+
483
+ if feat_cache is not None:
484
+ idx = feat_idx[0]
485
+ cache_x = x[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else x[:, :, :0, :, :]
486
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
487
+ cache_x = torch.cat([feat_cache[idx][:, :, -1:, :, :].to(cache_x.device), cache_x], dim=2)
488
+
489
+ x = self.conv2(x, feat_cache[idx], mode='upsample3d')
490
+ feat_cache[idx] = cache_x
491
+ feat_idx[0] += 1
492
+ else:
493
+ x = self.conv2(x, mode='upsample3d')
494
+
495
+ # Add residual connection
496
+ return x + h
497
+
498
+ class WanAttentionBlock(nn.Module):
499
+ """
500
+ Causal self-attention with a single head.
501
+
502
+ Args:
503
+ dim (int): The number of channels in the input tensor.
504
+ """
505
+
506
+ def __init__(self, dim):
507
+ super().__init__()
508
+ self.dim = dim
509
+
510
+ # layers
511
+ self.norm = WanRMS_norm(dim)
512
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
513
+ self.proj = nn.Conv2d(dim, dim, 1)
514
+
515
+ def forward(self, x):
516
+ identity = x
517
+ batch_size, channels, time, height, width = x.size()
518
+
519
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
520
+ x = self.norm(x)
521
+
522
+ # compute query, key, value
523
+ qkv = self.to_qkv(x)
524
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
525
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
526
+ q, k, v = qkv.chunk(3, dim=-1)
527
+
528
+ # apply attention
529
+ x = F.scaled_dot_product_attention(q, k, v)
530
+
531
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
532
+
533
+ # output projection
534
+ x = self.proj(x)
535
+
536
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
537
+ x = x.view(batch_size, time, channels, height, width)
538
+ x = x.permute(0, 2, 1, 3, 4)
539
+
540
+ return x + identity
541
+
542
+
543
+ class WanMidBlock(nn.Module):
544
+ """
545
+ Middle block for WanVAE encoder and decoder.
546
+
547
+ Args:
548
+ dim (int): Number of input/output channels.
549
+ dropout (float): Dropout rate.
550
+ non_linearity (str): Type of non-linearity to use.
551
+ """
552
+
553
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
554
+ super().__init__()
555
+ self.dim = dim
556
+
557
+ # Create the components
558
+ resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
559
+ attentions = []
560
+ for _ in range(num_layers):
561
+ attentions.append(WanAttentionBlock(dim))
562
+ resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
563
+ self.attentions = nn.ModuleList(attentions)
564
+ self.resnets = nn.ModuleList(resnets)
565
+
566
+ self.gradient_checkpointing = GRADIENT_CHECKPOINTING
567
+
568
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
569
+ # First residual block
570
+ x = self.resnets[0](x, feat_cache, feat_idx)
571
+
572
+ # Process through attention and residual blocks
573
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
574
+ if attn is not None:
575
+ if self.gradient_checkpointing:
576
+ x = torch.utils.checkpoint.checkpoint(
577
+ attn,
578
+ x,
579
+ use_reentrant=False,
580
+ )
581
+ else:
582
+ x = attn(x)
583
+
584
+ if self.gradient_checkpointing and feat_cache is not None:
585
+ # Save mutable state before checkpoint; it will be restored on recompute.
586
+ initial_idx = feat_idx[0]
587
+ initial_cache_snapshot = [
588
+ (c.clone() if isinstance(c, torch.Tensor) else c)
589
+ for c in feat_cache
590
+ ]
591
+
592
+ def checkpoint_fn(x, block=resnet):
593
+ feat_idx[0] = initial_idx
594
+ for j in range(len(feat_cache)):
595
+ val = initial_cache_snapshot[j]
596
+ feat_cache[j] = val.clone() if isinstance(val, torch.Tensor) else val
597
+ return block(x, feat_cache, feat_idx)
598
+
599
+ x = torch.utils.checkpoint.checkpoint(
600
+ checkpoint_fn,
601
+ x,
602
+ use_reentrant=False,
603
+ )
604
+ else:
605
+ x = resnet(x, feat_cache, feat_idx)
606
+
607
+ return x
608
+
609
+
610
+ class WanResidualDownBlock(nn.Module):
611
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
612
+ super().__init__()
613
+
614
+ # Shortcut path with downsample
615
+ self.avg_shortcut = AvgDown3D(
616
+ in_dim,
617
+ out_dim,
618
+ factor_t=2 if temperal_downsample else 1,
619
+ factor_s=2 if down_flag else 1,
620
+ )
621
+
622
+ # Main path with residual blocks and downsample
623
+ resnets = []
624
+ for _ in range(num_res_blocks):
625
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
626
+ in_dim = out_dim
627
+ self.resnets = nn.ModuleList(resnets)
628
+
629
+ # Add the final downsample block
630
+ if down_flag:
631
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
632
+ self.downsampler = WanResample(out_dim, mode=mode)
633
+ else:
634
+ self.downsampler = None
635
+
636
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
637
+ x_copy = x.clone()
638
+ for resnet in self.resnets:
639
+ x = resnet(x, feat_cache, feat_idx)
640
+ if self.downsampler is not None:
641
+ x = self.downsampler(x, feat_cache, feat_idx)
642
+
643
+ return x + self.avg_shortcut(x_copy)
644
+
645
+
646
+ class WanEncoder3d(nn.Module):
647
+ r"""
648
+ A 3D encoder module.
649
+
650
+ Args:
651
+ dim (int): The base number of channels in the first layer.
652
+ z_dim (int): The dimensionality of the latent space.
653
+ dim_mult (list of int): Multipliers for the number of channels in each block.
654
+ num_res_blocks (int): Number of residual blocks in each block.
655
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
656
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
657
+ dropout (float): Dropout rate for the dropout layers.
658
+ non_linearity (str): Type of non-linearity to use.
659
+ """
660
+
661
+ def __init__(
662
+ self,
663
+ in_channels: int = 3,
664
+ dim=128,
665
+ z_dim=4,
666
+ dim_mult=[1, 2, 4, 4],
667
+ num_res_blocks=2,
668
+ attn_scales=[],
669
+ temperal_downsample=[True, True, False],
670
+ dropout=0.0,
671
+ non_linearity: str = "silu",
672
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
673
+ ):
674
+ super().__init__()
675
+ self.dim = dim
676
+ self.z_dim = z_dim
677
+ self.dim_mult = dim_mult
678
+ self.num_res_blocks = num_res_blocks
679
+ self.attn_scales = attn_scales
680
+ self.temperal_downsample = temperal_downsample
681
+ self.nonlinearity = resolve_activation(non_linearity)
682
+
683
+ # dimensions
684
+ dims = [dim * u for u in [1] + dim_mult]
685
+ scale = 1.0
686
+
687
+ # init block
688
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
689
+
690
+ # downsample blocks
691
+ self.down_blocks = nn.ModuleList([])
692
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
693
+ # residual (+attention) blocks
694
+ if is_residual:
695
+ self.down_blocks.append(
696
+ WanResidualDownBlock(
697
+ in_dim,
698
+ out_dim,
699
+ dropout,
700
+ num_res_blocks,
701
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
702
+ down_flag=i != len(dim_mult) - 1,
703
+ )
704
+ )
705
+ else:
706
+ for _ in range(num_res_blocks):
707
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
708
+ if scale in attn_scales:
709
+ self.down_blocks.append(WanAttentionBlock(out_dim))
710
+ in_dim = out_dim
711
+
712
+ # downsample block
713
+ if i != len(dim_mult) - 1:
714
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
715
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
716
+ scale /= 2.0
717
+
718
+ # middle blocks
719
+ self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
720
+
721
+ # output blocks
722
+ self.norm_out = WanRMS_norm(out_dim, images=False)
723
+ self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
724
+
725
+ self.gradient_checkpointing = False
726
+
727
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
728
+ if feat_cache is not None:
729
+ idx = feat_idx[0]
730
+ cache_x = x[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else x[:, :, :0, :, :]
731
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
732
+ # cache last frame of last two chunk
733
+ cache_x = torch.cat([feat_cache[idx][:, :, -1:, :, :].to(cache_x.device), cache_x], dim=2)
734
+ x = self.conv_in(x, feat_cache[idx])
735
+ feat_cache[idx] = cache_x
736
+ feat_idx[0] += 1
737
+ else:
738
+ x = self.conv_in(x)
739
+
740
+ ## downsamples
741
+ for layer in self.down_blocks:
742
+ if feat_cache is not None:
743
+ x = layer(x, feat_cache, feat_idx)
744
+ else:
745
+ x = layer(x)
746
+
747
+ ## middle
748
+ x = self.mid_block(x, feat_cache, feat_idx)
749
+
750
+ ## head
751
+ x = self.norm_out(x)
752
+ x = self.nonlinearity(x)
753
+
754
+ if feat_cache is not None:
755
+ idx = feat_idx[0]
756
+ cache_x = x[:, :, -CACHE_T:, :, :].clone() if CACHE_T > 0 else x[:, :, :0, :, :]
757
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
758
+ # cache last frame of last two chunk
759
+ cache_x = torch.cat([feat_cache[idx][:, :, -1:, :, :].to(cache_x.device), cache_x], dim=2)
760
+ x = self.conv_out(x, feat_cache[idx])
761
+ feat_cache[idx] = cache_x
762
+ feat_idx[0] += 1
763
+ else:
764
+ x = self.conv_out(x)
765
+ return x
766
+
767
+
768
+ class WanResidualUpBlock(nn.Module):
769
+ """
770
+ A block that handles upsampling for the WanVAE decoder.
771
+
772
+ Args:
773
+ in_dim (int): Input dimension
774
+ out_dim (int): Output dimension
775
+ num_res_blocks (int): Number of residual blocks
776
+ dropout (float): Dropout rate
777
+ temperal_upsample (bool): Whether to upsample on temporal dimension
778
+ up_flag (bool): Whether to upsample or not
779
+ non_linearity (str): Type of non-linearity to use
780
+ """
781
+
782
+ def __init__(
783
+ self,
784
+ in_dim: int,
785
+ out_dim: int,
786
+ num_res_blocks: int,
787
+ dropout: float = 0.0,
788
+ temperal_upsample: bool = False,
789
+ up_flag: bool = False,
790
+ non_linearity: str = "silu",
791
+ ):
792
+ super().__init__()
793
+ self.in_dim = in_dim
794
+ self.out_dim = out_dim
795
+
796
+ if up_flag:
797
+ self.avg_shortcut = DupUp3D(
798
+ in_dim,
799
+ out_dim,
800
+ factor_t=2 if temperal_upsample else 1,
801
+ factor_s=2,
802
+ )
803
+ else:
804
+ self.avg_shortcut = None
805
+
806
+ # create residual blocks
807
+ resnets = []
808
+ current_dim = in_dim
809
+ for _ in range(num_res_blocks + 1):
810
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
811
+ current_dim = out_dim
812
+
813
+ self.resnets = nn.ModuleList(resnets)
814
+
815
+ # Add upsampling layer if needed
816
+ if up_flag:
817
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
818
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
819
+ else:
820
+ self.upsampler = None
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False, is_reference=False):
825
+ """
826
+ Forward pass through the upsampling block.
827
+
828
+ Args:
829
+ x (torch.Tensor): Input tensor
830
+ feat_cache (list, optional): Feature cache for causal convolutions
831
+ feat_idx (list, optional): Feature index for cache management
832
+ first_chunk (bool, optional): Whether this is the first chunk
833
+ is_reference (bool, optional): Whether processing reference tokens
834
+
835
+ Returns:
836
+ torch.Tensor: Output tensor
837
+ """
838
+ x_copy = x.clone()
839
+
840
+ for resnet in self.resnets:
841
+ if feat_cache is not None:
842
+ x = resnet(x, feat_cache, feat_idx, is_reference=is_reference)
843
+ else:
844
+ x = resnet(x)
845
+
846
+ if self.upsampler is not None:
847
+ if feat_cache is not None:
848
+ x = self.upsampler(x, feat_cache, feat_idx)
849
+ else:
850
+ # Pass is_reference to upsampler
851
+ x = self.upsampler(x, is_reference=is_reference)
852
+
853
+ if self.avg_shortcut is not None:
854
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk, is_reference=is_reference)
855
+
856
+ return x
857
+
858
+
859
+ class WanUpBlock(nn.Module):
860
+ """
861
+ A block that handles upsampling for the WanVAE decoder.
862
+
863
+ Args:
864
+ in_dim (int): Input dimension
865
+ out_dim (int): Output dimension
866
+ num_res_blocks (int): Number of residual blocks
867
+ dropout (float): Dropout rate
868
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
869
+ non_linearity (str): Type of non-linearity to use
870
+ """
871
+
872
+ def __init__(
873
+ self,
874
+ in_dim: int,
875
+ out_dim: int,
876
+ num_res_blocks: int,
877
+ dropout: float = 0.0,
878
+ upsample_mode: Optional[str] = None,
879
+ non_linearity: str = "silu",
880
+ ):
881
+ super().__init__()
882
+ self.in_dim = in_dim
883
+ self.out_dim = out_dim
884
+
885
+ # Create layers list
886
+ resnets = []
887
+ # Add residual blocks and attention if needed
888
+ current_dim = in_dim
889
+ for _ in range(num_res_blocks + 1):
890
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
891
+ current_dim = out_dim
892
+
893
+ self.resnets = nn.ModuleList(resnets)
894
+
895
+ # Add upsampling layer if needed
896
+ self.upsamplers = None
897
+ if upsample_mode is not None:
898
+ self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
899
+
900
+ self.gradient_checkpointing = False
901
+
902
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None, is_reference=False):
903
+ """
904
+ Forward pass through the upsampling block.
905
+
906
+ Args:
907
+ x (torch.Tensor): Input tensor
908
+ feat_cache (list, optional): Feature cache for causal convolutions
909
+ feat_idx (list, optional): Feature index for cache management
910
+ first_chunk (bool, optional): Whether this is the first chunk
911
+ is_reference (bool, optional): Whether processing reference tokens
912
+
913
+ Returns:
914
+ torch.Tensor: Output tensor
915
+ """
916
+ # Pass is_reference to all resnets
917
+ for resnet in self.resnets:
918
+ if feat_cache is not None:
919
+ x = resnet(x, feat_cache, feat_idx)
920
+ else:
921
+ x = resnet(x)
922
+
923
+ # Pass is_reference to upsampler
924
+ if self.upsamplers is not None:
925
+ if feat_cache is not None:
926
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
927
+ else:
928
+ x = self.upsamplers[0](x, first_chunk=first_chunk, is_reference=is_reference)
929
+ return x
930
+
931
+
932
+ class RefConvIn(nn.Module):
933
+ """
934
+ Tokenizes reference videos by converting spatial resolution into channels.
935
+ Uses only reshape operations.
936
+ Converts [b, c, T, h, w] to [b, c_out, T, h/patch_size, w/patch_size]
937
+ """
938
+
939
+ def __init__(
940
+ self,
941
+ in_channels=3,
942
+ out_channels=384,
943
+ patch_size=8,
944
+ ):
945
+ """
946
+ Args:
947
+ in_channels (int): Number of input channels (e.g., 3 for RGB)
948
+ out_channels (int): Number of output channels
949
+ patch_size (int): Size of spatial patches for downsampling
950
+ """
951
+ super().__init__()
952
+
953
+ self.in_channels = in_channels
954
+ self.out_channels = out_channels
955
+ self.patch_size = patch_size
956
+
957
+ # Calculate intermediate channels after patchification
958
+ self.patch_channels = in_channels * patch_size * patch_size
959
+
960
+ # Conv2d layer to project from patch_channels to out_channels
961
+ self.proj = nn.Conv2d(self.patch_channels, self.out_channels, kernel_size=3, stride=1, padding=1)
962
+ self.norm = WanRMS_norm(self.out_channels, images=True)
963
+
964
+ # Calculate how many times to repeat
965
+ assert (
966
+ self.out_channels % self.patch_channels == 0
967
+ ), f"out_channels ({self.out_channels}) must be divisible by patch_channels ({self.patch_channels})"
968
+
969
+
970
+ def forward(self, x):
971
+ """
972
+ Tokenize reference input using only reshape operations.
973
+
974
+ Args:
975
+ x: Input tensor [b, in_channels, T, h, w]
976
+
977
+ Returns:
978
+ Tokenized tensor [b, out_channels, T, h/patch_size, w/patch_size]
979
+ """
980
+ b, c, T, h, w = x.shape
981
+ patch_size = self.patch_size
982
+
983
+ # Ensure dimensions are divisible by patch_size
984
+ assert h % patch_size == 0, f"Height {h} must be divisible by patch_size {patch_size}"
985
+ assert w % patch_size == 0, f"Width {w} must be divisible by patch_size {patch_size}"
986
+
987
+ # Step 1: Reshape into patches
988
+ x = x.view(b, c, T, h // patch_size, patch_size, w // patch_size, patch_size)
989
+
990
+ # Step 2: Rearrange dimensions
991
+ x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
992
+
993
+ # Step 3: Flatten patches into channels
994
+ x = x.view(b, c * patch_size * patch_size, T, h // patch_size, w // patch_size)
995
+
996
+ # Step 4: Apply Conv2d projection for each time step
997
+ # Reshape to merge batch and time dimensions
998
+ x = x.view(b * T, self.patch_channels, h // patch_size, w // patch_size)
999
+
1000
+ # Apply convolution
1001
+ x = self.proj(x)
1002
+ x = self.norm(x)
1003
+
1004
+ # Reshape back to separate batch and time dimensions
1005
+ x = x.view(b, self.out_channels, T, h // patch_size, w // patch_size)
1006
+
1007
+ return x
1008
+
1009
+
1010
+ class WanRotaryPosEmbed(nn.Module):
1011
+ def __init__(
1012
+ self,
1013
+ attention_head_dim: int,
1014
+ patch_size: Tuple[int, int, int],
1015
+ max_seq_len: int,
1016
+ theta: float = 10000.0,
1017
+ ):
1018
+ super().__init__()
1019
+
1020
+ self.attention_head_dim = attention_head_dim
1021
+ self.patch_size = patch_size
1022
+ self.max_seq_len = max_seq_len
1023
+
1024
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
1025
+ t_dim = attention_head_dim - h_dim - w_dim
1026
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
1027
+
1028
+ freqs_cos = []
1029
+ freqs_sin = []
1030
+
1031
+ for dim in [t_dim, h_dim, w_dim]:
1032
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
1033
+ dim,
1034
+ max_seq_len,
1035
+ theta,
1036
+ use_real=True,
1037
+ repeat_interleave_real=True,
1038
+ freqs_dtype=freqs_dtype,
1039
+ )
1040
+ freqs_cos.append(freq_cos)
1041
+ freqs_sin.append(freq_sin)
1042
+
1043
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
1044
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
1045
+
1046
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1047
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
1048
+ p_t, p_h, p_w = self.patch_size
1049
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
1050
+
1051
+ split_sizes = [
1052
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
1053
+ self.attention_head_dim // 3,
1054
+ self.attention_head_dim // 3,
1055
+ ]
1056
+
1057
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
1058
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
1059
+
1060
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
1061
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
1062
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
1063
+
1064
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
1065
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
1066
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
1067
+
1068
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
1069
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
1070
+
1071
+ return freqs_cos, freqs_sin
1072
+
1073
+ class ReferenceRemover:
1074
+ """
1075
+ Removes reference frame tokens that were concatenated along temporal dimension.
1076
+ Handles cases where temporal upsampling may have occurred.
1077
+ """
1078
+
1079
+ def __init__(self, ref_frame_count: int = 1):
1080
+ """
1081
+ Args:
1082
+ ref_frame_count: Number of reference frames concatenated (default: 1)
1083
+ """
1084
+ self.ref_frame_count = ref_frame_count
1085
+
1086
+ def __call__(self, x: torch.Tensor, original_temporal_dim: int) -> torch.Tensor:
1087
+ """
1088
+ Remove reference frames from the temporal dimension.
1089
+
1090
+ Args:
1091
+ x: Tensor of shape [B, C, T, H, W]
1092
+ original_temporal_dim: The temporal dimension before concatenating reference
1093
+
1094
+ Returns:
1095
+ Tensor with reference frames removed
1096
+ """
1097
+ current_temporal_dim = x.shape[2]
1098
+
1099
+ # Calculate temporal scale factor from upsampling
1100
+ original_input_frames = original_temporal_dim + 1
1101
+ temporal_scale = current_temporal_dim // original_input_frames
1102
+
1103
+ # Calculate how many frames to remove (scaled reference frames)
1104
+ frames_to_remove = self.ref_frame_count * temporal_scale
1105
+
1106
+ # Remove reference frames from the beginning
1107
+ return (x[:, :, :frames_to_remove, :, :], x[:, :, frames_to_remove:, :, :])
1108
+
1109
+
1110
+ class WanDecoder3d(nn.Module):
1111
+ r"""
1112
+ A 3D decoder module.
1113
+
1114
+ Args:
1115
+ dim (int): The base number of channels in the first layer.
1116
+ z_dim (int): The dimensionality of the latent space.
1117
+ dim_mult (list of int): Multipliers for the number of channels in each block.
1118
+ num_res_blocks (int): Number of residual blocks in each block.
1119
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
1120
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
1121
+ dropout (float): Dropout rate for the dropout layers.
1122
+ non_linearity (str): Type of non-linearity to use.
1123
+ skip_decoder_attention (bool): If True, skip all attention blocks in decoder.
1124
+ """
1125
+
1126
+ def __init__(
1127
+ self,
1128
+ dim=128,
1129
+ z_dim=4,
1130
+ dim_mult=[1, 2, 4, 4],
1131
+ num_res_blocks=2,
1132
+ attn_scales=[],
1133
+ temperal_upsample=[False, True, True],
1134
+ dropout=0.0,
1135
+ non_linearity: str = "silu",
1136
+ out_channels: int = 3,
1137
+ is_residual: bool = False,
1138
+ use_reference: bool = False,
1139
+ skip_decoder_attention: bool = False,
1140
+ dc_factor: int = 2,
1141
+ ):
1142
+ super().__init__()
1143
+ self.dim = dim
1144
+ self.z_dim = z_dim
1145
+ self.dim_mult = dim_mult
1146
+ self.num_res_blocks = num_res_blocks
1147
+ self.attn_scales = attn_scales
1148
+ self.temperal_upsample = temperal_upsample
1149
+ self.use_reference = use_reference
1150
+ self.skip_decoder_attention = skip_decoder_attention
1151
+ self.dc_factor = dc_factor
1152
+ self.nonlinearity = resolve_activation(non_linearity)
1153
+
1154
+ # dimensions
1155
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
1156
+
1157
+ # init block
1158
+ self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
1159
+
1160
+ # middle blocks
1161
+ self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
1162
+
1163
+ self.ref_conv_in = RefConvIn(out_channels=dims[0]) if self.use_reference else None
1164
+
1165
+ # upsample block & attention block 1, 2 and 3
1166
+ self.up_blocks = nn.ModuleList([])
1167
+
1168
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
1169
+ # residual (+attention) blocks
1170
+ if i > 0 and not is_residual:
1171
+ # wan vae 2.1
1172
+ in_dim = in_dim // 2
1173
+
1174
+ # determine if we need upsampling
1175
+ up_flag = i != len(dim_mult) - 1
1176
+ # determine upsampling mode, if not upsampling, set to None
1177
+ upsample_mode = None
1178
+ if up_flag and temperal_upsample[i]:
1179
+ upsample_mode = "upsample3d"
1180
+ elif up_flag:
1181
+ upsample_mode = "upsample2d"
1182
+ # Create and add the upsampling block
1183
+ if is_residual:
1184
+ up_block = WanResidualUpBlock(
1185
+ in_dim=in_dim,
1186
+ out_dim=out_dim,
1187
+ num_res_blocks=num_res_blocks,
1188
+ dropout=dropout,
1189
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
1190
+ up_flag=up_flag,
1191
+ non_linearity=non_linearity,
1192
+ )
1193
+ else:
1194
+ up_block = WanUpBlock(
1195
+ in_dim=in_dim,
1196
+ out_dim=out_dim,
1197
+ num_res_blocks=num_res_blocks,
1198
+ dropout=dropout,
1199
+ upsample_mode=upsample_mode,
1200
+ non_linearity=non_linearity,
1201
+ )
1202
+
1203
+ self.up_blocks.append(up_block)
1204
+
1205
+ # output blocks
1206
+ self.norm_out = WanRMS_norm(out_dim, images=False)
1207
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
1208
+
1209
+ self.gradient_checkpointing = GRADIENT_CHECKPOINTING
1210
+
1211
+ def forward(self, x, transformer, feat_cache=None, feat_idx=[0], first_chunk=False, reference_frame=None, skip=False, window_size=-1):
1212
+ run_attn = not self.skip_decoder_attention and not skip
1213
+ if self.gradient_checkpointing:
1214
+ x = torch.utils.checkpoint.checkpoint(
1215
+ self.conv_in,
1216
+ x,
1217
+ use_reentrant=False
1218
+ )
1219
+ else:
1220
+ x = self.conv_in(x)
1221
+
1222
+ ## middle
1223
+ x = self.mid_block(x, feat_cache, feat_idx)
1224
+ ref_tokens = None
1225
+ if self.use_reference and reference_frame is not None:
1226
+ # ref_tokens: [B, C, 1, H, W] - single frame
1227
+ if self.gradient_checkpointing:
1228
+ ref_tokens = torch.utils.checkpoint.checkpoint(
1229
+ self.ref_conv_in,
1230
+ reference_frame,
1231
+ use_reentrant=False
1232
+ )
1233
+ else:
1234
+ ref_tokens = self.ref_conv_in(reference_frame)
1235
+
1236
+ # Transformer + upblock
1237
+ if run_attn:
1238
+ for i in range(4):
1239
+ if i <= 2:
1240
+ if ref_tokens is not None:
1241
+ x = torch.cat([ref_tokens, x], dim=2)
1242
+ transformer_output = transformer(
1243
+ hidden_states=x,
1244
+ stage_idx=i,
1245
+ return_dict=True,
1246
+ window_size=window_size,
1247
+ )
1248
+ # Extract the output sample
1249
+ x = transformer_output.sample if hasattr(transformer_output, 'sample') else transformer_output[0]
1250
+ if ref_tokens is not None:
1251
+ ref_tokens, x = x[:, :, :1], x[:, :, 1:]
1252
+ if i <= 1:
1253
+ if self.gradient_checkpointing:
1254
+ ref_tokens = torch.utils.checkpoint.checkpoint(
1255
+ self.up_blocks[i],
1256
+ ref_tokens,
1257
+ None,
1258
+ [0],
1259
+ first_chunk,
1260
+ True,
1261
+ use_reentrant=False
1262
+ )
1263
+ else:
1264
+ ref_tokens = self.up_blocks[i](ref_tokens, is_reference=True, first_chunk=first_chunk)
1265
+
1266
+ if self.gradient_checkpointing:
1267
+ # Save mutable state before checkpoint - will be restored on each forward run
1268
+ # (both original forward and backward recompute)
1269
+ initial_idx = feat_idx[0]
1270
+ initial_cache_snapshot = [
1271
+ (c.clone() if isinstance(c, torch.Tensor) else c)
1272
+ for c in feat_cache
1273
+ ] if feat_cache is not None else None
1274
+
1275
+ def checkpoint_fn(x, block_idx=i):
1276
+ # Restore state before each run to ensure consistency
1277
+ feat_idx[0] = initial_idx
1278
+ if initial_cache_snapshot is not None:
1279
+ for j in range(len(feat_cache)):
1280
+ val = initial_cache_snapshot[j]
1281
+ feat_cache[j] = val.clone() if isinstance(val, torch.Tensor) else val
1282
+ return self.up_blocks[block_idx](x, feat_cache, feat_idx, first_chunk=first_chunk)
1283
+
1284
+ x = torch.utils.checkpoint.checkpoint(
1285
+ checkpoint_fn,
1286
+ x,
1287
+ use_reentrant=False,
1288
+ )
1289
+ else:
1290
+ x = self.up_blocks[i](x, feat_cache, feat_idx, first_chunk=first_chunk)
1291
+ else:
1292
+ print(f"[DEBUG]: Transformer skipped")
1293
+ for i in range(4):
1294
+ x = self.up_blocks[i](x, feat_cache, feat_idx, first_chunk=first_chunk)
1295
+
1296
+ ## head
1297
+ x = self.norm_out(x)
1298
+ x = self.nonlinearity(x)
1299
+
1300
+ if self.gradient_checkpointing:
1301
+ x = torch.utils.checkpoint.checkpoint(
1302
+ self.conv_out,
1303
+ x,
1304
+ None,
1305
+ 'upsample3d',
1306
+ use_reentrant=False,
1307
+ )
1308
+ else:
1309
+ x = self.conv_out(x, mode='upsample3d')
1310
+ return x
1311
+
1312
+
1313
+ def patchify(x, patch_size):
1314
+ if patch_size == 1:
1315
+ return x
1316
+
1317
+ if x.dim() != 5:
1318
+ raise ValueError(f"Invalid input shape: {x.shape}")
1319
+ # x shape: [batch_size, channels, frames, height, width]
1320
+ batch_size, channels, frames, height, width = x.shape
1321
+
1322
+ # Ensure height and width are divisible by patch_size
1323
+ if height % patch_size != 0 or width % patch_size != 0:
1324
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
1325
+
1326
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
1327
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
1328
+
1329
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
1330
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
1331
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
1332
+
1333
+ return x
1334
+
1335
+
1336
+ def unpatchify(x, patch_size):
1337
+ if patch_size == 1:
1338
+ return x
1339
+
1340
+ if x.dim() != 5:
1341
+ raise ValueError(f"Invalid input shape: {x.shape}")
1342
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
1343
+ batch_size, c_patches, frames, height, width = x.shape
1344
+ channels = c_patches // (patch_size * patch_size)
1345
+
1346
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
1347
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
1348
+
1349
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
1350
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
1351
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
1352
+
1353
+ return x
1354
+
1355
+
1356
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1357
+ r"""
1358
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
1359
+ Introduced in [Wan 2.1].
1360
+
1361
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1362
+ for all models (such as downloading or saving).
1363
+ """
1364
+
1365
+ _supports_gradient_checkpointing = False
1366
+
1367
+ @register_to_config
1368
+ def __init__(
1369
+ self,
1370
+ base_dim: int = 96,
1371
+ decoder_base_dim: Optional[int] = None,
1372
+ use_reference: bool = False,
1373
+ skip_decoder_attention: bool = False,
1374
+ z_dim: int = 16,
1375
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
1376
+ num_res_blocks: int = 2,
1377
+ attn_scales: List[float] = [],
1378
+ temperal_downsample: List[bool] = [False, True, True],
1379
+ dropout: float = 0.0,
1380
+ latents_mean: List[float] = [
1381
+ -0.7571,
1382
+ -0.7089,
1383
+ -0.9113,
1384
+ 0.1075,
1385
+ -0.1745,
1386
+ 0.9653,
1387
+ -0.1517,
1388
+ 1.5508,
1389
+ 0.4134,
1390
+ -0.0715,
1391
+ 0.5517,
1392
+ -0.3632,
1393
+ -0.1922,
1394
+ -0.9497,
1395
+ 0.2503,
1396
+ -0.2921,
1397
+ ],
1398
+ latents_std: List[float] = [
1399
+ 2.8184,
1400
+ 1.4541,
1401
+ 2.3275,
1402
+ 2.6558,
1403
+ 1.2196,
1404
+ 1.7708,
1405
+ 2.6052,
1406
+ 2.0743,
1407
+ 3.2687,
1408
+ 2.1526,
1409
+ 2.8652,
1410
+ 1.5579,
1411
+ 1.6382,
1412
+ 1.1253,
1413
+ 2.8251,
1414
+ 1.9160,
1415
+ ],
1416
+ is_residual: bool = False,
1417
+ in_channels: int = 3,
1418
+ out_channels: int = 3,
1419
+ patch_size: Optional[int] = None,
1420
+ scale_factor_temporal: Optional[int] = 4,
1421
+ scale_factor_spatial: Optional[int] = 8,
1422
+ inference_w_dropout=False,
1423
+ dropout_p=0.7,
1424
+ gradient_checkpointing=False,
1425
+ **kwargs,
1426
+ ) -> None:
1427
+ global GRADIENT_CHECKPOINTING
1428
+ GRADIENT_CHECKPOINTING = gradient_checkpointing
1429
+ super().__init__()
1430
+ self.inference_w_dropout = inference_w_dropout
1431
+ self.dropout_p = dropout_p
1432
+
1433
+ self.z_dim = z_dim
1434
+ self.temperal_downsample = temperal_downsample
1435
+ self.temperal_upsample = temperal_downsample[::-1]
1436
+
1437
+ if decoder_base_dim is None:
1438
+ decoder_base_dim = base_dim
1439
+
1440
+ self.encoder = WanEncoder3d(
1441
+ in_channels=in_channels,
1442
+ dim=base_dim,
1443
+ z_dim=z_dim * 2,
1444
+ dim_mult=dim_mult,
1445
+ num_res_blocks=num_res_blocks,
1446
+ attn_scales=attn_scales,
1447
+ temperal_downsample=temperal_downsample,
1448
+ dropout=dropout,
1449
+ is_residual=is_residual,
1450
+ )
1451
+ self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
1452
+ self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
1453
+
1454
+ self.decoder = WanDecoder3d(
1455
+ dim=decoder_base_dim,
1456
+ z_dim=z_dim,
1457
+ dim_mult=dim_mult,
1458
+ num_res_blocks=num_res_blocks,
1459
+ attn_scales=attn_scales,
1460
+ temperal_upsample=self.temperal_upsample,
1461
+ dropout=dropout,
1462
+ out_channels=out_channels,
1463
+ is_residual=is_residual,
1464
+ use_reference=use_reference,
1465
+ skip_decoder_attention=skip_decoder_attention,
1466
+ )
1467
+
1468
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1469
+
1470
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
1471
+ # to perform decoding of a single video latent at a time.
1472
+ self.use_slicing = False
1473
+
1474
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
1475
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
1476
+ # intermediate tiles together, the memory requirement can be lowered.
1477
+ self.use_tiling = False
1478
+
1479
+ # The minimal tile height and width for spatial tiling to be used
1480
+ self.tile_sample_min_height = 256
1481
+ self.tile_sample_min_width = 256
1482
+
1483
+ # The minimal distance between two spatial tiles
1484
+ self.tile_sample_stride_height = 192
1485
+ self.tile_sample_stride_width = 192
1486
+
1487
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
1488
+ self._cached_conv_counts = {
1489
+ "decoder": (
1490
+ sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) if self.decoder is not None else 0
1491
+ ),
1492
+ "encoder": (
1493
+ sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0
1494
+ ),
1495
+ }
1496
+
1497
+ self.reference_frame = None
1498
+
1499
+ def _init_ref_conv_in(self):
1500
+ ref_conv_in = getattr(self.decoder, "ref_conv_in", None)
1501
+ if ref_conv_in is None:
1502
+ return
1503
+
1504
+ with torch.no_grad():
1505
+ nn.init.xavier_uniform_(ref_conv_in.proj.weight)
1506
+ if ref_conv_in.proj.bias is not None:
1507
+ nn.init.constant_(ref_conv_in.proj.bias, 0.0)
1508
+
1509
+ def _apply_token_dropout(self, x: torch.Tensor) -> torch.Tensor:
1510
+ """
1511
+ Apply token dropout to the input tensor.
1512
+
1513
+ Args:
1514
+ x: Input tensor of shape [B, C, T, H, W]
1515
+
1516
+ Returns:
1517
+ Tensor with random tokens dropped (set to zero)
1518
+ """
1519
+ if self.inference_w_dropout or self.training:
1520
+ if self.training:
1521
+ p = torch.rand(1).item() * self.dropout_p
1522
+ else:
1523
+ p = self.dropout_p
1524
+ dropped = torch.rand_like(x[:, :1, :1, :, :]) < p
1525
+ x = torch.where(dropped, torch.zeros_like(x), x)
1526
+ return x
1527
+
1528
+ def enable_tiling(
1529
+ self,
1530
+ tile_sample_min_height: Optional[int] = None,
1531
+ tile_sample_min_width: Optional[int] = None,
1532
+ tile_sample_stride_height: Optional[float] = None,
1533
+ tile_sample_stride_width: Optional[float] = None,
1534
+ ) -> None:
1535
+ r"""
1536
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1537
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1538
+ processing larger images.
1539
+
1540
+ Args:
1541
+ tile_sample_min_height (`int`, *optional*):
1542
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1543
+ tile_sample_min_width (`int`, *optional*):
1544
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1545
+ tile_sample_stride_height (`int`, *optional*):
1546
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1547
+ no tiling artifacts produced across the height dimension.
1548
+ tile_sample_stride_width (`int`, *optional*):
1549
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
1550
+ artifacts produced across the width dimension.
1551
+ """
1552
+ self.use_tiling = True
1553
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1554
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1555
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1556
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1557
+
1558
+ def disable_tiling(self) -> None:
1559
+ r"""
1560
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1561
+ decoding in one step.
1562
+ """
1563
+ self.use_tiling = False
1564
+
1565
+ def enable_slicing(self) -> None:
1566
+ r"""
1567
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1568
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1569
+ """
1570
+ self.use_slicing = True
1571
+
1572
+ def disable_slicing(self) -> None:
1573
+ r"""
1574
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1575
+ decoding in one step.
1576
+ """
1577
+ self.use_slicing = False
1578
+
1579
+ def clear_cache(self):
1580
+ # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
1581
+ self._conv_num = self._cached_conv_counts["decoder"]
1582
+ self._conv_idx = [0]
1583
+ self._feat_map = [None] * self._conv_num
1584
+ # cache encode
1585
+ self._enc_conv_num = self._cached_conv_counts["encoder"]
1586
+ self._enc_conv_idx = [0]
1587
+ self._enc_feat_map = [None] * self._enc_conv_num
1588
+
1589
+ def _encode(self, x: torch.Tensor):
1590
+ _, _, num_frame, height, width = x.shape
1591
+
1592
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1593
+ return self.tiled_encode(x, is_reference)
1594
+
1595
+ self.clear_cache()
1596
+ if self.config.patch_size is not None:
1597
+ x = patchify(x, patch_size=self.config.patch_size)
1598
+ iter_ = 1 #TODO
1599
+ for i in range(0, iter_):
1600
+ self._enc_conv_idx = [0]
1601
+ if i == 0:
1602
+ out = self.encoder(
1603
+ x[:, :, : 4 * LATENT_T_STRIDE - 3, :, :],
1604
+ feat_cache=self._enc_feat_map,
1605
+ feat_idx=self._enc_conv_idx,
1606
+ )
1607
+ else:
1608
+ out_ = self.encoder(
1609
+ x[:, :, i * 4 * LATENT_T_STRIDE - 3 : (i + 1) * 4 * LATENT_T_STRIDE - 3, :, :],
1610
+ feat_cache=self._enc_feat_map,
1611
+ feat_idx=self._enc_conv_idx,
1612
+ )
1613
+ out = torch.cat([out, out_], 2)
1614
+
1615
+ enc = self.quant_conv(out)
1616
+ self.clear_cache()
1617
+ return enc
1618
+
1619
+ @apply_forward_hook
1620
+ def encode(
1621
+ self, x: torch.Tensor, return_dict: bool = True
1622
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1623
+ r"""
1624
+ Encode a batch of images into latents.
1625
+
1626
+ Args:
1627
+ x (`torch.Tensor`): Input batch of images.
1628
+ return_dict (`bool`, *optional*, defaults to `True`):
1629
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1630
+
1631
+ Returns:
1632
+ The latent representations of the encoded videos. If `return_dict` is True, a
1633
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1634
+ """
1635
+
1636
+ if self.use_slicing and x.shape[0] > 1:
1637
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1638
+ h = torch.cat(encoded_slices)
1639
+ else:
1640
+ h = self._encode(x)
1641
+
1642
+ posterior = DiagonalGaussianDistribution(h)
1643
+
1644
+ if not return_dict:
1645
+ return (posterior,)
1646
+ return AutoencoderKLOutput(latent_dist=posterior)
1647
+
1648
+ def _decode(self, z: torch.Tensor, transformer, return_dict: bool = True, reference_frame=None, skip=False, window_size=-1):
1649
+ _, _, num_frame, height, width = z.shape
1650
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1651
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1652
+
1653
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1654
+ return self.tiled_decode(z, return_dict=return_dict, reference_frame=reference_frame, skip=skip)
1655
+
1656
+ self.clear_cache()
1657
+
1658
+ x = self.post_quant_conv(z)
1659
+
1660
+ x = self._apply_token_dropout(x)
1661
+
1662
+ for i in range(0, num_frame, LATENT_T_STRIDE):
1663
+ self._conv_idx = [0]
1664
+ self._conv_idx_ref = [0]
1665
+ if i == 0:
1666
+ out = self.decoder(
1667
+ x[:, :, i : i + LATENT_T_STRIDE, :, :],
1668
+ transformer=transformer,
1669
+ feat_cache=self._feat_map,
1670
+ feat_idx=self._conv_idx,
1671
+ first_chunk=True,
1672
+ reference_frame=reference_frame,
1673
+ skip=skip,
1674
+ window_size=window_size,
1675
+ )
1676
+ else:
1677
+ out_ = self.decoder(
1678
+ x[:, :, i : i + LATENT_T_STRIDE, :, :],
1679
+ transformer=transformer,
1680
+ feat_cache=self._feat_map,
1681
+ feat_idx=self._conv_idx,
1682
+ reference_frame=reference_frame,
1683
+ skip=skip,
1684
+ window_size=window_size,
1685
+ )
1686
+ out = torch.cat([out, out_], 2)
1687
+
1688
+ if self.config.patch_size is not None:
1689
+ out = unpatchify(out, patch_size=self.config.patch_size)
1690
+
1691
+ out = torch.clamp(out, min=-1.0, max=1.0)
1692
+
1693
+ self.clear_cache()
1694
+ if not return_dict:
1695
+ return (out,)
1696
+
1697
+ return DecoderOutput(sample=out)
1698
+
1699
+ @apply_forward_hook
1700
+ def decode(
1701
+ self, z: torch.Tensor, transformer ,return_dict: bool = True, reference_frame=None, skip=False, window_size=-1
1702
+ ) -> Union[DecoderOutput, torch.Tensor]:
1703
+ r"""
1704
+ Decode a batch of images.
1705
+
1706
+ Args:
1707
+ z (`torch.Tensor`): Input batch of latent vectors.
1708
+ return_dict (`bool`, *optional*, defaults to `True`):
1709
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1710
+ reference_frame (`torch.Tensor`, *optional*):
1711
+ Reference frame for decoder attention.
1712
+ skip (`bool`, *optional*, defaults to `False`):
1713
+ Whether to skip attention in the decoder.
1714
+ Returns:
1715
+ [`~models.vae.DecoderOutput`] or `tuple`:
1716
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1717
+ returned.
1718
+ """
1719
+ # Use passed reference_frame or fall back to stored one
1720
+ ref_frame = reference_frame if reference_frame is not None else self.reference_frame
1721
+
1722
+ if self.use_slicing and z.shape[0] > 1:
1723
+ decoded_slices = [
1724
+ self._decode(z_slice, transformer, reference_frame=ref_frame, skip=skip, window_size=window_size).sample for z_slice in z.split(1)
1725
+ ]
1726
+ decoded = torch.cat(decoded_slices)
1727
+ else:
1728
+ decoded = self._decode(z, transformer, reference_frame=ref_frame, skip=skip, window_size=window_size).sample
1729
+
1730
+ if not return_dict:
1731
+ return (decoded,)
1732
+ return DecoderOutput(sample=decoded)
1733
+
1734
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1735
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
1736
+ for y in range(blend_extent):
1737
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1738
+ y / blend_extent
1739
+ )
1740
+ return b
1741
+
1742
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1743
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
1744
+ for x in range(blend_extent):
1745
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1746
+ x / blend_extent
1747
+ )
1748
+ return b
1749
+
1750
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1751
+ r"""Encode a batch of images using a tiled encoder.
1752
+
1753
+ Args:
1754
+ x (`torch.Tensor`): Input batch of videos.
1755
+
1756
+ Returns:
1757
+ `torch.Tensor`:
1758
+ The latent representation of the encoded videos.
1759
+ """
1760
+ _, _, num_frames, height, width = x.shape
1761
+ latent_height = height // self.spatial_compression_ratio
1762
+ latent_width = width // self.spatial_compression_ratio
1763
+
1764
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1765
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1766
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1767
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1768
+
1769
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1770
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1771
+
1772
+ # Split x into overlapping tiles and encode them separately.
1773
+ # The tiles have an overlap to avoid seams between tiles.
1774
+ rows = []
1775
+ for i in range(0, height, self.tile_sample_stride_height):
1776
+ row = []
1777
+ for j in range(0, width, self.tile_sample_stride_width):
1778
+ self.clear_cache()
1779
+ time = []
1780
+ frame_range = 1 + (num_frames - 1) // 4
1781
+ for k in range(frame_range):
1782
+ self._enc_conv_idx = [0]
1783
+ if k == 0:
1784
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1785
+ else:
1786
+ tile = x[
1787
+ :,
1788
+ :,
1789
+ 1 + 4 * (k - 1) : 1 + 4 * k,
1790
+ i : i + self.tile_sample_min_height,
1791
+ j : j + self.tile_sample_min_width,
1792
+ ]
1793
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1794
+ tile = self.quant_conv(tile)
1795
+ time.append(tile)
1796
+ row.append(torch.cat(time, dim=2))
1797
+ rows.append(row)
1798
+ self.clear_cache()
1799
+
1800
+ result_rows = []
1801
+ for i, row in enumerate(rows):
1802
+ result_row = []
1803
+ for j, tile in enumerate(row):
1804
+ # blend the above tile and the left tile
1805
+ # to the current tile and add the current tile to the result row
1806
+ if i > 0:
1807
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1808
+ if j > 0:
1809
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1810
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1811
+ result_rows.append(torch.cat(result_row, dim=-1))
1812
+
1813
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1814
+ return enc
1815
+
1816
+ def tiled_decode(
1817
+ self, z: torch.Tensor, return_dict: bool = True, reference_frame=None, skip=False
1818
+ ) -> Union[DecoderOutput, torch.Tensor]:
1819
+ r"""
1820
+ Decode a batch of images using a tiled decoder.
1821
+
1822
+ Args:
1823
+ z (`torch.Tensor`): Input batch of latent vectors.
1824
+ return_dict (`bool`, *optional*, defaults to `True`):
1825
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1826
+
1827
+ Returns:
1828
+ [`~models.vae.DecoderOutput`] or `tuple`:
1829
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1830
+ returned.
1831
+ """
1832
+ _, _, num_frames, height, width = z.shape
1833
+ sample_height = height * self.spatial_compression_ratio
1834
+ sample_width = width * self.spatial_compression_ratio
1835
+
1836
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1837
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1838
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1839
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1840
+
1841
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1842
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1843
+
1844
+ # Split z into overlapping tiles and decode them separately.
1845
+ # The tiles have an overlap to avoid seams between tiles.
1846
+ rows = []
1847
+ for i in range(0, height, tile_latent_stride_height):
1848
+ row = []
1849
+ for j in range(0, width, tile_latent_stride_width):
1850
+ self.clear_cache()
1851
+ time = []
1852
+ for k in range(num_frames):
1853
+ self._conv_idx = [0]
1854
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1855
+ tile = self.post_quant_conv(tile)
1856
+
1857
+ tile = self._apply_token_dropout(tile)
1858
+
1859
+ decoded = self.decoder(
1860
+ tile,
1861
+ feat_cache=self._feat_map,
1862
+ feat_idx=self._conv_idx,
1863
+ reference_frame=reference_frame,
1864
+ skip=skip,
1865
+ )
1866
+ time.append(decoded)
1867
+ row.append(torch.cat(time, dim=2))
1868
+ rows.append(row)
1869
+ self.clear_cache()
1870
+
1871
+ result_rows = []
1872
+ for i, row in enumerate(rows):
1873
+ result_row = []
1874
+ for j, tile in enumerate(row):
1875
+ # blend the above tile and the left tile
1876
+ # to the current tile and add the current tile to the result row
1877
+ if i > 0:
1878
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1879
+ if j > 0:
1880
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1881
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1882
+ result_rows.append(torch.cat(result_row, dim=-1))
1883
+
1884
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1885
+
1886
+ if not return_dict:
1887
+ return (dec,)
1888
+ return DecoderOutput(sample=dec)
1889
+
1890
+ def forward(
1891
+ self,
1892
+ sample: torch.Tensor,
1893
+ sample_posterior: bool = False,
1894
+ return_dict: bool = True,
1895
+ generator: Optional[torch.Generator] = None,
1896
+ ) -> Union[DecoderOutput, torch.Tensor]:
1897
+ """
1898
+ Args:
1899
+ sample (`torch.Tensor`): Input sample.
1900
+ return_dict (`bool`, *optional*, defaults to `True`):
1901
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1902
+ """
1903
+ x = sample
1904
+
1905
+ # Store reference frame if using reference attention
1906
+ if self.decoder.use_reference:
1907
+ idx = torch.randint(0, x.size(2), ()).item()
1908
+ self.reference_frame = x[:, :, idx : idx + 1, :, :].clone()
1909
+ else:
1910
+ self.reference_frame = None
1911
+
1912
+ posterior = self.encode(x).latent_dist
1913
+ if sample_posterior:
1914
+ z = posterior.sample(generator=generator)
1915
+ else:
1916
+ z = posterior
src/models/Wan/transformer_wan.py ADDED
@@ -0,0 +1,1049 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from peft import LoraConfig, get_peft_model, TaskType
25
+ from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward
28
+ from diffusers.models.attention_dispatch import dispatch_attention_fn
29
+ from diffusers.models.cache_utils import CacheMixin
30
+ from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
31
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
+ from diffusers.models.modeling_utils import ModelMixin
33
+ from diffusers.models.normalization import FP32LayerNorm
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None:
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn.fused_projections:
44
+ if attn.cross_attention_dim_head is None:
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
47
+ else:
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn.to_q(hidden_states)
50
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
51
+ else:
52
+ query = attn.to_q(hidden_states)
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+ return query, key, value
56
+
57
+
58
+ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
59
+ if attn.fused_projections:
60
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
61
+ else:
62
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
63
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
64
+ return key_img, value_img
65
+
66
+
67
+ class WanAttnProcessor:
68
+ _attention_backend = None
69
+
70
+ def __init__(self, return_attention_maps):
71
+ if not hasattr(F, "scaled_dot_product_attention"):
72
+ raise ImportError(
73
+ "WanAttnProcessor requires PyTorch 2.0."
74
+ )
75
+ self.return_attention_maps = return_attention_maps
76
+
77
+ def __call__(
78
+ self,
79
+ attn: "WanAttention",
80
+ hidden_states: torch.Tensor,
81
+ encoder_hidden_states: Optional[torch.Tensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
84
+ ) -> torch.Tensor:
85
+ encoder_hidden_states_img = None
86
+ if attn.add_k_proj is not None:
87
+ # 512 is the context length of the text encoder, hardcoded for now
88
+ image_context_length = encoder_hidden_states.shape[1] - 512
89
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
90
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
91
+
92
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
93
+
94
+ query = attn.norm_q(query)
95
+ key = attn.norm_k(key)
96
+
97
+ query = query.unflatten(2, (attn.heads, -1))
98
+ key = key.unflatten(2, (attn.heads, -1))
99
+ value = value.unflatten(2, (attn.heads, -1))
100
+
101
+ if rotary_emb is not None:
102
+
103
+ def apply_rotary_emb(
104
+ hidden_states: torch.Tensor,
105
+ freqs_cos: torch.Tensor,
106
+ freqs_sin: torch.Tensor,
107
+ ):
108
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
109
+ cos = freqs_cos[..., 0::2]
110
+ sin = freqs_sin[..., 1::2]
111
+ out = torch.empty_like(hidden_states)
112
+ out[..., 0::2] = x1 * cos - x2 * sin
113
+ out[..., 1::2] = x1 * sin + x2 * cos
114
+ return out.type_as(hidden_states)
115
+
116
+ query = apply_rotary_emb(query, *rotary_emb)
117
+ key = apply_rotary_emb(key, *rotary_emb)
118
+
119
+ # I2V task
120
+ hidden_states_img = None
121
+ if encoder_hidden_states_img is not None:
122
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
123
+ key_img = attn.norm_added_k(key_img)
124
+
125
+ key_img = key_img.unflatten(2, (attn.heads, -1))
126
+ value_img = value_img.unflatten(2, (attn.heads, -1))
127
+
128
+ hidden_states_img = dispatch_attention_fn(
129
+ query,
130
+ key_img,
131
+ value_img,
132
+ attn_mask=None,
133
+ dropout_p=0.0,
134
+ is_causal=False,
135
+ backend=self._attention_backend,
136
+ )
137
+ hidden_states_img = hidden_states_img.flatten(2, 3)
138
+ hidden_states_img = hidden_states_img.type_as(query)
139
+
140
+ if not self.return_attention_maps:
141
+ # Use fast dispatch
142
+ # Cast attention_mask to match query dtype to avoid dtype mismatch
143
+ attn_mask = attention_mask.to(query.dtype) if attention_mask is not None else None
144
+
145
+ hidden_states = dispatch_attention_fn(
146
+ query,
147
+ key,
148
+ value,
149
+ attn_mask=attn_mask,
150
+ dropout_p=0.0,
151
+ is_causal=False,
152
+ backend=self._attention_backend,
153
+ )
154
+ hidden_states = hidden_states.flatten(2, 3)
155
+ attn_weights = None
156
+
157
+ else:
158
+ # Manual attention computation to get attention maps
159
+ # query, key, value: (B, S, H, D) where H=heads, D=head_dim
160
+
161
+ # Transpose to (B, H, S, D) for batched matrix multiplication
162
+ q = query.transpose(1, 2) # (B, H, S, D)
163
+ k = key.transpose(1, 2) # (B, H, S, D)
164
+ v = value.transpose(1, 2) # (B, H, S, D)
165
+
166
+ # Compute attention scores: (B, H, S, S)
167
+ scale = q.size(-1) ** -0.5
168
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
169
+
170
+ # Apply attention mask if provided
171
+ if attention_mask is not None:
172
+ attn_scores = attn_scores + attention_mask
173
+
174
+ # Compute attention weights
175
+ attn_weights = F.softmax(attn_scores, dim=-1) # (B, H, S, S)
176
+
177
+ # Apply attention to values
178
+ hidden_states = torch.matmul(attn_weights, v) # (B, H, S, D)
179
+
180
+ # Transpose back and flatten: (B, S, H, D) -> (B, S, H*D)
181
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
182
+
183
+
184
+ hidden_states = hidden_states.type_as(query)
185
+
186
+ if hidden_states_img is not None:
187
+ hidden_states = hidden_states + hidden_states_img
188
+
189
+ hidden_states = attn.to_out[0](hidden_states)
190
+ hidden_states = attn.to_out[1](hidden_states)
191
+
192
+ return hidden_states, attn_weights
193
+
194
+ class WanAttention(torch.nn.Module, AttentionModuleMixin):
195
+ _default_processor_cls = WanAttnProcessor
196
+ _available_processors = [WanAttnProcessor]
197
+
198
+ def __init__(
199
+ self,
200
+ dim: int,
201
+ heads: int = 8,
202
+ dim_head: int = 64,
203
+ eps: float = 1e-5,
204
+ dropout: float = 0.0,
205
+ added_kv_proj_dim: Optional[int] = None, #image embedding dimension
206
+ cross_attention_dim_head: Optional[int] = None, #text embedding dimension
207
+ processor=None,
208
+ is_cross_attention=None,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.inner_dim = dim_head * heads
213
+ self.heads = heads
214
+ self.added_kv_proj_dim = added_kv_proj_dim
215
+ self.cross_attention_dim_head = cross_attention_dim_head
216
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
217
+
218
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
219
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
220
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
221
+ self.to_out = torch.nn.ModuleList(
222
+ [
223
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
224
+ torch.nn.Dropout(dropout),
225
+ ]
226
+ )
227
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
228
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
229
+
230
+ self.add_k_proj = self.add_v_proj = None
231
+ if added_kv_proj_dim is not None:
232
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
233
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
234
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
235
+
236
+ self.is_cross_attention = cross_attention_dim_head is not None
237
+
238
+ self.set_processor(processor)
239
+
240
+ def fuse_projections(self):
241
+ if getattr(self, "fused_projections", False):
242
+ return
243
+
244
+ if self.cross_attention_dim_head is None:
245
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
246
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
247
+ out_features, in_features = concatenated_weights.shape
248
+ with torch.device("meta"):
249
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
250
+ self.to_qkv.load_state_dict(
251
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
252
+ )
253
+ else:
254
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
255
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
256
+ out_features, in_features = concatenated_weights.shape
257
+ with torch.device("meta"):
258
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
259
+ self.to_kv.load_state_dict(
260
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
261
+ )
262
+
263
+ if self.added_kv_proj_dim is not None:
264
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
265
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
266
+ out_features, in_features = concatenated_weights.shape
267
+ with torch.device("meta"):
268
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
269
+ self.to_added_kv.load_state_dict(
270
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
271
+ )
272
+
273
+ self.fused_projections = True
274
+
275
+ @torch.no_grad()
276
+ def unfuse_projections(self):
277
+ if not getattr(self, "fused_projections", False):
278
+ return
279
+
280
+ if hasattr(self, "to_qkv"):
281
+ delattr(
282
+ self, "to_qkv")
283
+ if hasattr(self, "to_kv"):
284
+ delattr(self, "to_kv")
285
+ if hasattr(self, "to_added_kv"):
286
+ delattr(self, "to_added_kv")
287
+
288
+ self.fused_projections = False
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ encoder_hidden_states: Optional[torch.Tensor] = None,
294
+ attention_mask: Optional[torch.Tensor] = None,
295
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
296
+ **kwargs,
297
+ ) -> torch.Tensor:
298
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
299
+
300
+
301
+ class WanImageEmbedding(torch.nn.Module):
302
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
303
+ super().__init__()
304
+
305
+ self.norm1 = FP32LayerNorm(in_features)
306
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
307
+ self.norm2 = FP32LayerNorm(out_features)
308
+ if pos_embed_seq_len is not None:
309
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
310
+ else:
311
+ self.pos_embed = None
312
+
313
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
314
+ if self.pos_embed is not None:
315
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
316
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
317
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
318
+
319
+ hidden_states = self.norm1(encoder_hidden_states_image)
320
+ hidden_states = self.ff(hidden_states)
321
+ hidden_states = self.norm2(hidden_states)
322
+ return hidden_states
323
+
324
+
325
+ class WanTimeTextImageEmbedding(nn.Module):
326
+ def __init__(
327
+ self,
328
+ dim: int,
329
+ time_freq_dim: int,
330
+ time_proj_dim: int,
331
+ text_embed_dim: int,
332
+ image_embed_dim: Optional[int] = None,
333
+ pos_embed_seq_len: Optional[int] = None,
334
+ ):
335
+ super().__init__()
336
+
337
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
338
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
339
+ self.act_fn = nn.SiLU()
340
+ self.time_proj = nn.Linear(dim, time_proj_dim)
341
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
342
+
343
+ self.image_embedder = None
344
+ if image_embed_dim is not None:
345
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
346
+
347
+ def forward(
348
+ self,
349
+ timestep: torch.Tensor,
350
+ encoder_hidden_states: torch.Tensor,
351
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
352
+ timestep_seq_len: Optional[int] = None,
353
+ ):
354
+ timestep = self.timesteps_proj(timestep)
355
+ if timestep_seq_len is not None:
356
+ timestep = timestep.unflatten(0, (-1, timestep_seq_len))
357
+
358
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
359
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
360
+ timestep = timestep.to(time_embedder_dtype)
361
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
362
+ timestep_proj = self.time_proj(self.act_fn(temb))
363
+
364
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
365
+ if encoder_hidden_states_image is not None:
366
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
367
+
368
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
369
+
370
+
371
+ class WanRotaryPosEmbed(nn.Module):
372
+ def __init__(
373
+ self,
374
+ attention_head_dim: int,
375
+ patch_size: Tuple[int, int, int],
376
+ max_seq_len: int,
377
+ theta: float = 10000.0,
378
+ ):
379
+ super().__init__()
380
+
381
+ self.attention_head_dim = attention_head_dim
382
+ self.patch_size = patch_size
383
+ self.max_seq_len = max_seq_len
384
+
385
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
386
+ t_dim = attention_head_dim - h_dim - w_dim
387
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
388
+
389
+ freqs_cos = []
390
+ freqs_sin = []
391
+
392
+ for dim in [t_dim, h_dim, w_dim]:
393
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
394
+ dim,
395
+ max_seq_len,
396
+ theta,
397
+ use_real=True,
398
+ repeat_interleave_real=True,
399
+ freqs_dtype=freqs_dtype,
400
+ )
401
+ freqs_cos.append(freq_cos)
402
+ freqs_sin.append(freq_sin)
403
+
404
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
405
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
406
+
407
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
408
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
409
+ p_t, p_h, p_w = self.patch_size
410
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
411
+
412
+ split_sizes = [
413
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
414
+ self.attention_head_dim // 3,
415
+ self.attention_head_dim // 3,
416
+ ]
417
+
418
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
419
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
420
+
421
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
422
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
423
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
424
+
425
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
426
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
427
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
428
+
429
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
430
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
431
+
432
+ return freqs_cos, freqs_sin
433
+
434
+
435
+ @maybe_allow_in_graph
436
+ class WanTransformerBlockOG(nn.Module):
437
+ def __init__(
438
+ self,
439
+ dim: int,
440
+ ffn_dim: int,
441
+ num_heads: int,
442
+ qk_norm: str = "rms_norm_across_heads",
443
+ cross_attn_norm: bool = False,
444
+ eps: float = 1e-6,
445
+ added_kv_proj_dim: Optional[int] = None,
446
+ ):
447
+ super().__init__()
448
+
449
+ # 1. Self-attention
450
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
451
+ self.attn1 = WanAttention(
452
+ dim=dim,
453
+ heads=num_heads,
454
+ dim_head=dim // num_heads,
455
+ eps=eps,
456
+ cross_attention_dim_head=None,
457
+ processor=WanAttnProcessor(),
458
+ )
459
+
460
+ # 2. Cross-attention
461
+ self.attn2 = WanAttention(
462
+ dim=dim,
463
+ heads=num_heads,
464
+ dim_head=dim // num_heads,
465
+ eps=eps,
466
+ added_kv_proj_dim=added_kv_proj_dim,
467
+ cross_attention_dim_head=dim // num_heads,
468
+ processor=WanAttnProcessor(),
469
+ )
470
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
471
+
472
+ # 3. Feed-forward
473
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
474
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
475
+
476
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ encoder_hidden_states: torch.Tensor,
482
+ temb: torch.Tensor,
483
+ rotary_emb: torch.Tensor,
484
+ ) -> torch.Tensor:
485
+ if temb.ndim == 4:
486
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
487
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
488
+ self.scale_shift_table.unsqueeze(0) + temb.float()
489
+ ).chunk(6, dim=2)
490
+ # batch_size, seq_len, 1, inner_dim
491
+ shift_msa = shift_msa.squeeze(2)
492
+ scale_msa = scale_msa.squeeze(2)
493
+ gate_msa = gate_msa.squeeze(2)
494
+ c_shift_msa = c_shift_msa.squeeze(2)
495
+ c_scale_msa = c_scale_msa.squeeze(2)
496
+ c_gate_msa = c_gate_msa.squeeze(2)
497
+ else:
498
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
499
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
500
+ self.scale_shift_table + temb.float()
501
+ ).chunk(6, dim=1)
502
+
503
+ # 1. Self-attention
504
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
505
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
506
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
507
+
508
+ # 2. Cross-attention
509
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
510
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
511
+ hidden_states = hidden_states + attn_output
512
+
513
+ # 3. Feed-forward
514
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
515
+ hidden_states
516
+ )
517
+ ff_output = self.ffn(norm_hidden_states)
518
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
519
+
520
+ return hidden_states
521
+
522
+ @maybe_allow_in_graph
523
+ class WanTransformerBlock(nn.Module):
524
+ def __init__(
525
+ self,
526
+ dim: int,
527
+ ffn_dim: int,
528
+ num_heads: int,
529
+ return_attention_maps: bool,
530
+ qk_norm: str = "rms_norm_across_heads",
531
+ eps: float = 1e-6,
532
+ ):
533
+ super().__init__()
534
+
535
+ # 1. Self-attention
536
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
537
+ self.attn1 = WanAttention(
538
+ dim=dim,
539
+ heads=num_heads,
540
+ dim_head=dim // num_heads,
541
+ eps=eps,
542
+ cross_attention_dim_head=None,
543
+ processor=WanAttnProcessor(return_attention_maps=return_attention_maps),
544
+ )
545
+
546
+ # 2. Feed-forward
547
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
548
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
549
+
550
+ # 3. Curriculum learning parameter for spatial attention
551
+ self.attention_window = -1 # -1 = full attention (default)
552
+
553
+ def forward(
554
+ self,
555
+ hidden_states: torch.Tensor,
556
+ rotary_emb: torch.Tensor,
557
+ attention_mask: Optional[torch.Tensor] = None,
558
+ ) -> torch.Tensor:
559
+ attn_weights = None
560
+
561
+ # 1. Self-attention
562
+ norm_hidden_states = self.norm1(hidden_states.float()).type_as(hidden_states)
563
+ attn_output, attn_weights = self.attn1(norm_hidden_states, None, attention_mask, rotary_emb)
564
+ hidden_states = (hidden_states.float() + attn_output).type_as(hidden_states)
565
+
566
+ # 2. Feed-forward
567
+ norm_hidden_states = self.norm3(hidden_states.float()).type_as(hidden_states)
568
+ ff_output = self.ffn(norm_hidden_states)
569
+ hidden_states = (hidden_states.float() + ff_output.float()).type_as(hidden_states)
570
+
571
+ return hidden_states, attn_weights
572
+
573
+
574
+ class WanTransformer3DModel(
575
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
576
+ ):
577
+ r"""
578
+ A Transformer model for video-like data used in the Wan model.
579
+
580
+ Args:
581
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
582
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
583
+ num_attention_heads (`int`, defaults to `40`):
584
+ Fixed length for text embeddings.
585
+ attention_head_dim (`int`, defaults to `128`):
586
+ The number of channels in each head.
587
+ in_channels (`int`, defaults to `16`):
588
+ The number of channels in the input.
589
+ out_channels (`int`, defaults to `16`):
590
+ The number of channels in the output.
591
+ text_dim (`int`, defaults to `512`):
592
+ Input dimension for text embeddings.
593
+ freq_dim (`int`, defaults to `256`):
594
+ Dimension for sinusoidal time embeddings.
595
+ ffn_dim (`int`, defaults to `13824`):
596
+ Intermediate dimension in feed-forward network.
597
+ num_layers (`int`, defaults to `40`):
598
+ The number of layers of transformer blocks to use.
599
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
600
+ Window size for local attention (-1 indicates global attention).
601
+ cross_attn_norm (`bool`, defaults to `True`):
602
+ Enable cross-attention normalization.
603
+ qk_norm (`bool`, defaults to `True`):
604
+ Enable query/key normalization.
605
+ eps (`float`, defaults to `1e-6`):
606
+ Epsilon value for normalization layers.
607
+ add_img_emb (`bool`, defaults to `False`):
608
+ Whether to use img_emb.
609
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
610
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
611
+ """
612
+
613
+ _supports_gradient_checkpointing = True
614
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
615
+ _no_split_modules = ["WanTransformerBlock"]
616
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
617
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
618
+ _repeated_blocks = ["WanTransformerBlock"]
619
+
620
+ @register_to_config
621
+ def __init__(
622
+ self,
623
+ num_attention_heads: int = 40,
624
+ attention_head_dim: int = 128,
625
+ ffn_dim: int = 13824,
626
+ num_layers: int = 40,
627
+ qk_norm: Optional[str] = "rms_norm_across_heads",
628
+ eps: float = 1e-6,
629
+ gradient_checkpointing: bool = False,
630
+ ) -> None:
631
+ super().__init__()
632
+
633
+ inner_dim = num_attention_heads * attention_head_dim
634
+
635
+ # Transformer blocks
636
+ self.blocks = nn.ModuleList(
637
+ [
638
+ WanTransformerBlock(
639
+ inner_dim, ffn_dim, num_attention_heads, False, qk_norm, eps
640
+ )
641
+ for i in range(num_layers)
642
+ ]
643
+ )
644
+
645
+ self.gradient_checkpointing = gradient_checkpointing
646
+
647
+ class WanDecoderTransformer(torch.nn.Module):
648
+ def __init__(
649
+ self,
650
+ chunk:int = 2,
651
+ rope_max_seq_len=None,
652
+ patch_size=[(1, 2, 2), (1, 4, 4), (1, 8, 8)],
653
+ num_layers: int = 30,
654
+ num_heads=12,
655
+ head_dim=128,
656
+ channels=[384, 192, 192],
657
+ use_lora: bool = False,
658
+ lora_rank: int = 8,
659
+ lora_alpha: int = 32,
660
+ lora_dropout: float = 0.1,
661
+ reusing: bool = False,
662
+ pretrained: bool = True,
663
+ gradient_checkpointing: bool = False,
664
+ ) -> None:
665
+ super().__init__()
666
+
667
+ self.chunk = chunk
668
+ self.use_lora = use_lora
669
+ self.attn_weights = []
670
+
671
+ # # Initialize the transformer
672
+ if pretrained:
673
+ self.transformer = WanTransformer3DModel.from_pretrained(
674
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
675
+ subfolder="transformer",
676
+ num_attention_heads=12,
677
+ attention_head_dim=128,
678
+ num_layers=30,
679
+ ffn_dim=8960,
680
+ eps=1e-6,
681
+ qk_norm="rms_norm_across_heads",
682
+ gradient_checkpointing=gradient_checkpointing,
683
+ torch_dtype=torch.float32,
684
+ device_map=None,
685
+ ignore_mismatched_sizes=True,
686
+ strict=False
687
+ )
688
+ else:
689
+ self.transformer = WanTransformer3DModel(
690
+ num_attention_heads=num_heads,
691
+ attention_head_dim=head_dim,
692
+ num_layers=num_layers,
693
+ ffn_dim=8960,
694
+ eps=1e-6,
695
+ qk_norm="rms_norm_across_heads",
696
+ gradient_checkpointing=gradient_checkpointing,
697
+ )
698
+
699
+ # Apply LoRA if requested
700
+ if self.use_lora:
701
+ self._apply_lora(lora_rank, lora_alpha, lora_dropout)
702
+
703
+ # Configuration
704
+ self.channels = channels
705
+ self.num_attention_heads = num_heads
706
+ self.attention_head_dim = head_dim
707
+ self.num_layers = num_layers
708
+ self.reusing = reusing
709
+ inner_dim = self.num_attention_heads * self.attention_head_dim
710
+
711
+ # Ensure each image has 1560 tokens
712
+ seq_len_per_chunk = 1560
713
+ chunk = self.chunk
714
+ self.patch_size = patch_size
715
+ if rope_max_seq_len is None:
716
+ self.rope_max_seq_len = [seq_len_per_chunk * (chunk + 1), seq_len_per_chunk * (2 * chunk), seq_len_per_chunk * (4 * chunk - 2)]
717
+ else:
718
+ self.rope_max_seq_len = rope_max_seq_len
719
+ eps = 1e-6
720
+
721
+ # 1. Patch & position embedding
722
+ self.patch_embeddings = nn.ModuleList([
723
+ nn.Conv3d(channels[0], inner_dim, kernel_size=self.patch_size[0], stride=self.patch_size[0]), # First upblock output
724
+ nn.Conv3d(channels[1], inner_dim, kernel_size=self.patch_size[1], stride=self.patch_size[1]), # Second upblock output
725
+ nn.Conv3d(channels[2], inner_dim, kernel_size=self.patch_size[2], stride=self.patch_size[2]), # Third upblock output
726
+ ])
727
+
728
+ self.rope = nn.ModuleList([
729
+ WanRotaryPosEmbed(self.attention_head_dim, self.patch_size[i], self.rope_max_seq_len[i]) for i in range(3)
730
+ ])
731
+
732
+ # Output norms & projections for three resolutions
733
+ self.norm_outs = nn.ModuleList([
734
+ FP32LayerNorm(inner_dim, eps, elementwise_affine=False),
735
+ FP32LayerNorm(inner_dim, eps, elementwise_affine=False),
736
+ FP32LayerNorm(inner_dim, eps, elementwise_affine=False),
737
+ ])
738
+
739
+ self.proj_outs = nn.ModuleList([
740
+ nn.Linear(inner_dim, channels[0] * math.prod(self.patch_size[0])),
741
+ nn.Linear(inner_dim, channels[1] * math.prod(self.patch_size[1])),
742
+ nn.Linear(inner_dim, channels[2] * math.prod(self.patch_size[2])),
743
+ ])
744
+
745
+ self.initialize_decoder_components()
746
+
747
+ def initialize_decoder_components(self):
748
+ """Initialize patch embeddings and position embeddings"""
749
+ import math
750
+
751
+ # Initialize patch embeddings
752
+ for patch_embed in self.patch_embeddings:
753
+ patch_embed.reset_parameters()
754
+
755
+ # # Initialize position embeddings (ViT standard)
756
+ # for pos_embed in self.pos_embeds:
757
+ # nn.init.trunc_normal_(pos_embed, std=0.02)
758
+
759
+ # Initialize output projections
760
+ for proj_out in self.proj_outs:
761
+ nn.init.xavier_uniform_(proj_out.weight)
762
+ # nn.init.zeros_(proj_out.weight)
763
+ if proj_out.bias is not None:
764
+ nn.init.zeros_(proj_out.bias)
765
+
766
+ def _apply_lora(self, lora_rank, lora_alpha, lora_dropout):
767
+ """Apply LoRA to transformer blocks"""
768
+
769
+ lora_config = LoraConfig(
770
+ r=lora_rank,
771
+ lora_alpha=lora_alpha,
772
+ target_modules=[
773
+ "to_q", "to_k", "to_v", "to_out.0",
774
+ "ffn.net.0.proj", "ffn.net.2",
775
+ ],
776
+ lora_dropout=lora_dropout,
777
+ bias="none",
778
+ task_type=TaskType.FEATURE_EXTRACTION,
779
+ )
780
+
781
+ self.transformer = get_peft_model(self.transformer, lora_config)
782
+
783
+ def get_lora_target_modules(self):
784
+ """Return the target modules configured for LoRA on the wrapped transformer."""
785
+ if not self.use_lora:
786
+ return []
787
+
788
+ transformer = getattr(self, "transformer", None)
789
+ if transformer is None:
790
+ return []
791
+
792
+ peft_config = getattr(transformer, "peft_config", None)
793
+ if not peft_config:
794
+ return []
795
+
796
+ active_adapter = getattr(transformer, "active_adapter", None)
797
+ if active_adapter and active_adapter in peft_config:
798
+ config = peft_config[active_adapter]
799
+ else:
800
+ config = next(iter(peft_config.values()))
801
+
802
+ target_modules = getattr(config, "target_modules", None)
803
+ if target_modules is None:
804
+ return []
805
+
806
+ return list(target_modules)
807
+
808
+ def fuse_lora_weights(self):
809
+ """
810
+ Fuse LoRA weights into the base model weights.
811
+
812
+ This merges the low-rank adaptation matrices (A and B) with the original weights:
813
+ W' = W + (scaling * B @ A)
814
+
815
+ After fusing, the model will have the same behavior but without the LoRA overhead,
816
+ making it more efficient for inference.
817
+
818
+ Returns:
819
+ bool: True if fusion was successful, False otherwise
820
+ """
821
+ if not self.use_lora:
822
+ print("⚠ LoRA is not enabled, nothing to fuse")
823
+ return False
824
+
825
+ try:
826
+ # PEFT library provides a merge_and_unload method
827
+ print("Fusing LoRA weights into base model...")
828
+
829
+ # Get the base model with fused weights
830
+ self.transformer = self.transformer.merge_and_unload()
831
+
832
+ # Update the use_lora flag since LoRA is now fused
833
+ self.use_lora = False
834
+
835
+ print("✓ Successfully fused LoRA weights into base model")
836
+ return True
837
+
838
+ except Exception as e:
839
+ print(f"✗ Error fusing LoRA weights: {e}")
840
+ return False
841
+
842
+ def unfuse_lora_weights(self):
843
+ """
844
+ Unfuse/unmerge LoRA weights from the base model.
845
+
846
+ This separates the LoRA weights from base weights if they were previously merged.
847
+ Note: This only works if the model still has LoRA adapters loaded.
848
+
849
+ Returns:
850
+ bool: True if unfusion was successful, False otherwise
851
+ """
852
+ if not self.use_lora:
853
+ print("⚠ LoRA is not enabled or already unfused")
854
+ return False
855
+
856
+ try:
857
+ print("Unfusing LoRA weights from base model...")
858
+
859
+ # PEFT library provides an unmerge method
860
+ self.transformer.unmerge_adapter()
861
+
862
+ print("✓ Successfully unfused LoRA weights from base model")
863
+ return True
864
+
865
+ except Exception as e:
866
+ print(f"✗ Error unfusing LoRA weights: {e}")
867
+ return False
868
+
869
+ def get_map(self):
870
+ return self.attn_weights
871
+
872
+ def clear_map(self):
873
+ self.attn_weights = []
874
+
875
+ def create_spatial_mask(self, attention_window, num_frames, height, width, device):
876
+ """
877
+ Create spatial attention mask for self-attention within frames.
878
+
879
+ Restricts each token to attend only to spatially nearby tokens within the same frame.
880
+ Uses Manhattan distance for spatial proximity.
881
+
882
+ Args:
883
+ batch_size: Batch size
884
+ num_frames: Number of temporal frames
885
+ height: Spatial height of feature map
886
+ width: Spatial width of feature map
887
+ device: torch device
888
+
889
+ Returns:
890
+ Attention mask [1, 1, seq_len, seq_len] or None if full attention
891
+ """
892
+ if attention_window < 0:
893
+ return None # Full attention
894
+
895
+ seq_len = num_frames * height * width
896
+
897
+ # Tokens are ordered as [t0_h0_w0, t0_h0_w1, ..., t0_hH_wW, t1_h0_w0, ...]
898
+
899
+ # For each query token, compute which key token it should attend to
900
+ # Query token i at (t_q, h_q, w_q) should attend to key token at (t=0, h_q, w_q)
901
+
902
+ # Create indices for spatial positions (h, w) - reused across frames
903
+ spatial_size = height * width
904
+ h_indices = torch.arange(height, device=device).repeat_interleave(width) # [0,0,...,0,1,1,...,1,...]
905
+ w_indices = torch.arange(width, device=device).repeat(height) # [0,1,2,...,W-1,0,1,2,...,W-1,...]
906
+
907
+ # For each query position, find the corresponding key index in first frame
908
+ # Query at frame t, position (h,w) -> Key at frame 0, position (h,w)
909
+ # Key index = h * width + w
910
+ key_indices_per_spatial_pos = h_indices * width + w_indices # [spatial_size]
911
+
912
+ # Repeat this pattern for all frames (each query frame uses same spatial mapping)
913
+ key_indices = key_indices_per_spatial_pos.repeat(num_frames) # [seq_len]
914
+
915
+ # Create sparse mask more efficiently using indexing
916
+ # Initialize with -inf (block all attention)
917
+ attention_mask = torch.full((seq_len, seq_len), float('-inf'), dtype=torch.float32, device=device)
918
+
919
+ # For each query position, allow attention to exactly one key position
920
+ query_indices = torch.arange(seq_len, device=device)
921
+ attention_mask[query_indices, key_indices] = 0.0
922
+
923
+ # Add batch and head dimensions: [1, 1, seq_len, seq_len]
924
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0)
925
+
926
+ return attention_mask
927
+
928
+ def forward(
929
+ self,
930
+ hidden_states: torch.Tensor,
931
+ stage_idx: int = 0,
932
+ return_dict: bool = True,
933
+ window_size=-1,
934
+ attention_kwargs: Optional[Dict[str, Any]] = None,
935
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
936
+ """
937
+ Args:
938
+ hidden_states: Input tensor (B, C, T, H, W) where C is 384 or 192
939
+ stage_idx: 0 for first stage (384 channels), 1 for second stage (192 channels)
940
+ return_dict: Whether to return dict or tuple
941
+ attention_kwargs: Additional attention arguments
942
+ """
943
+
944
+ assert stage_idx in [0, 1, 2], f"stage_idx must be 0 or 1, got {stage_idx}"
945
+
946
+ # clear previous attention weights
947
+ # self.attn_weights = []
948
+
949
+ if attention_kwargs is not None:
950
+ attention_kwargs = attention_kwargs.copy()
951
+ lora_scale = attention_kwargs.pop("scale", 1.0)
952
+ else:
953
+ lora_scale = 1.0
954
+
955
+ if USE_PEFT_BACKEND:
956
+ scale_lora_layers(self, lora_scale)
957
+ else:
958
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
959
+ logger.warning(
960
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
961
+ )
962
+
963
+ # Get input dimensions
964
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
965
+ p_t, p_h, p_w = self.patch_size[stage_idx]
966
+
967
+ # Keep exact output shape even when T/H/W are not divisible by patch size.
968
+ # We pad before patch embedding and crop back after unpatchify.
969
+ pad_t = (p_t - (num_frames % p_t)) % p_t
970
+ pad_h = (p_h - (height % p_h)) % p_h
971
+ pad_w = (p_w - (width % p_w)) % p_w
972
+ if pad_t or pad_h or pad_w:
973
+ hidden_states = F.pad(hidden_states, (0, pad_w, 0, pad_h, 0, pad_t))
974
+
975
+ _, _, padded_num_frames, padded_height, padded_width = hidden_states.shape
976
+ post_patch_num_frames = padded_num_frames // p_t
977
+ post_patch_height = padded_height // p_h
978
+ post_patch_width = padded_width // p_w
979
+
980
+ # Select appropriate patch embedding based on stage
981
+ patch_embedding = self.patch_embeddings[stage_idx]
982
+ rotary_emb = self.rope[stage_idx](hidden_states)
983
+
984
+ # Patch embedding
985
+ hidden_states = patch_embedding(hidden_states)
986
+ hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, seq_len, inner_dim)
987
+ assert hidden_states.shape[1] <= self.rope_max_seq_len[stage_idx], (
988
+ f"Sequence length {hidden_states.shape[1]} is greater than maximum sequence length "
989
+ f"{self.rope_max_seq_len[stage_idx]} for stage {stage_idx}"
990
+ )
991
+ # Select transformer blocks
992
+ if self.reusing:
993
+ transformer_blocks = self.transformer.blocks
994
+ else:
995
+ blocks_per_stage = self.num_layers // 3
996
+ transformer_blocks = self.transformer.blocks[stage_idx * blocks_per_stage : (stage_idx + 1) * blocks_per_stage]
997
+
998
+ # Run transformer blocks
999
+ attention_mask = self.create_spatial_mask(
1000
+ window_size,
1001
+ post_patch_num_frames,
1002
+ post_patch_height,
1003
+ post_patch_width,
1004
+ hidden_states.device,
1005
+ )
1006
+ if torch.is_grad_enabled() and getattr(self.transformer, 'gradient_checkpointing', False):
1007
+ for block in transformer_blocks:
1008
+ hidden_states, attn_weight = torch.utils.checkpoint.checkpoint(
1009
+ block,
1010
+ hidden_states,
1011
+ rotary_emb,
1012
+ attention_mask,
1013
+ use_reentrant=False
1014
+ )
1015
+ self.attn_weights.append(attn_weight)
1016
+ else:
1017
+ for block in transformer_blocks:
1018
+ hidden_states, attn_weight = block(
1019
+ hidden_states,
1020
+ rotary_emb,
1021
+ attention_mask,
1022
+ )
1023
+ self.attn_weights.append(attn_weight)
1024
+
1025
+ # Output norm & projection
1026
+ norm_out = self.norm_outs[stage_idx]
1027
+ proj_out = self.proj_outs[stage_idx]
1028
+
1029
+ hidden_states = norm_out(hidden_states.float()).type_as(hidden_states)
1030
+ hidden_states = proj_out(hidden_states)
1031
+
1032
+ # Unpatchify
1033
+ out_channels = self.channels[stage_idx]
1034
+ hidden_states = hidden_states.reshape(
1035
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width,
1036
+ p_t, p_h, p_w, out_channels
1037
+ )
1038
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
1039
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1040
+ if pad_t or pad_h or pad_w:
1041
+ output = output[:, :, :num_frames, :height, :width]
1042
+
1043
+ if USE_PEFT_BACKEND:
1044
+ unscale_lora_layers(self, lora_scale)
1045
+
1046
+ if not return_dict:
1047
+ return (output,)
1048
+
1049
+ return Transformer2DModelOutput(sample=output)
src/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+