Yuanshi commited on
Commit
0b31b45
·
1 Parent(s): 0a257b6
Files changed (5) hide show
  1. .gitignore +2 -0
  2. app.py +97 -0
  3. vibt/qwen_image.py +347 -0
  4. vibt/scheduler.py +44 -0
  5. vibt/wan.py +132 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/__pycache__/
2
+ *.pyc
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import os
5
+
6
+ # ==========================================
7
+ # 1. 核心处理函数 (骨架)
8
+ # ==========================================
9
+ @spaces.GPU
10
+ def run_stylization(input_video_path, prompt, noise_scale, shift_gamma, steps, guidance_scale, seed):
11
+ """
12
+ 这里是实际推理逻辑的占位符。
13
+ """
14
+ if not input_video_path:
15
+ return None
16
+
17
+ print("========== Inference Start ==========")
18
+ print(f"Video Path: {input_video_path}")
19
+ print(f"Prompt: {prompt}")
20
+ print(f"Params: Noise={noise_scale}, Gamma={shift_gamma}, Steps={steps}, CFG={guidance_scale}, Seed={seed}")
21
+
22
+ # [Prototype Logic] 直接返回输入视频用于演示
23
+ return input_video_path
24
+
25
+ # ==========================================
26
+ # 2. 界面布局 (Gradio Blocks)
27
+ # ==========================================
28
+
29
+ # 移除 CSS 参数以修复 TypeError
30
+ with gr.Blocks() as demo:
31
+ with gr.Column(elem_id="col-container"):
32
+ gr.Markdown("# 🎥 ViBT Video Stylization Interface")
33
+ gr.Markdown("上传视频并设置风格化参数。")
34
+
35
+ with gr.Row():
36
+ # --- 左侧:输入与设置 ---
37
+ with gr.Column():
38
+ # 视频输入
39
+ input_video = gr.Video(label="Source Video", sources=["upload"])
40
+
41
+ # 提示词
42
+ prompt_input = gr.Textbox(
43
+ label="Style Prompt",
44
+ placeholder="e.g., Van Gogh style, cyberpunk city...",
45
+ value="Oil painting style, vivid colors"
46
+ )
47
+
48
+ # 高级参数折叠区
49
+ with gr.Accordion("Advanced Settings", open=True):
50
+ with gr.Row():
51
+ noise_scale = gr.Slider(
52
+ label="Noise Scale", minimum=0.0, maximum=2.0, step=0.1, value=1.0,
53
+ info="Controls how much noise is added."
54
+ )
55
+ shift_gamma = gr.Slider(
56
+ label="Shift Gamma", minimum=1.0, maximum=10.0, step=0.5, value=5.0,
57
+ info="Scheduler parameter."
58
+ )
59
+
60
+ with gr.Row():
61
+ num_steps = gr.Slider(
62
+ label="Inference Steps", minimum=10, maximum=50, step=1, value=28,
63
+ info="More steps = higher quality but slower."
64
+ )
65
+ guidance_scale = gr.Slider(
66
+ label="Guidance Scale (CFG)", minimum=1.0, maximum=20.0, step=0.5, value=1.5,
67
+ info="How closely to follow the prompt."
68
+ )
69
+
70
+ seed = gr.Number(label="Seed", value=42, precision=0)
71
+
72
+ # 提交按钮
73
+ run_btn = gr.Button("Generate Video", variant="primary")
74
+
75
+ # --- 右侧:结果输出 ---
76
+ with gr.Column():
77
+ output_video = gr.Video(label="Stylized Result", interactive=False)
78
+
79
+ # ==========================================
80
+ # 3. 事件绑定
81
+ # ==========================================
82
+ run_btn.click(
83
+ fn=run_stylization,
84
+ inputs=[
85
+ input_video,
86
+ prompt_input,
87
+ noise_scale,
88
+ shift_gamma,
89
+ num_steps,
90
+ guidance_scale,
91
+ seed
92
+ ],
93
+ outputs=[output_video]
94
+ )
95
+
96
+ if __name__ == "__main__":
97
+ demo.launch()
vibt/qwen_image.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+
6
+
7
+ from diffusers import QwenImageEditPipeline, QwenImagePipeline
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.pipelines.qwenimage.pipeline_output import QwenImagePipelineOutput
10
+
11
+ from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import retrieve_latents
12
+
13
+
14
+ def encode_vae_image(pipe, image: torch.Tensor, generator: torch.Generator):
15
+ latent_channels = pipe.vae.config.z_dim if getattr(pipe, "vae", None) else 16
16
+ image_latents = retrieve_latents(
17
+ pipe.vae.encode(image), generator=generator, sample_mode="argmax"
18
+ )
19
+ latents_mean = (
20
+ torch.tensor(pipe.vae.config.latents_mean)
21
+ .view(1, latent_channels, 1, 1, 1)
22
+ .to(image_latents.device, image_latents.dtype)
23
+ )
24
+ latents_std = (
25
+ torch.tensor(pipe.vae.config.latents_std)
26
+ .view(1, latent_channels, 1, 1, 1)
27
+ .to(image_latents.device, image_latents.dtype)
28
+ )
29
+ image_latents = (image_latents - latents_mean) / latents_std
30
+
31
+ return image_latents
32
+
33
+
34
+ @torch.no_grad()
35
+ def encode_image(pipe: QwenImagePipeline, image):
36
+ width, height = image.size
37
+ image = pipe.image_processor.preprocess(image, height, width)
38
+ image = image.to(dtype=pipe.dtype, device=pipe.device).unsqueeze(2)
39
+ image_latents = encode_vae_image(pipe, image, None)
40
+
41
+ image_latent_height, image_latent_width = image_latents.shape[3:]
42
+ image_latents = pipe._pack_latents(
43
+ image_latents,
44
+ 1,
45
+ pipe.transformer.config.in_channels // 4,
46
+ image_latent_height,
47
+ image_latent_width,
48
+ )
49
+ return image_latents
50
+
51
+
52
+ @torch.no_grad()
53
+ def decode_latents_image(pipe: QwenImagePipeline, latents):
54
+ latents = pipe._unpack_latents(latents, 1024, 1024, pipe.vae_scale_factor)
55
+ latents = latents.to(pipe.vae.dtype)
56
+ latents_mean = (
57
+ torch.tensor(pipe.vae.config.latents_mean)
58
+ .view(1, pipe.vae.config.z_dim, 1, 1, 1)
59
+ .to(latents.device, latents.dtype)
60
+ )
61
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
62
+ 1, pipe.vae.config.z_dim, 1, 1, 1
63
+ ).to(latents.device, latents.dtype)
64
+ latents = latents / latents_std + latents_mean
65
+ image = pipe.vae.decode(latents, return_dict=False)[0][:, :, 0]
66
+ image = pipe.image_processor.postprocess(image, output_type="pil")
67
+ return image
68
+
69
+
70
+ aspect_ratios = {
71
+ "1:1": (1328, 1328),
72
+ "16:9": (1664, 928),
73
+ "9:16": (928, 1664),
74
+ "4:3": (1472, 1104),
75
+ "3:4": (1104, 1472),
76
+ "3:2": (1584, 1056),
77
+ "2:3": (1056, 1584),
78
+ }
79
+
80
+
81
+ def process_input_img(image):
82
+ # find the closest aspect ratio
83
+ w, h = image.size
84
+ aspect_ratio = w / h
85
+ closest_ratio = min(
86
+ aspect_ratios.items(),
87
+ key=lambda x: abs((x[1][0] / x[1][1]) - aspect_ratio),
88
+ )
89
+ target_size = closest_ratio[1]
90
+ return image.resize(target_size, Image.LANCZOS)
91
+
92
+
93
+ @torch.no_grad()
94
+ def qwen_bridge_gen(
95
+ self: QwenImageEditPipeline,
96
+ image: Optional[PipelineImageInput] = None,
97
+ prompt: Union[str, List[str]] = None,
98
+ negative_prompt: Union[str, List[str]] = None,
99
+ true_cfg_scale: float = 4.0,
100
+ height: Optional[int] = None,
101
+ width: Optional[int] = None,
102
+ num_inference_steps: int = 50,
103
+ guidance_scale: float = 1.0,
104
+ num_images_per_prompt: int = 1,
105
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
106
+ latents: Optional[torch.Tensor] = None,
107
+ prompt_embeds: Optional[torch.Tensor] = None,
108
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
109
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
110
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
111
+ output_type: Optional[str] = "pil",
112
+ return_dict: bool = True,
113
+ attention_kwargs: Optional[Dict[str, Any]] = None,
114
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
115
+ max_sequence_length: int = 512,
116
+ # Bridge specific
117
+ return_trajectory=False,
118
+ ):
119
+ image_size = image[0].size if isinstance(image, list) else image.size
120
+ calculated_width, calculated_height = image_size
121
+ height = height or calculated_height
122
+ width = width or calculated_width
123
+
124
+ multiple_of = self.vae_scale_factor * 2
125
+ width = width // multiple_of * multiple_of
126
+ height = height // multiple_of * multiple_of
127
+
128
+ # 1. Check inputs. Raise error if not correct
129
+ self.check_inputs(
130
+ prompt,
131
+ height,
132
+ width,
133
+ negative_prompt=negative_prompt,
134
+ prompt_embeds=prompt_embeds,
135
+ negative_prompt_embeds=negative_prompt_embeds,
136
+ prompt_embeds_mask=prompt_embeds_mask,
137
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
138
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
139
+ max_sequence_length=max_sequence_length,
140
+ )
141
+
142
+ self._guidance_scale = guidance_scale
143
+ self._attention_kwargs = attention_kwargs
144
+ self._current_timestep = None
145
+ self._interrupt = False
146
+
147
+ # 2. Define call parameters
148
+ if prompt is not None and isinstance(prompt, str):
149
+ batch_size = 1
150
+ elif prompt is not None and isinstance(prompt, list):
151
+ batch_size = len(prompt)
152
+ else:
153
+ batch_size = prompt_embeds.shape[0]
154
+
155
+ device = self._execution_device
156
+ # 3. Preprocess image
157
+ if image is not None and not (
158
+ isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels
159
+ ):
160
+ image = self.image_processor.resize(image, calculated_height, calculated_width)
161
+ prompt_image = image
162
+ image = self.image_processor.preprocess(
163
+ image, calculated_height, calculated_width
164
+ )
165
+ image = image.unsqueeze(2)
166
+
167
+ has_neg_prompt = negative_prompt is not None or (
168
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
169
+ )
170
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
171
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
172
+ image=prompt_image,
173
+ prompt=prompt,
174
+ prompt_embeds=prompt_embeds,
175
+ prompt_embeds_mask=prompt_embeds_mask,
176
+ device=device,
177
+ num_images_per_prompt=num_images_per_prompt,
178
+ max_sequence_length=max_sequence_length,
179
+ )
180
+ if do_true_cfg:
181
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
182
+ image=prompt_image,
183
+ prompt=negative_prompt,
184
+ prompt_embeds=negative_prompt_embeds,
185
+ prompt_embeds_mask=negative_prompt_embeds_mask,
186
+ device=device,
187
+ num_images_per_prompt=num_images_per_prompt,
188
+ max_sequence_length=max_sequence_length,
189
+ )
190
+
191
+ # 4. Prepare latent variables
192
+ num_channels_latents = self.transformer.config.in_channels // 4
193
+ noise, image_latents = self.prepare_latents(
194
+ image,
195
+ batch_size * num_images_per_prompt,
196
+ num_channels_latents,
197
+ height,
198
+ width,
199
+ prompt_embeds.dtype,
200
+ device,
201
+ generator,
202
+ latents,
203
+ )
204
+ latents = image_latents.clone()
205
+ img_shapes = [
206
+ [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]
207
+ ] * batch_size
208
+
209
+ # 5. Prepare timesteps
210
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
211
+ timesteps = self.scheduler.timesteps
212
+ self._num_timesteps = len(timesteps)
213
+
214
+ # handle guidance
215
+ guidance = None
216
+ txt_seq_lens = (
217
+ prompt_embeds_mask.sum(dim=1).tolist()
218
+ if prompt_embeds_mask is not None
219
+ else None
220
+ )
221
+ negative_txt_seq_lens = (
222
+ negative_prompt_embeds_mask.sum(dim=1).tolist()
223
+ if negative_prompt_embeds_mask is not None
224
+ else None
225
+ )
226
+
227
+ trajectory = [latents] if return_trajectory else None
228
+
229
+ # 6. Denoising loop
230
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
231
+ for i, t in enumerate(timesteps):
232
+ if self.interrupt:
233
+ continue
234
+
235
+ self._current_timestep = t
236
+
237
+ latent_model_input = latents
238
+
239
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
240
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
241
+ with self.transformer.cache_context("cond"):
242
+ noise_pred = self.transformer(
243
+ hidden_states=latent_model_input,
244
+ timestep=timestep / 1000,
245
+ guidance=guidance,
246
+ encoder_hidden_states_mask=prompt_embeds_mask,
247
+ encoder_hidden_states=prompt_embeds,
248
+ img_shapes=img_shapes,
249
+ # img_shapes=[[(1, 64, 64)]],
250
+ txt_seq_lens=txt_seq_lens,
251
+ attention_kwargs={},
252
+ return_dict=False,
253
+ )[0]
254
+ noise_pred = noise_pred[:, : latents.size(1)]
255
+
256
+ if do_true_cfg:
257
+ with self.transformer.cache_context("uncond"):
258
+ neg_noise_pred = self.transformer(
259
+ hidden_states=latent_model_input,
260
+ timestep=timestep / 1000,
261
+ guidance=guidance,
262
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
263
+ encoder_hidden_states=negative_prompt_embeds,
264
+ img_shapes=img_shapes,
265
+ # img_shapes=[[(1, 64, 64)]],
266
+ txt_seq_lens=negative_txt_seq_lens,
267
+ attention_kwargs=self.attention_kwargs,
268
+ return_dict=False,
269
+ )[0]
270
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
271
+ comb_pred = neg_noise_pred + true_cfg_scale * (
272
+ noise_pred - neg_noise_pred
273
+ )
274
+
275
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
276
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
277
+ noise_pred = comb_pred * (cond_norm / noise_norm)
278
+
279
+ # # step
280
+ # next_t = timesteps[i + 1] if i < len(timesteps) - 1 else 0
281
+
282
+ # sigma_t = t / 1000
283
+ # sigma_next_t = next_t / 1000
284
+ # sigma_delta = sigma_next_t - sigma_t
285
+ # print(
286
+ # f"sigma_t: {sigma_t}, sigma_next_t: {sigma_next_t}, sigma_delta: {sigma_delta}"
287
+ # )
288
+
289
+ # noise = torch.randn(
290
+ # latents.shape,
291
+ # dtype=latents.dtype,
292
+ # device=latents.device,
293
+ # generator=generator,
294
+ # )
295
+ # eta = torch.sqrt(-sigma_delta * sigma_next_t / sigma_t)
296
+ # # eta = torch.sqrt(-sigma_delta)
297
+
298
+ # coef = torch.clip(noise_pred.abs(), 0, 1) if rescale_noise else 1
299
+ # latents = latents + noise_pred * sigma_delta + sigma * eta * noise * coef
300
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
301
+
302
+ if return_trajectory:
303
+ trajectory.append(latents)
304
+
305
+ # call the callback, if provided
306
+ progress_bar.update()
307
+
308
+ self._current_timestep = None
309
+ if output_type == "latent":
310
+ image = latents
311
+ else:
312
+
313
+ def decode_latents(latents, height, width):
314
+ latents = self._unpack_latents(
315
+ latents, height, width, self.vae_scale_factor
316
+ )
317
+ latents = latents.to(self.vae.dtype)
318
+ latents_mean = (
319
+ torch.tensor(self.vae.config.latents_mean)
320
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
321
+ .to(latents.device, latents.dtype)
322
+ )
323
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(
324
+ 1, self.vae.config.z_dim, 1, 1, 1
325
+ ).to(latents.device, latents.dtype)
326
+ latents = latents / latents_std + latents_mean
327
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
328
+ image = self.image_processor.postprocess(image, output_type=output_type)
329
+ return image
330
+
331
+ image = decode_latents(latents, height, width)
332
+ trajectory = (
333
+ [decode_latents(t, height, width)[0] for t in trajectory]
334
+ if return_trajectory
335
+ else None
336
+ )
337
+
338
+ # Offload all models
339
+ self.maybe_free_model_hooks()
340
+
341
+ if not return_dict:
342
+ return (image,)
343
+
344
+ if return_trajectory:
345
+ return QwenImagePipelineOutput(images=image), trajectory
346
+ else:
347
+ return QwenImagePipelineOutput(images=image)
vibt/scheduler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.schedulers import UniPCMultistepScheduler
2
+ import torch
3
+
4
+
5
+ class ViBTScheduler(UniPCMultistepScheduler):
6
+ def __init__(self, **kwargs):
7
+ super().__init__(**{**kwargs, "use_flow_sigmas": True})
8
+ self.set_parameters()
9
+
10
+ def set_parameters(self, noise_scale=1.0, shift_gamma=5.0, seed=None):
11
+ self.noise_scale = noise_scale
12
+ self.config.flow_shift = shift_gamma
13
+ self.generator = (
14
+ None if seed is None else torch.Generator("cuda").manual_seed(seed)
15
+ )
16
+
17
+ def step(self, model_output, timestep, sample, **kwargs):
18
+ delta_t = (
19
+ max(self.timesteps[self.timesteps < timestep]) - timestep
20
+ if any(self.timesteps < timestep)
21
+ else -timestep - 1
22
+ ) / 1000
23
+
24
+ current_t = (timestep + 1) / 1000.0
25
+ eta = (-delta_t * (current_t + delta_t) / current_t) ** 0.5
26
+
27
+ noise = torch.randn(
28
+ sample.shape,
29
+ generator=self.generator,
30
+ device=sample.device,
31
+ dtype=sample.dtype,
32
+ )
33
+ latents = sample + delta_t * model_output + eta * self.noise_scale * noise
34
+
35
+ return (latents,)
36
+
37
+ @classmethod
38
+ def from_scheduler(
39
+ cls, scheduler: UniPCMultistepScheduler, noise_scale=1.0, shift_gamma=5.0
40
+ ):
41
+ obj = cls.__new__(cls)
42
+ obj.__dict__ = scheduler.__dict__.copy()
43
+ obj.set_parameters(noise_scale, shift_gamma)
44
+ return obj
vibt/wan.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ from diffusers import WanPipeline
4
+ from safetensors.torch import load_file
5
+ from huggingface_hub import hf_hub_download
6
+
7
+
8
+ @torch.no_grad()
9
+ def encode_video(pipe: WanPipeline, video_frames):
10
+ video_tensor = pipe.video_processor.preprocess_video(video_frames).to(
11
+ dtype=pipe.dtype, device=pipe.device
12
+ )
13
+ posterior = pipe.vae.encode(video_tensor, return_dict=False)[0]
14
+ z = posterior.mode()
15
+ latents_mean = (
16
+ torch.tensor(pipe.vae.config.latents_mean)
17
+ .view(1, pipe.vae.config.z_dim, 1, 1, 1)
18
+ .to(z.device, z.dtype)
19
+ )
20
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
21
+ 1, pipe.vae.config.z_dim, 1, 1, 1
22
+ ).to(z.device, z.dtype)
23
+ latents = (z - latents_mean) * latents_std
24
+ return latents
25
+
26
+
27
+ @torch.no_grad()
28
+ def decode_latents(pipe: WanPipeline, latents):
29
+ latents = latents.to(pipe.vae.dtype)
30
+ latents_mean = (
31
+ torch.tensor(pipe.vae.config.latents_mean)
32
+ .view(1, pipe.vae.config.z_dim, 1, 1, 1)
33
+ .to(latents.device, latents.dtype)
34
+ )
35
+ latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(
36
+ 1, pipe.vae.config.z_dim, 1, 1, 1
37
+ ).to(latents.device, latents.dtype)
38
+ latents = latents / latents_std + latents_mean
39
+ video = pipe.vae.decode(latents, return_dict=False)[0]
40
+ video = pipe.video_processor.postprocess_video(video, output_type="np")
41
+ return video
42
+
43
+
44
+ def name_convert(n: str):
45
+ # blocks.* attention
46
+ m = re.match(
47
+ r"blocks\.(\d+)\.(self_attn|cross_attn)\.(q|k|v|o|norm_k|norm_q)\.(weight|bias)",
48
+ n,
49
+ )
50
+ if m:
51
+ b, kind, comp, suf = m.groups()
52
+ attn = "attn1" if kind == "self_attn" else "attn2"
53
+ if comp in ("q", "k", "v"):
54
+ return f"blocks.{b}.{attn}.to_{comp}.{suf}"
55
+ if comp == "o":
56
+ return f"blocks.{b}.{attn}.to_out.0.{suf}"
57
+ return f"blocks.{b}.{attn}.{comp}.{suf}"
58
+
59
+ # blocks.* ffn
60
+ m = re.match(r"blocks\.(\d+)\.ffn\.(0|2)\.(weight|bias)", n)
61
+ if m:
62
+ b, idx, suf = m.groups()
63
+ if idx == "0":
64
+ return f"blocks.{b}.ffn.net.0.proj.{suf}"
65
+ return f"blocks.{b}.ffn.net.2.{suf}"
66
+
67
+ # blocks.* norm3/modulation
68
+ m = re.match(r"blocks\.(\d+)\.norm3\.(weight|bias)", n)
69
+ if m:
70
+ b, suf = m.groups()
71
+ return f"blocks.{b}.norm2.{suf}"
72
+
73
+ m = re.match(r"blocks\.(\d+)\.modulation$", n)
74
+ if m:
75
+ b = m.group(1)
76
+ return f"blocks.{b}.scale_shift_table"
77
+
78
+ # patch_embedding
79
+ if n.startswith("patch_embedding."):
80
+ return n
81
+
82
+ # text / time embedding
83
+ m = re.match(r"text_embedding\.(0|2)\.(weight|bias)", n)
84
+ if m:
85
+ idx, suf = m.groups()
86
+ lin = "linear_1" if idx == "0" else "linear_2"
87
+ return f"condition_embedder.text_embedder.{lin}.{suf}"
88
+
89
+ m = re.match(r"time_embedding\.(0|2)\.(weight|bias)", n)
90
+ if m:
91
+ idx, suf = m.groups()
92
+ lin = "linear_1" if idx == "0" else "linear_2"
93
+ return f"condition_embedder.time_embedder.{lin}.{suf}"
94
+
95
+ m = re.match(r"time_projection\.1\.(weight|bias)", n)
96
+ if m:
97
+ suf = m.group(1)
98
+ return f"condition_embedder.time_proj.{suf}"
99
+
100
+ # head
101
+ if n == "head.head.weight":
102
+ return "proj_out.weight"
103
+ if n == "head.head.bias":
104
+ return "proj_out.bias"
105
+ if n == "head.modulation":
106
+ return "scale_shift_table"
107
+
108
+ return n
109
+
110
+
111
+ def load_vibt_weight(
112
+ transformer, repo_name="Yuanshi/Bridge", weight_path=None, local_path=None
113
+ ):
114
+ assert (
115
+ weight_path or local_path
116
+ ) is not None, "Either weight_path or local_path must be provided."
117
+
118
+ tensors = load_file(local_path or hf_hub_download(repo_name, weight_path))
119
+
120
+ new_tensors = {}
121
+
122
+ for key, value in tensors.items():
123
+ key = name_convert(key)
124
+ new_tensors[key] = value
125
+
126
+ for name, param in transformer.named_parameters():
127
+ device, dtype = param.device, param.dtype
128
+ if name in new_tensors:
129
+ assert (
130
+ param.shape == new_tensors[name].shape
131
+ ), f"{name}: {param.shape} != {new_tensors[name].shape}"
132
+ param.data = new_tensors[name].to(device=device, dtype=dtype)