deepbeepmeep commited on
Commit
21a01ff
·
1 Parent(s): cad98bc

added Phantom model support

Browse files
README.md CHANGED
@@ -10,7 +10,9 @@
10
 
11
 
12
  ## 🔥 Latest News!!
13
- * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below)
 
 
14
  * April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
15
  * April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
16
  * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
 
10
 
11
 
12
  ## 🔥 Latest News!!
13
+ * April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
14
+ * April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
15
+
16
  * April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
17
  * April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
18
  * April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
wan/diffusion_forcing.py CHANGED
@@ -31,6 +31,8 @@ class DTT2V:
31
  text_encoder_filename = None,
32
  quantizeTransformer = False,
33
  dtype = torch.bfloat16,
 
 
34
  ):
35
  self.device = torch.device(f"cuda")
36
  self.config = config
@@ -50,24 +52,22 @@ class DTT2V:
50
  self.vae_stride = config.vae_stride
51
  self.patch_size = config.patch_size
52
 
53
-
54
  self.vae = WanVAE(
55
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
56
  device=self.device)
57
 
58
  logging.info(f"Creating WanModel from {model_filename}")
59
  from mmgp import offload
60
  # model_filename = "model.safetensors"
61
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json"
62
  # offload.load_model_data(self.model, "recam.ckpt")
63
  # self.model.cpu()
64
- if self.dtype == torch.float16 and not "fp16" in model_filename:
65
- self.model.to(self.dtype)
66
- # offload.save_model(self.model, "rt1.3B.safetensors", config_file_path="config.json")
67
- # offload.save_model(self.model, "rtint8.safetensors", do_quantize= "config.json")
68
  # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
69
- if self.dtype == torch.float16:
70
- self.vae.model.to(self.dtype)
71
  self.model.eval().requires_grad_(False)
72
 
73
  self.scheduler = FlowUniPCMultistepScheduler()
@@ -228,11 +228,16 @@ class DTT2V:
228
  latent_height = height // 8
229
  latent_width = width // 8
230
 
231
- prompt_embeds = self.text_encoder([prompt], self.device)
232
- prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
 
 
233
  if self.do_classifier_free_guidance:
234
- negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
235
- negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
 
 
 
236
 
237
  self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
238
  init_timesteps = self.scheduler.timesteps
@@ -305,6 +310,17 @@ class DTT2V:
305
  del time_steps_comb
306
  from mmgp import offload
307
  freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
 
 
 
 
 
 
 
 
 
 
 
308
  for i, timestep_i in enumerate(tqdm(step_matrix)):
309
  offload.set_step_no_for_lora(self.model, i)
310
  update_mask_i = step_update_mask[i]
@@ -323,52 +339,45 @@ class DTT2V:
323
  * noise_factor
324
  )
325
  timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
326
- kwrags = {
327
  "x" : torch.stack([latent_model_input[0]]),
328
  "t" : timestep,
329
- "freqs" :freqs,
330
- "fps" : fps_embeds,
331
- "causal_block_size" : causal_block_size,
332
- "causal_attention" : causal_attention,
333
- "callback" : callback,
334
- "pipeline" : self,
335
  "current_step" : i,
336
- }
337
- kwrags.update(i2v_extra_kwrags)
338
-
339
- if not self.do_classifier_free_guidance:
340
- noise_pred = self.model(
341
- context=prompt_embeds,
342
- **kwrags,
343
- )[0]
344
- if self._interrupt:
345
- return None
346
- noise_pred= noise_pred.to(torch.float32)
347
- else:
348
- if joint_pass:
349
- noise_pred_cond, noise_pred_uncond = self.model(
350
- context=prompt_embeds,
351
- context2=negative_prompt_embeds,
352
- **kwrags,
353
- )
354
- if self._interrupt:
355
- return None
356
- else:
357
- noise_pred_cond = self.model(
358
- context=prompt_embeds,
359
  **kwrags,
360
  )[0]
361
- if self._interrupt:
362
- return None
363
- noise_pred_uncond = self.model(
364
- context=negative_prompt_embeds,
365
- )[0]
366
  if self._interrupt:
367
  return None
368
- noise_pred_cond= noise_pred_cond.to(torch.float32)
369
- noise_pred_uncond= noise_pred_uncond.to(torch.float32)
370
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
371
- del noise_pred_cond, noise_pred_uncond
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  for idx in range(valid_interval_start, valid_interval_end):
373
  if update_mask_i[idx].item():
374
  latents[0][:, idx] = sample_schedulers[idx].step(
 
31
  text_encoder_filename = None,
32
  quantizeTransformer = False,
33
  dtype = torch.bfloat16,
34
+ VAE_dtype = torch.float32,
35
+ mixed_precision_transformer = False,
36
  ):
37
  self.device = torch.device(f"cuda")
38
  self.config = config
 
52
  self.vae_stride = config.vae_stride
53
  self.patch_size = config.patch_size
54
 
 
55
  self.vae = WanVAE(
56
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
57
  device=self.device)
58
 
59
  logging.info(f"Creating WanModel from {model_filename}")
60
  from mmgp import offload
61
  # model_filename = "model.safetensors"
62
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
63
  # offload.load_model_data(self.model, "recam.ckpt")
64
  # self.model.cpu()
65
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
66
+ offload.change_dtype(self.model, dtype, True)
67
+ # offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
68
+ # offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
69
  # offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
70
+
 
71
  self.model.eval().requires_grad_(False)
72
 
73
  self.scheduler = FlowUniPCMultistepScheduler()
 
228
  latent_height = height // 8
229
  latent_width = width // 8
230
 
231
+ if self._interrupt:
232
+ return None
233
+ prompt_embeds = self.text_encoder([prompt], self.device)[0]
234
+ prompt_embeds = prompt_embeds.to(self.dtype).to(self.device)
235
  if self.do_classifier_free_guidance:
236
+ negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)[0]
237
+ negative_prompt_embeds = negative_prompt_embeds.to(self.dtype).to(self.device)
238
+
239
+ if self._interrupt:
240
+ return None
241
 
242
  self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
243
  init_timesteps = self.scheduler.timesteps
 
310
  del time_steps_comb
311
  from mmgp import offload
312
  freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
313
+ kwrags = {
314
+ "freqs" :freqs,
315
+ "fps" : fps_embeds,
316
+ "causal_block_size" : causal_block_size,
317
+ "causal_attention" : causal_attention,
318
+ "callback" : callback,
319
+ "pipeline" : self,
320
+ }
321
+ kwrags.update(i2v_extra_kwrags)
322
+
323
+
324
  for i, timestep_i in enumerate(tqdm(step_matrix)):
325
  offload.set_step_no_for_lora(self.model, i)
326
  update_mask_i = step_update_mask[i]
 
339
  * noise_factor
340
  )
341
  timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
342
+ kwrags.update({
343
  "x" : torch.stack([latent_model_input[0]]),
344
  "t" : timestep,
 
 
 
 
 
 
345
  "current_step" : i,
346
+ })
347
+
348
+ # with torch.autocast(device_type="cuda"):
349
+ if True:
350
+ if not self.do_classifier_free_guidance:
351
+ noise_pred = self.model(
352
+ context=[prompt_embeds],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  **kwrags,
354
  )[0]
 
 
 
 
 
355
  if self._interrupt:
356
  return None
357
+ noise_pred= noise_pred.to(torch.float32)
358
+ else:
359
+ if joint_pass:
360
+ noise_pred_cond, noise_pred_uncond = self.model(
361
+ context= [prompt_embeds, negative_prompt_embeds],
362
+ **kwrags,
363
+ )
364
+ if self._interrupt:
365
+ return None
366
+ else:
367
+ noise_pred_cond = self.model(
368
+ context=[prompt_embeds],
369
+ **kwrags,
370
+ )[0]
371
+ if self._interrupt:
372
+ return None
373
+ noise_pred_uncond = self.model(
374
+ context=[negative_prompt_embeds],
375
+ **kwrags,
376
+ )[0]
377
+ if self._interrupt:
378
+ return None
379
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
380
+ del noise_pred_cond, noise_pred_uncond
381
  for idx in range(valid_interval_start, valid_interval_end):
382
  if update_mask_i[idx].item():
383
  latents[0][:, idx] = sample_schedulers[idx].step(
wan/image2video.py CHANGED
@@ -48,47 +48,17 @@ class WanI2V:
48
  self,
49
  config,
50
  checkpoint_dir,
51
- rank=0,
52
- t5_fsdp=False,
53
- dit_fsdp=False,
54
- use_usp=False,
55
- t5_cpu=False,
56
- init_on_cpu=True,
57
- i2v720p= True,
58
  model_filename ="",
59
  text_encoder_filename="",
60
  quantizeTransformer = False,
61
- dtype = torch.bfloat16
 
 
62
  ):
63
- r"""
64
- Initializes the image-to-video generation model components.
65
-
66
- Args:
67
- config (EasyDict):
68
- Object containing model parameters initialized from config.py
69
- checkpoint_dir (`str`):
70
- Path to directory containing model checkpoints
71
- device_id (`int`, *optional*, defaults to 0):
72
- Id of target GPU device
73
- rank (`int`, *optional*, defaults to 0):
74
- Process rank for distributed training
75
- t5_fsdp (`bool`, *optional*, defaults to False):
76
- Enable FSDP sharding for T5 model
77
- dit_fsdp (`bool`, *optional*, defaults to False):
78
- Enable FSDP sharding for DiT model
79
- use_usp (`bool`, *optional*, defaults to False):
80
- Enable distribution strategy of USP.
81
- t5_cpu (`bool`, *optional*, defaults to False):
82
- Whether to place T5 model on CPU. Only works without t5_fsdp.
83
- Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
84
- init_on_cpu (`bool`, *optional*, defaults to True):
85
- """
86
  self.device = torch.device(f"cuda")
87
  self.config = config
88
- self.rank = rank
89
- self.use_usp = use_usp
90
- self.t5_cpu = t5_cpu
91
  self.dtype = dtype
 
92
  self.num_train_timesteps = config.num_train_timesteps
93
  self.param_dtype = config.param_dtype
94
  # shard_fn = partial(shard_model, device_id=device_id)
@@ -104,7 +74,7 @@ class WanI2V:
104
  self.vae_stride = config.vae_stride
105
  self.patch_size = config.patch_size
106
  self.vae = WanVAE(
107
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
108
  device=self.device)
109
 
110
  self.clip = CLIPModel(
@@ -118,11 +88,9 @@ class WanI2V:
118
  from mmgp import offload
119
 
120
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
121
- if self.dtype == torch.float16 and not "fp16" in model_filename:
122
- self.model.to(self.dtype)
123
  # offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
124
- if self.dtype == torch.float16:
125
- self.vae.model.to(self.dtype)
126
 
127
  # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
128
  self.model.eval().requires_grad_(False)
@@ -142,7 +110,6 @@ class WanI2V:
142
  guide_scale=5.0,
143
  n_prompt="",
144
  seed=-1,
145
- offload_model=True,
146
  callback = None,
147
  enable_RIFLEx = False,
148
  VAE_tile_size= 0,
@@ -212,13 +179,13 @@ class WanI2V:
212
  w = lat_w * self.vae_stride[2]
213
 
214
  clip_image_size = self.clip.model.image_size
215
- img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
216
  img = resize_lanczos(img, clip_image_size, clip_image_size)
217
- img = img.sub_(0.5).div_(0.5).to(self.device, self.dtype)
218
  if img2!= None:
219
- img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device, self.dtype)
220
  img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
221
- img2 = img2.sub_(0.5).div_(0.5).to(self.device, self.dtype)
222
 
223
  max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
224
 
@@ -244,25 +211,19 @@ class WanI2V:
244
  if n_prompt == "":
245
  n_prompt = self.sample_neg_prompt
246
 
 
 
 
247
  # preprocess
248
- if not self.t5_cpu:
249
- # self.text_encoder.model.to(self.device)
250
- context = self.text_encoder([input_prompt], self.device)
251
- context_null = self.text_encoder([n_prompt], self.device)
252
- if offload_model:
253
- self.text_encoder.model.cpu()
254
- else:
255
- context = self.text_encoder([input_prompt], torch.device('cpu'))
256
- context_null = self.text_encoder([n_prompt], torch.device('cpu'))
257
- context = [t.to(self.device) for t in context]
258
- context_null = [t.to(self.device) for t in context_null]
259
 
260
- context = [u.to(self.dtype) for u in context]
261
- context_null = [u.to(self.dtype) for u in context_null]
262
 
263
  clip_context = self.clip.visual([img[:, None, :, :]])
264
- if offload_model:
265
- self.clip.model.cpu()
266
 
267
  from mmgp import offload
268
  offload.last_offload_obj.unload_all()
@@ -270,23 +231,20 @@ class WanI2V:
270
  mean2 = 0
271
  enc= torch.concat([
272
  img_interpolated,
273
- torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.dtype),
274
  img_interpolated2,
275
  ], dim=1).to(self.device)
276
  else:
277
  enc= torch.concat([
278
  img_interpolated,
279
- torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.dtype)
280
  ], dim=1).to(self.device)
 
281
 
282
  lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
283
  y = torch.concat([msk, lat_y])
 
284
 
285
- @contextmanager
286
- def noop_no_sync():
287
- yield
288
-
289
- no_sync = getattr(self.model, 'no_sync', noop_no_sync)
290
 
291
  # evaluation mode
292
 
@@ -317,7 +275,7 @@ class WanI2V:
317
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
318
 
319
  arg_c = {
320
- 'context': [context[0]],
321
  'clip_fea': clip_context,
322
  'y': [y],
323
  'freqs' : freqs,
@@ -326,7 +284,7 @@ class WanI2V:
326
  }
327
 
328
  arg_null = {
329
- 'context': context_null,
330
  'clip_fea': clip_context,
331
  'y': [y],
332
  'freqs' : freqs,
@@ -335,8 +293,7 @@ class WanI2V:
335
  }
336
 
337
  arg_both= {
338
- 'context': [context[0]],
339
- 'context2': context_null,
340
  'clip_fea': clip_context,
341
  'y': [y],
342
  'freqs' : freqs,
@@ -344,9 +301,6 @@ class WanI2V:
344
  'callback' : callback
345
  }
346
 
347
- if offload_model:
348
- torch.cuda.empty_cache()
349
-
350
  if self.model.enable_teacache:
351
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
352
 
@@ -379,8 +333,6 @@ class WanI2V:
379
  )[0]
380
  if self._interrupt:
381
  return None
382
- if offload_model:
383
- torch.cuda.empty_cache()
384
  noise_pred_uncond = self.model(
385
  latent_model_input,
386
  t=timestep,
@@ -392,8 +344,7 @@ class WanI2V:
392
  if self._interrupt:
393
  return None
394
  del latent_model_input
395
- if offload_model:
396
- torch.cuda.empty_cache()
397
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
398
  noise_pred_text = noise_pred_cond
399
  if cfg_star_switch:
@@ -412,9 +363,6 @@ class WanI2V:
412
 
413
  del noise_pred_uncond
414
 
415
- latent = latent.to(
416
- torch.device('cpu') if offload_model else self.device)
417
-
418
  temp_x0 = sample_scheduler.step(
419
  noise_pred.unsqueeze(0),
420
  t,
@@ -429,29 +377,18 @@ class WanI2V:
429
  callback(i, latent, False)
430
 
431
 
432
- x0 = [latent.to(self.device, dtype=self.dtype)]
433
-
434
- if offload_model:
435
- self.model.cpu()
436
- torch.cuda.empty_cache()
437
 
438
- if self.rank == 0:
439
- # x0 = [lat_y]
440
- video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
441
 
442
- if any_end_frame and add_frames_for_end_image:
443
- # video[:, -1:] = img_interpolated2
444
- video = video[:, :-1]
445
 
446
- else:
447
- video = None
 
448
 
449
  del noise, latent
450
  del sample_scheduler
451
- if offload_model:
452
- gc.collect()
453
- torch.cuda.synchronize()
454
- if dist.is_initialized():
455
- dist.barrier()
456
 
457
  return video
 
48
  self,
49
  config,
50
  checkpoint_dir,
 
 
 
 
 
 
 
51
  model_filename ="",
52
  text_encoder_filename="",
53
  quantizeTransformer = False,
54
+ dtype = torch.bfloat16,
55
+ VAE_dtype = torch.float32,
56
+ mixed_precision_transformer = False
57
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  self.device = torch.device(f"cuda")
59
  self.config = config
 
 
 
60
  self.dtype = dtype
61
+ self.VAE_dtype = VAE_dtype
62
  self.num_train_timesteps = config.num_train_timesteps
63
  self.param_dtype = config.param_dtype
64
  # shard_fn = partial(shard_model, device_id=device_id)
 
74
  self.vae_stride = config.vae_stride
75
  self.patch_size = config.patch_size
76
  self.vae = WanVAE(
77
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype = VAE_dtype,
78
  device=self.device)
79
 
80
  self.clip = CLIPModel(
 
88
  from mmgp import offload
89
 
90
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
91
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
92
+ offload.change_dtype(self.model, dtype, True)
93
  # offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
 
 
94
 
95
  # offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
96
  self.model.eval().requires_grad_(False)
 
110
  guide_scale=5.0,
111
  n_prompt="",
112
  seed=-1,
 
113
  callback = None,
114
  enable_RIFLEx = False,
115
  VAE_tile_size= 0,
 
179
  w = lat_w * self.vae_stride[2]
180
 
181
  clip_image_size = self.clip.model.image_size
182
+ img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
183
  img = resize_lanczos(img, clip_image_size, clip_image_size)
184
+ img = img.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
185
  if img2!= None:
186
+ img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
187
  img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
188
+ img2 = img2.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
189
 
190
  max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
191
 
 
211
  if n_prompt == "":
212
  n_prompt = self.sample_neg_prompt
213
 
214
+ if self._interrupt:
215
+ return None
216
+
217
  # preprocess
218
+ context = self.text_encoder([input_prompt], self.device)[0]
219
+ context_null = self.text_encoder([n_prompt], self.device)[0]
220
+ context = context.to(self.dtype)
221
+ context_null = context_null.to(self.dtype)
 
 
 
 
 
 
 
222
 
223
+ if self._interrupt:
224
+ return None
225
 
226
  clip_context = self.clip.visual([img[:, None, :, :]])
 
 
227
 
228
  from mmgp import offload
229
  offload.last_offload_obj.unload_all()
 
231
  mean2 = 0
232
  enc= torch.concat([
233
  img_interpolated,
234
+ torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.VAE_dtype),
235
  img_interpolated2,
236
  ], dim=1).to(self.device)
237
  else:
238
  enc= torch.concat([
239
  img_interpolated,
240
+ torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.VAE_dtype)
241
  ], dim=1).to(self.device)
242
+ img, img2, img_interpolated, img_interpolated2 = None, None, None, None
243
 
244
  lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
245
  y = torch.concat([msk, lat_y])
246
+ lat_y = None
247
 
 
 
 
 
 
248
 
249
  # evaluation mode
250
 
 
275
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
276
 
277
  arg_c = {
278
+ 'context': [context],
279
  'clip_fea': clip_context,
280
  'y': [y],
281
  'freqs' : freqs,
 
284
  }
285
 
286
  arg_null = {
287
+ 'context': [context_null],
288
  'clip_fea': clip_context,
289
  'y': [y],
290
  'freqs' : freqs,
 
293
  }
294
 
295
  arg_both= {
296
+ 'context': [context, context_null],
 
297
  'clip_fea': clip_context,
298
  'y': [y],
299
  'freqs' : freqs,
 
301
  'callback' : callback
302
  }
303
 
 
 
 
304
  if self.model.enable_teacache:
305
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
306
 
 
333
  )[0]
334
  if self._interrupt:
335
  return None
 
 
336
  noise_pred_uncond = self.model(
337
  latent_model_input,
338
  t=timestep,
 
344
  if self._interrupt:
345
  return None
346
  del latent_model_input
347
+
 
348
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
349
  noise_pred_text = noise_pred_cond
350
  if cfg_star_switch:
 
363
 
364
  del noise_pred_uncond
365
 
 
 
 
366
  temp_x0 = sample_scheduler.step(
367
  noise_pred.unsqueeze(0),
368
  t,
 
377
  callback(i, latent, False)
378
 
379
 
380
+ # x0 = [latent.to(self.device, dtype=self.dtype)]
 
 
 
 
381
 
382
+ x0 = [latent]
 
 
383
 
384
+ # x0 = [lat_y]
385
+ video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
 
386
 
387
+ if any_end_frame and add_frames_for_end_image:
388
+ # video[:, -1:] = img_interpolated2
389
+ video = video[:, :-1]
390
 
391
  del noise, latent
392
  del sample_scheduler
 
 
 
 
 
393
 
394
  return video
wan/modules/model.py CHANGED
@@ -408,6 +408,9 @@ class WanAttentionBlock(nn.Module):
408
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
409
  """
410
  hint = None
 
 
 
411
  if self.block_id is not None and hints is not None:
412
  kwargs = {
413
  "grid_sizes" : grid_sizes,
@@ -434,9 +437,11 @@ class WanAttentionBlock(nn.Module):
434
  cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
435
  x_mod += cam_emb
436
 
437
- xlist = [x_mod]
438
  del x_mod
439
  y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
 
 
440
  if cam_emb != None:
441
  y = self.projector(y)
442
 
@@ -445,15 +450,18 @@ class WanAttentionBlock(nn.Module):
445
  x, y = reshape_latent(x , 1), reshape_latent(y , 1)
446
  del y
447
  y = self.norm3(x)
 
448
  ylist= [y]
449
  del y
450
- x += self.cross_attn(ylist, context)
 
451
  y = self.norm2(x)
452
 
453
  y = reshape_latent(y , latent_frames)
454
  y *= 1 + e[4]
455
  y += e[3]
456
  y = reshape_latent(y , 1)
 
457
 
458
  ffn = self.ffn[0]
459
  gelu = self.ffn[1]
@@ -469,7 +477,7 @@ class WanAttentionBlock(nn.Module):
469
  y_chunk[...] = ffn2(mlp_chunk)
470
  del mlp_chunk
471
  y = y.view(y_shape)
472
-
473
  x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
474
  x.addcmul_(y, e[5])
475
  x, y = reshape_latent(x , 1), reshape_latent(y , 1)
@@ -532,7 +540,6 @@ class Head(nn.Module):
532
  out_dim = math.prod(patch_size) * out_dim
533
  self.norm = WanLayerNorm(dim, eps)
534
  self.head = nn.Linear(dim, out_dim)
535
-
536
  # modulation
537
  self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
538
 
@@ -552,6 +559,7 @@ class Head(nn.Module):
552
  x *= (1 + e[1])
553
  x += e[0]
554
  x = reshape_latent(x , 1)
 
555
  x = self.head(x)
556
  return x
557
 
@@ -735,6 +743,44 @@ class WanModel(ModelMixin, ConfigMixin):
735
  block.projector.bias = nn.Parameter(torch.zeros(dim))
736
 
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
739
  rescale_func = np.poly1d(self.coefficients)
740
  e_list = []
@@ -788,7 +834,6 @@ class WanModel(ModelMixin, ConfigMixin):
788
  freqs = None,
789
  pipeline = None,
790
  current_step = 0,
791
- context2 = None,
792
  is_uncond=False,
793
  max_steps = 0,
794
  slg_layers=None,
@@ -797,7 +842,10 @@ class WanModel(ModelMixin, ConfigMixin):
797
  fps = None,
798
  causal_block_size = 1,
799
  causal_attention = False,
 
800
  ):
 
 
801
 
802
  if self.model_type == 'i2v':
803
  assert clip_fea is not None and y is not None
@@ -810,9 +858,9 @@ class WanModel(ModelMixin, ConfigMixin):
810
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
811
 
812
  # embeddings
813
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
814
- # grid_sizes = torch.stack(
815
- # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
816
 
817
  grid_sizes = [ list(u.shape[2:]) for u in x]
818
  embed_sizes = grid_sizes[0]
@@ -836,57 +884,46 @@ class WanModel(ModelMixin, ConfigMixin):
836
 
837
  x = [u.flatten(2).transpose(1, 2) for u in x]
838
  x = x[0]
 
 
 
839
 
840
  if t.dim() == 2:
841
  b, f = t.shape
842
  _flag_df = True
843
  else:
844
  _flag_df = False
845
-
846
  e = self.time_embedding(
847
- sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
848
  ) # b, dim
849
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
850
 
851
  if self.inject_sample_info:
852
  fps = torch.tensor(fps, dtype=torch.long, device=device)
853
 
854
- fps_emb = self.fps_embedding(fps).float()
855
  if _flag_df:
856
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
857
  else:
858
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
859
 
860
  # context
861
- context = self.text_embedding(
862
- torch.stack([
863
- torch.cat(
864
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
865
- for u in context
866
- ]))
867
- if context2!=None:
868
- context2 = self.text_embedding(
869
- torch.stack([
870
- torch.cat(
871
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
872
- for u in context2
873
- ]))
874
-
875
  if clip_fea is not None:
876
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
877
- context = torch.concat([context_clip, context], dim=1)
878
- if context2 != None:
879
- context2 = torch.concat([context_clip, context2], dim=1)
880
 
881
- joint_pass = context2 != None
 
882
  if joint_pass:
883
- x_list = [x, x.clone()]
884
- context_list = [context, context2]
 
 
885
  is_uncond = False
886
- else:
887
- x_list = [x]
888
- context_list = [context]
889
  del x
 
890
 
891
  # arguments
892
 
@@ -945,10 +982,7 @@ class WanModel(ModelMixin, ConfigMixin):
945
  if callback != None:
946
  callback(-1, None, False, True)
947
  if pipeline._interrupt:
948
- if joint_pass:
949
- return None, None
950
- else:
951
- return [None]
952
 
953
  if slg_layers is not None and block_idx in slg_layers:
954
  if is_uncond and not joint_pass:
@@ -983,10 +1017,7 @@ class WanModel(ModelMixin, ConfigMixin):
983
  x_list[i] = self.unpatchify(x, grid_sizes)
984
  del x
985
 
986
- if joint_pass:
987
- return x_list[0][0], x_list[1][0]
988
- else:
989
- return [u.float() for u in x_list[0]]
990
 
991
  def unpatchify(self, x, grid_sizes):
992
  r"""
 
408
  freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
409
  """
410
  hint = None
411
+ attention_dtype = self.self_attn.q.weight.dtype
412
+ dtype = x.dtype
413
+
414
  if self.block_id is not None and hints is not None:
415
  kwargs = {
416
  "grid_sizes" : grid_sizes,
 
437
  cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
438
  x_mod += cam_emb
439
 
440
+ xlist = [x_mod.to(attention_dtype)]
441
  del x_mod
442
  y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
443
+ y = y.to(dtype)
444
+
445
  if cam_emb != None:
446
  y = self.projector(y)
447
 
 
450
  x, y = reshape_latent(x , 1), reshape_latent(y , 1)
451
  del y
452
  y = self.norm3(x)
453
+ y = y.to(attention_dtype)
454
  ylist= [y]
455
  del y
456
+ x += self.cross_attn(ylist, context).to(dtype)
457
+
458
  y = self.norm2(x)
459
 
460
  y = reshape_latent(y , latent_frames)
461
  y *= 1 + e[4]
462
  y += e[3]
463
  y = reshape_latent(y , 1)
464
+ y = y.to(attention_dtype)
465
 
466
  ffn = self.ffn[0]
467
  gelu = self.ffn[1]
 
477
  y_chunk[...] = ffn2(mlp_chunk)
478
  del mlp_chunk
479
  y = y.view(y_shape)
480
+ y = y.to(dtype)
481
  x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
482
  x.addcmul_(y, e[5])
483
  x, y = reshape_latent(x , 1), reshape_latent(y , 1)
 
540
  out_dim = math.prod(patch_size) * out_dim
541
  self.norm = WanLayerNorm(dim, eps)
542
  self.head = nn.Linear(dim, out_dim)
 
543
  # modulation
544
  self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
545
 
 
559
  x *= (1 + e[1])
560
  x += e[0]
561
  x = reshape_latent(x , 1)
562
+ x= x.to(self.head.weight.dtype)
563
  x = self.head(x)
564
  return x
565
 
 
743
  block.projector.bias = nn.Parameter(torch.zeros(dim))
744
 
745
 
746
+ def lock_layers_dtypes(self, dtype = torch.float32, force = False):
747
+ count = 0
748
+ layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
749
+ self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
750
+ if hasattr(self, "fps_embedding"):
751
+ layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
752
+
753
+ if hasattr(self, "vace_patch_embedding"):
754
+ layer_list += [self.vace_patch_embedding]
755
+ layer_list += [self.vace_blocks[0].before_proj]
756
+ for block in self.vace_blocks:
757
+ layer_list += [block.after_proj, block.norm3]
758
+
759
+ # cam master
760
+ if hasattr(self.blocks[0], "projector"):
761
+ for block in self.blocks:
762
+ layer_list += [block.projector]
763
+
764
+ for block in self.blocks:
765
+ layer_list += [block.norm3]
766
+ for layer in layer_list:
767
+ if hasattr(layer, "weight"):
768
+ if layer.weight.dtype == dtype :
769
+ count += 1
770
+ elif force:
771
+ if hasattr(layer, "weight"):
772
+ layer.weight.data = layer.weight.data.to(dtype)
773
+ if hasattr(layer, "bias"):
774
+ layer.bias.data = layer.bias.data.to(dtype)
775
+ count += 1
776
+
777
+ layer._lock_dtype = dtype
778
+
779
+
780
+ if count > 0:
781
+ self._lock_dtype = dtype
782
+
783
+
784
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
785
  rescale_func = np.poly1d(self.coefficients)
786
  e_list = []
 
834
  freqs = None,
835
  pipeline = None,
836
  current_step = 0,
 
837
  is_uncond=False,
838
  max_steps = 0,
839
  slg_layers=None,
 
842
  fps = None,
843
  causal_block_size = 1,
844
  causal_attention = False,
845
+ x_neg = None
846
  ):
847
+ # dtype = self.blocks[0].self_attn.q.weight.dtype
848
+ dtype = self.patch_embedding.weight.dtype
849
 
850
  if self.model_type == 'i2v':
851
  assert clip_fea is not None and y is not None
 
858
  x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
859
 
860
  # embeddings
861
+ x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
862
+ if x_neg !=None:
863
+ x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
864
 
865
  grid_sizes = [ list(u.shape[2:]) for u in x]
866
  embed_sizes = grid_sizes[0]
 
884
 
885
  x = [u.flatten(2).transpose(1, 2) for u in x]
886
  x = x[0]
887
+ if x_neg !=None:
888
+ x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
889
+ x_neg = x_neg[0]
890
 
891
  if t.dim() == 2:
892
  b, f = t.shape
893
  _flag_df = True
894
  else:
895
  _flag_df = False
 
896
  e = self.time_embedding(
897
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
898
  ) # b, dim
899
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
900
 
901
  if self.inject_sample_info:
902
  fps = torch.tensor(fps, dtype=torch.long, device=device)
903
 
904
+ fps_emb = self.fps_embedding(fps).to(dtype) # float()
905
  if _flag_df:
906
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
907
  else:
908
  e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
909
 
910
  # context
911
+ context = [self.text_embedding( torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))] ).unsqueeze(0) ) for u in context ]
912
+
 
 
 
 
 
 
 
 
 
 
 
 
913
  if clip_fea is not None:
914
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
915
+ context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
 
 
916
 
917
+ joint_pass = len(context) > 0
918
+ x_list = [x]
919
  if joint_pass:
920
+ if x_neg == None:
921
+ x_list += [x.clone() for i in range(len(context) - 1) ]
922
+ else:
923
+ x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
924
  is_uncond = False
 
 
 
925
  del x
926
+ context_list = context
927
 
928
  # arguments
929
 
 
982
  if callback != None:
983
  callback(-1, None, False, True)
984
  if pipeline._interrupt:
985
+ return [None] * len(x_list)
 
 
 
986
 
987
  if slg_layers is not None and block_idx in slg_layers:
988
  if is_uncond and not joint_pass:
 
1017
  x_list[i] = self.unpatchify(x, grid_sizes)
1018
  del x
1019
 
1020
+ return [x[0].float() for x in x_list]
 
 
 
1021
 
1022
  def unpatchify(self, x, grid_sizes):
1023
  r"""
wan/modules/vae.py CHANGED
@@ -752,7 +752,8 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
752
  logging.info(f'loading {pretrained_path}')
753
  # model.load_state_dict(
754
  # torch.load(pretrained_path, map_location=device), assign=True)
755
- offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
 
756
  return model
757
 
758
 
@@ -782,20 +783,22 @@ class WanVAE:
782
  self.model = _video_vae(
783
  pretrained_path=vae_pth,
784
  z_dim=z_dim,
785
- ).eval() #.requires_grad_(False).to(device)
786
-
787
  def encode(self, videos, tile_size = 256, any_end_frame = False):
788
  """
789
  videos: A list of videos each with shape [C, T, H, W].
790
  """
 
 
791
  if tile_size > 0:
792
- return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
793
  else:
794
- return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
795
 
796
 
797
  def decode(self, zs, tile_size, any_end_frame = False):
798
  if tile_size > 0:
799
- return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
800
  else:
801
- return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
 
752
  logging.info(f'loading {pretrained_path}')
753
  # model.load_state_dict(
754
  # torch.load(pretrained_path, map_location=device), assign=True)
755
+ # offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
756
+ offload.load_model_data(model, pretrained_path.replace(".pth", ".safetensors"), writable_tensors= False)
757
  return model
758
 
759
 
 
783
  self.model = _video_vae(
784
  pretrained_path=vae_pth,
785
  z_dim=z_dim,
786
+ ).to(dtype).eval() #.requires_grad_(False).to(device)
787
+
788
  def encode(self, videos, tile_size = 256, any_end_frame = False):
789
  """
790
  videos: A list of videos each with shape [C, T, H, W].
791
  """
792
+ original_dtype = videos[0].dtype
793
+
794
  if tile_size > 0:
795
+ return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
796
  else:
797
+ return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
798
 
799
 
800
  def decode(self, zs, tile_size, any_end_frame = False):
801
  if tile_size > 0:
802
+ return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
803
  else:
804
+ return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
wan/text2video.py CHANGED
@@ -52,7 +52,9 @@ class WanT2V:
52
  model_filename = None,
53
  text_encoder_filename = None,
54
  quantizeTransformer = False,
55
- dtype = torch.bfloat16
 
 
56
  ):
57
  self.device = torch.device(f"cuda")
58
  self.config = config
@@ -71,24 +73,23 @@ class WanT2V:
71
 
72
  self.vae_stride = config.vae_stride
73
  self.patch_size = config.patch_size
74
-
75
 
76
  self.vae = WanVAE(
77
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
78
  device=self.device)
79
 
80
  logging.info(f"Creating WanModel from {model_filename}")
81
  from mmgp import offload
82
-
83
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
84
- # offload.load_model_data(self.model, "recam.ckpt")
 
 
85
  # self.model.cpu()
86
- # offload.save_model(self.model, "recam.safetensors")
87
- if self.dtype == torch.float16 and not "fp16" in model_filename:
88
- self.model.to(self.dtype)
89
- # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
90
- if self.dtype == torch.float16:
91
- self.vae.model.to(self.dtype)
92
  self.model.eval().requires_grad_(False)
93
 
94
 
@@ -252,6 +253,15 @@ class WanT2V:
252
 
253
  return self.vae.decode(trimed_zs, tile_size= tile_size)
254
 
 
 
 
 
 
 
 
 
 
255
  def generate(self,
256
  input_prompt,
257
  input_frames= None,
@@ -320,8 +330,15 @@ class WanT2V:
320
  seed_g = torch.Generator(device=self.device)
321
  seed_g.manual_seed(seed)
322
 
323
- context = self.text_encoder([input_prompt], self.device)
324
- context_null = self.text_encoder([n_prompt], self.device)
 
 
 
 
 
 
 
325
  if target_camera != None:
326
  size = (source_video.shape[2], source_video.shape[1])
327
  source_video = source_video.to(dtype=self.dtype , device=self.device)
@@ -346,8 +363,12 @@ class WanT2V:
346
  target_shape = list(z0[0].shape)
347
  target_shape[0] = int(target_shape[0] / 2)
348
  else:
 
 
 
 
349
  F = frame_num
350
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
351
  size[1] // self.vae_stride[1],
352
  size[0] // self.vae_stride[2])
353
 
@@ -355,8 +376,8 @@ class WanT2V:
355
  (self.patch_size[1] * self.patch_size[2]) *
356
  target_shape[1])
357
 
358
- context = [u.to(self.dtype) for u in context]
359
- context_null = [u.to(self.dtype) for u in context_null]
360
 
361
  noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
362
 
@@ -393,21 +414,15 @@ class WanT2V:
393
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
394
  else:
395
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
396
- arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
397
- arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
398
- arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
399
 
400
  if target_camera != None:
401
- recam_dict = {'cam_emb': cam_emb}
402
- arg_c.update(recam_dict)
403
- arg_null.update(recam_dict)
404
- arg_both.update(recam_dict)
405
 
406
  if input_frames != None:
407
- vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
408
- arg_c.update(vace_dict)
409
- arg_null.update(vace_dict)
410
- arg_both.update(vace_dict)
411
 
412
  if self.model.enable_teacache:
413
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
@@ -424,39 +439,68 @@ class WanT2V:
424
  timestep = [t]
425
  offload.set_step_no_for_lora(self.model, i)
426
  timestep = torch.stack(timestep)
427
-
 
428
  if joint_pass:
429
- noise_pred_cond, noise_pred_uncond = self.model(
430
- latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
 
 
 
 
 
 
431
  if self._interrupt:
432
  return None
433
  else:
434
- noise_pred_cond = self.model(
435
- latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
436
- if self._interrupt:
437
- return None
438
- noise_pred_uncond = self.model(
439
- latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
440
- if self._interrupt:
441
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
 
443
  # del latent_model_input
444
 
445
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
446
- noise_pred_text = noise_pred_cond
447
- if cfg_star_switch:
448
- positive_flat = noise_pred_text.view(batch_size, -1)
449
- negative_flat = noise_pred_uncond.view(batch_size, -1)
 
 
 
 
 
450
 
451
- alpha = optimized_scale(positive_flat,negative_flat)
452
- alpha = alpha.view(batch_size, 1, 1, 1)
453
 
454
- if (i <= cfg_zero_step):
455
- noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
456
- else:
457
- noise_pred_uncond *= alpha
458
- noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
459
- del noise_pred_uncond
460
 
461
  temp_x0 = sample_scheduler.step(
462
  noise_pred[:, :target_shape[1]].unsqueeze(0),
@@ -473,8 +517,12 @@ class WanT2V:
473
  x0 = latents
474
 
475
  if input_frames == None:
 
 
 
476
  videos = self.vae.decode(x0, VAE_tile_size)
477
  else:
 
478
  videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
479
 
480
  del latents
 
52
  model_filename = None,
53
  text_encoder_filename = None,
54
  quantizeTransformer = False,
55
+ dtype = torch.bfloat16,
56
+ VAE_dtype = torch.float32,
57
+ mixed_precision_transformer = False
58
  ):
59
  self.device = torch.device(f"cuda")
60
  self.config = config
 
73
 
74
  self.vae_stride = config.vae_stride
75
  self.patch_size = config.patch_size
 
76
 
77
  self.vae = WanVAE(
78
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
79
  device=self.device)
80
 
81
  logging.info(f"Creating WanModel from {model_filename}")
82
  from mmgp import offload
83
+ # model_filename
84
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
85
+ # offload.load_model_data(self.model, "e:/vace.safetensors")
86
+ # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
87
+ # self.model.to(torch.bfloat16)
88
  # self.model.cpu()
89
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
90
+ offload.change_dtype(self.model, dtype, True)
91
+ # offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
92
+ # offload.save_model(self.model, "phantom_1.3B.safetensors")
 
 
93
  self.model.eval().requires_grad_(False)
94
 
95
 
 
253
 
254
  return self.vae.decode(trimed_zs, tile_size= tile_size)
255
 
256
+ def get_vae_latents(self, ref_images, device, tile_size= 0):
257
+ ref_vae_latents = []
258
+ for ref_image in ref_images:
259
+ ref_image = TF.to_tensor(ref_image).sub_(0.5).div_(0.5).to(self.device)
260
+ img_vae_latent = self.vae.encode([ref_image.unsqueeze(1)], tile_size= tile_size)
261
+ ref_vae_latents.append(img_vae_latent[0])
262
+
263
+ return torch.cat(ref_vae_latents, dim=1)
264
+
265
  def generate(self,
266
  input_prompt,
267
  input_frames= None,
 
330
  seed_g = torch.Generator(device=self.device)
331
  seed_g.manual_seed(seed)
332
 
333
+ if self._interrupt:
334
+ return None
335
+ context = self.text_encoder([input_prompt], self.device)[0]
336
+ context_null = self.text_encoder([n_prompt], self.device)[0]
337
+ context = context.to(self.dtype)
338
+ context_null = context_null.to(self.dtype)
339
+ input_ref_images_neg = None
340
+ phantom = False
341
+
342
  if target_camera != None:
343
  size = (source_video.shape[2], source_video.shape[1])
344
  source_video = source_video.to(dtype=self.dtype , device=self.device)
 
363
  target_shape = list(z0[0].shape)
364
  target_shape[0] = int(target_shape[0] / 2)
365
  else:
366
+ if input_ref_images != None: # Phantom Ref images
367
+ phantom = True
368
+ input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
369
+ input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
370
  F = frame_num
371
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
372
  size[1] // self.vae_stride[1],
373
  size[0] // self.vae_stride[2])
374
 
 
376
  (self.patch_size[1] * self.patch_size[2]) *
377
  target_shape[1])
378
 
379
+ if self._interrupt:
380
+ return None
381
 
382
  noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
383
 
 
414
  freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
415
  else:
416
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
417
+
418
+ kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
 
419
 
420
  if target_camera != None:
421
+ kwargs.update({'cam_emb': cam_emb})
 
 
 
422
 
423
  if input_frames != None:
424
+ kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
425
+
 
 
426
 
427
  if self.model.enable_teacache:
428
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
 
439
  timestep = [t]
440
  offload.set_step_no_for_lora(self.model, i)
441
  timestep = torch.stack(timestep)
442
+ kwargs["current_step"] = i
443
+ kwargs["t"] = timestep
444
  if joint_pass:
445
+ if phantom:
446
+ pos_it, pos_i, neg = self.model(
447
+ [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
448
+ x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
449
+ context = [context, context_null, context_null], **kwargs)
450
+ else:
451
+ noise_pred_cond, noise_pred_uncond = self.model(
452
+ latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
453
  if self._interrupt:
454
  return None
455
  else:
456
+ if phantom:
457
+ pos_it = self.model(
458
+ [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
459
+ )[0]
460
+ if self._interrupt:
461
+ return None
462
+ pos_i = self.model(
463
+ [torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
464
+ )[0]
465
+ if self._interrupt:
466
+ return None
467
+ neg = self.model(
468
+ [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
469
+ )[0]
470
+ if self._interrupt:
471
+ return None
472
+ else:
473
+ noise_pred_cond = self.model(
474
+ latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
475
+ if self._interrupt:
476
+ return None
477
+ noise_pred_uncond = self.model(
478
+ latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
479
+ if self._interrupt:
480
+ return None
481
 
482
  # del latent_model_input
483
 
484
  # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
485
+ if phantom:
486
+ guide_scale_img= 5.0
487
+ guide_scale_text= guide_scale #7.5
488
+ noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
489
+ else:
490
+ noise_pred_text = noise_pred_cond
491
+ if cfg_star_switch:
492
+ positive_flat = noise_pred_text.view(batch_size, -1)
493
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
494
 
495
+ alpha = optimized_scale(positive_flat,negative_flat)
496
+ alpha = alpha.view(batch_size, 1, 1, 1)
497
 
498
+ if (i <= cfg_zero_step):
499
+ noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
500
+ else:
501
+ noise_pred_uncond *= alpha
502
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
503
+ noise_pred_uncond, noise_pred_cond, noise_pred_text, pos_it, pos_i, neg = None, None, None, None, None, None
504
 
505
  temp_x0 = sample_scheduler.step(
506
  noise_pred[:, :target_shape[1]].unsqueeze(0),
 
517
  x0 = latents
518
 
519
  if input_frames == None:
520
+ if phantom:
521
+ # phantom post processing
522
+ x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
523
  videos = self.vae.decode(x0, VAE_tile_size)
524
  else:
525
+ # vace post processing
526
  videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
527
 
528
  del latents
wan/utils/utils.py CHANGED
@@ -69,18 +69,29 @@ def remove_background(img, session=None):
69
 
70
 
71
 
72
- def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
73
  if rm_background:
74
  session = new_session()
75
 
76
  output_list =[]
77
  for img in img_list:
78
  width, height = img.size
79
- scale = (budget_height * budget_width / (height * width))**(1/2)
80
- new_height = int( round(height * scale / 16) * 16)
81
- new_width = int( round(width * scale / 16) * 16)
82
 
83
- resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  if rm_background:
85
  resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
86
  output_list.append(resized_image)
 
69
 
70
 
71
 
72
+ def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
73
  if rm_background:
74
  session = new_session()
75
 
76
  output_list =[]
77
  for img in img_list:
78
  width, height = img.size
 
 
 
79
 
80
+ if fit_into_canvas:
81
+ white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
82
+ scale = min(budget_height / height, budget_width / width)
83
+ new_height = int(height * scale)
84
+ new_width = int(width * scale)
85
+ resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
86
+ top = (budget_height - new_height) // 2
87
+ left = (budget_width - new_width) // 2
88
+ white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)
89
+ resized_image = Image.fromarray(white_canvas)
90
+ else:
91
+ scale = (budget_height * budget_width / (height * width))**(1/2)
92
+ new_height = int( round(height * scale / 16) * 16)
93
+ new_width = int( round(width * scale / 16) * 16)
94
+ resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
95
  if rm_background:
96
  resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
97
  output_list.append(resized_image)
wgp.py CHANGED
@@ -40,7 +40,7 @@ global_queue_ref = []
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
- target_mmgp_version = "3.4.0"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
@@ -133,10 +133,11 @@ def process_prompt_and_add_tasks(state, model_choice):
133
 
134
  model_filename = state["model_filename"]
135
 
136
- if model_choice != get_model_type(model_filename):
 
 
137
  raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
138
-
139
- inputs = state.get(get_model_type(model_filename), None)
140
  inputs["state"] = state
141
  inputs.pop("lset_name")
142
  if inputs == None:
@@ -176,7 +177,7 @@ def process_prompt_and_add_tasks(state, model_choice):
176
  gr.Info(f"Resolution {resolution} not supported by image 2 video")
177
  return
178
 
179
- if "1.3B" in model_filename and width * height > 848*480:
180
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
181
  return
182
 
@@ -186,8 +187,28 @@ def process_prompt_and_add_tasks(state, model_choice):
186
  if video_length > sliding_window_size:
187
  gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
188
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- if "diffusion_forcing" in model_filename:
 
 
 
 
 
 
 
 
191
  image_start = inputs["image_start"]
192
  video_source = inputs["video_source"]
193
  keep_frames_video_source = inputs["keep_frames_video_source"]
@@ -1362,10 +1383,6 @@ quantizeTransformer = args.quantize_transformer
1362
  check_loras = args.check_loras ==1
1363
  advanced = args.advanced
1364
 
1365
- transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors",
1366
- "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"]
1367
- transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
1368
- transformer_choices = transformer_choices_t2v + transformer_choices_i2v
1369
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
1370
  server_config_filename = "wgp_config.json"
1371
  if not os.path.isdir("settings"):
@@ -1401,11 +1418,32 @@ else:
1401
  text = reader.read()
1402
  server_config = json.loads(text)
1403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1404
 
1405
- model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B"]
1406
  model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
1407
  "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
1408
- "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B" }
 
 
1409
 
1410
 
1411
  def get_model_type(model_filename):
@@ -1417,29 +1455,47 @@ def get_model_type(model_filename):
1417
  def test_class_i2v(model_filename):
1418
  return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
1419
 
1420
- def get_model_name(model_filename):
1421
  if "Fun" in model_filename:
1422
  model_name = "Fun InP image2video"
1423
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
1424
  elif "Vace" in model_filename:
1425
  model_name = "Vace ControlNet"
1426
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
1427
  elif "image" in model_filename:
1428
  model_name = "Wan2.1 image2video"
1429
  model_name += " 720p" if "720p" in model_filename else " 480p"
 
 
 
 
1430
  elif "recam" in model_filename:
1431
  model_name = "ReCamMaster"
1432
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
1433
  elif "FLF2V" in model_filename:
1434
  model_name = "Wan2.1 FLF2V"
1435
  model_name += " 720p" if "720p" in model_filename else " 480p"
 
1436
  elif "sky_reels2_diffusion_forcing" in model_filename:
1437
- model_name = "SkyReels2 diffusion forcing"
 
 
 
 
1438
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
 
 
 
 
1439
  else:
1440
  model_name = "Wan2.1 text2video"
1441
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1442
-
 
1443
  return model_name
1444
 
1445
 
@@ -1493,13 +1549,28 @@ def get_default_settings(filename):
1493
  "slg_end_perc": 90
1494
  }
1495
 
1496
- if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B"):
1497
  ui_defaults.update({
1498
  "guidance_scale": 6.0,
1499
  "flow_shift": 8,
1500
- "sliding_window_discard_last_frames" : 0
 
 
 
 
 
1501
  })
1502
 
 
 
 
 
 
 
 
 
 
 
1503
  with open(defaults_filename, "w", encoding="utf-8") as f:
1504
  json.dump(ui_defaults, f, indent=4)
1505
  else:
@@ -1649,7 +1720,7 @@ def download_models(transformer_filename, text_encoder_filename):
1649
  from huggingface_hub import hf_hub_download, snapshot_download
1650
  repoId = "DeepBeepMeep/Wan2.1"
1651
  sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
1652
- fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
1653
  targetRoot = "ckpts/"
1654
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1655
  if len(files)==0:
@@ -1763,12 +1834,12 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
1763
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
1764
 
1765
 
1766
- def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
1767
 
1768
  cfg = WAN_CONFIGS['t2v-14B']
1769
  # cfg = WAN_CONFIGS['t2v-1.3B']
1770
  print(f"Loading '{model_filename}' model...")
1771
- if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B"):
1772
  model_factory = wan.DTT2V
1773
  else:
1774
  model_factory = wan.WanT2V
@@ -1779,52 +1850,32 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
1779
  model_filename=model_filename,
1780
  text_encoder_filename= text_encoder_filename,
1781
  quantizeTransformer = quantizeTransformer,
1782
- dtype = dtype
 
 
1783
  )
1784
 
1785
  pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
1786
 
1787
  return wan_model, pipe
1788
 
1789
- def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
1790
 
1791
  print(f"Loading '{model_filename}' model...")
1792
 
1793
- if value == '720P':
1794
- cfg = WAN_CONFIGS['i2v-14B']
1795
- wan_model = wan.WanI2V(
1796
- config=cfg,
1797
- checkpoint_dir="ckpts",
1798
- rank=0,
1799
- t5_fsdp=False,
1800
- dit_fsdp=False,
1801
- use_usp=False,
1802
- i2v720p= True,
1803
- model_filename=model_filename,
1804
- text_encoder_filename=text_encoder_filename,
1805
- quantizeTransformer = quantizeTransformer,
1806
- dtype = dtype
1807
- )
1808
- pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
1809
-
1810
- elif value == '480P':
1811
- cfg = WAN_CONFIGS['i2v-14B']
1812
- wan_model = wan.WanI2V(
1813
- config=cfg,
1814
- checkpoint_dir="ckpts",
1815
- rank=0,
1816
- t5_fsdp=False,
1817
- dit_fsdp=False,
1818
- use_usp=False,
1819
- i2v720p= False,
1820
- model_filename=model_filename,
1821
- text_encoder_filename=text_encoder_filename,
1822
- quantizeTransformer = quantizeTransformer,
1823
- dtype = dtype
1824
- )
1825
- pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
1826
- else:
1827
- raise Exception("Model i2v {value} not supported")
1828
  return wan_model, pipe
1829
 
1830
 
@@ -1836,18 +1887,22 @@ def load_models(model_filename):
1836
  perc_reserved_mem_max = args.perc_reserved_mem_max
1837
 
1838
  major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
1839
- default_dtype = torch.float16 if major < 8 else torch.bfloat16
1840
- # default_dtype = torch.bfloat16
1841
- if default_dtype == torch.float16 or args.fp16:
1842
  print("Switching to f16 model as GPU architecture doesn't support bf16")
 
 
 
 
1843
  if "quanto" in model_filename:
1844
  model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
1845
  download_models(model_filename, text_encoder_filename)
 
 
1846
  if test_class_i2v(model_filename):
1847
  res720P = "720p" in model_filename
1848
- wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype )
1849
  else:
1850
- wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype)
1851
  wan_model._model_file_name = model_filename
1852
  kwargs = { "extraModelsToQuantize": None}
1853
  if profile == 2 or profile == 4:
@@ -1888,8 +1943,13 @@ def get_default_flow(filename, i2v):
1888
 
1889
 
1890
  def generate_header(model_filename, compile, attention_mode):
1891
-
1892
- header = "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
 
 
 
 
 
1893
  if attention_mode not in attention_modes_installed:
1894
  header += " -NOT INSTALLED-"
1895
  elif attention_mode not in attention_modes_supported:
@@ -1907,6 +1967,8 @@ def generate_header(model_filename, compile, attention_mode):
1907
  def apply_changes( state,
1908
  transformer_types_choices,
1909
  text_encoder_choice,
 
 
1910
  save_path_choice,
1911
  attention_choice,
1912
  compile_choice,
@@ -1922,7 +1984,7 @@ def apply_changes( state,
1922
  if args.lock_config:
1923
  return
1924
  if gen_in_progress:
1925
- return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
1926
  global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1927
  server_config = {"attention_mode" : attention_choice,
1928
  "transformer_types": transformer_types_choices,
@@ -1931,6 +1993,8 @@ def apply_changes( state,
1931
  "compile" : compile_choice,
1932
  "profile" : profile_choice,
1933
  "vae_config" : vae_config_choice,
 
 
1934
  "metadata_type": metadata_choice,
1935
  "transformer_quantization" : quantization_choice,
1936
  "boost" : boost_choice,
@@ -2052,12 +2116,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps):
2052
  return callback
2053
  def abort_generation(state):
2054
  gen = get_gen_info(state)
2055
- if "in_progress" in gen:
2056
 
2057
- gen["abort"] = True
2058
- gen["extra_orders"] = 0
2059
- if wan_model != None:
2060
- wan_model._interrupt= True
2061
  msg = "Processing Request to abort Current Generation"
2062
  gen["status"] = msg
2063
  gr.Info(msg)
@@ -2140,13 +2201,6 @@ def finalize_generation(state):
2140
  return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
2141
 
2142
 
2143
- def refresh_gallery_on_trigger(state):
2144
- gen = get_gen_info(state)
2145
-
2146
- if(gen.get("update_gallery", False)):
2147
- gen['update_gallery'] = False
2148
- return gr.update(value=gen.get("file_list", []))
2149
-
2150
  def select_video(state , event_data: gr.EventData):
2151
  data= event_data._data
2152
  gen = get_gen_info(state)
@@ -2385,6 +2439,8 @@ def generate_video(
2385
  # VAE Tiling
2386
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
2387
  if vae_config == 0:
 
 
2388
  if device_mem_capacity >= 24000:
2389
  use_vae_config = 1
2390
  elif device_mem_capacity >= 8000:
@@ -2497,6 +2553,7 @@ def generate_video(
2497
  max_frames_to_generate = video_length
2498
  diffusion_forcing = "diffusion_forcing" in model_filename
2499
  vace = "Vace" in model_filename
 
2500
  if diffusion_forcing or vace:
2501
  reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
2502
  if diffusion_forcing and source_video != None:
@@ -2536,6 +2593,7 @@ def generate_video(
2536
  extra_windows = 0
2537
  guide_start_frame = 0
2538
  video_length = first_window_video_length
 
2539
  while not abort:
2540
  if sliding_window:
2541
  prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
@@ -2550,7 +2608,9 @@ def generate_video(
2550
  window_no += 1
2551
  gen["window_no"] = window_no
2552
 
2553
- if diffusion_forcing:
 
 
2554
  if video_source != None and len(video_source) > 0 and window_no == 1:
2555
  keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
2556
  prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
@@ -2559,7 +2619,7 @@ def generate_video(
2559
  prefix_video_frames_count = prefix_video.shape[1]
2560
  pre_video_guide = prefix_video[:, -reuse_frames:]
2561
 
2562
- if vace:
2563
  # video_prompt_type = video_prompt_type +"G"
2564
  image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
2565
  video_guide_copy = video_guide
@@ -2610,7 +2670,7 @@ def generate_video(
2610
  progress_args = [0, status + " - Encoding Prompt"]
2611
  send_cmd("progress", progress_args)
2612
 
2613
- samples = torch.empty( (1,2)) #for testing
2614
  # if False:
2615
 
2616
  try:
@@ -2633,7 +2693,6 @@ def generate_video(
2633
  guide_scale=guidance_scale,
2634
  n_prompt=negative_prompt,
2635
  seed=seed,
2636
- offload_model=False,
2637
  callback=callback,
2638
  enable_RIFLEx = enable_RIFLEx,
2639
  VAE_tile_size = VAE_tile_size,
@@ -2738,6 +2797,7 @@ def generate_video(
2738
  if samples == None:
2739
  abort = True
2740
  state["prompt"] = ""
 
2741
  else:
2742
  sample = samples.cpu()
2743
  if True: # for testing
@@ -2839,7 +2899,6 @@ def generate_video(
2839
 
2840
  print(f"New video saved to Path: "+video_path)
2841
  file_list.append(video_path)
2842
- state['update_gallery'] = True
2843
  send_cmd("output")
2844
  if sliding_window :
2845
  if max_frames_to_generate > 0 and extra_windows == 0:
@@ -2847,8 +2906,6 @@ def generate_video(
2847
  if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
2848
  break
2849
  video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
2850
- else:
2851
- break
2852
 
2853
  seed += 1
2854
 
@@ -3416,7 +3473,7 @@ def prepare_inputs_dict(target, inputs ):
3416
  if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
3417
  inputs.pop("model_mode")
3418
 
3419
- if not "Vace" in model_filename:
3420
  unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
3421
  for k in unsaved_params:
3422
  inputs.pop(k)
@@ -3776,6 +3833,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3776
  diffusion_forcing = "diffusion_forcing" in model_filename
3777
  recammaster = "recam" in model_filename
3778
  vace = "Vace" in model_filename
 
3779
  with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
3780
  if diffusion_forcing:
3781
  image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
@@ -3835,23 +3893,27 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3835
  model_mode = gr.Dropdown(visible=False)
3836
  keep_frames_video_source = gr.Text(visible=False)
3837
 
3838
- with gr.Column(visible= vace ) as video_prompt_column:
3839
  video_prompt_type_value= ui_defaults.get("video_prompt_type","")
3840
  video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
3841
  with gr.Row():
3842
- video_prompt_type_video_guide = gr.Dropdown(
3843
- choices=[
3844
- ("None", ""),
3845
- ("Transfer Human Motion from the Control Video", "PV"),
3846
- ("Transfer Depth from the Control Video", "DV"),
3847
- ("Recolorize the Control Video", "CV"),
3848
- # ("Alternate Video Ending", "OV"),
3849
- ("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
3850
- ("Control Video and Mask video for stronger Inpainting ", "MV"),
3851
- ],
3852
- value=filter_letters(video_prompt_type_value, "ODPCMV"),
3853
- label="Video to Video", scale = 3
3854
- )
 
 
 
 
3855
  video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
3856
 
3857
  video_prompt_type_image_refs = gr.Dropdown(
@@ -3869,7 +3931,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3869
  image_refs = gr.Gallery( label ="Reference Images",
3870
  type ="pil", show_label= True,
3871
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
3872
- value= ui_defaults.get("image_refs", None) )
 
3873
 
3874
  # with gr.Row():
3875
  remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
@@ -3929,7 +3992,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3929
  # ("832x1104 (3:4, 720p)", "832x1104"),
3930
  # ("960x960 (1:1, 720p)", "960x960"),
3931
  # 480p
3932
- # ("960x544 (16:9, 480p)", "960x544"),
 
3933
  ("832x480 (16:9, 480p)", "832x480"),
3934
  ("480x832 (9:16, 480p)", "480x832"),
3935
  # ("832x624 (4:3, 540p)", "832x624"),
@@ -4082,13 +4146,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
4082
  gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
4083
  gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
4084
  if diffusion_forcing:
4085
- sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size")
4086
  sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4087
- sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=1, visible = False)
4088
  else:
4089
  sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
4090
  sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4091
- sliding_window_discard_last_frames = gr.Slider(0, 10, value=ui_defaults.get("sliding_window_discard_last_frames", 16), step=1, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
4092
 
4093
 
4094
  with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
@@ -4429,13 +4493,22 @@ def generate_configuration_tab(state, blocks, header, model_choice):
4429
 
4430
  quantization_choice = gr.Dropdown(
4431
  choices=[
4432
- ("Int8 Quantization (recommended)", "int8"),
4433
  ("16 bits (no quantization)", "bf16"),
4434
  ],
4435
  value= transformer_quantization,
4436
  label="Wan Transformer Model Quantization Type (if available)",
4437
  )
4438
 
 
 
 
 
 
 
 
 
 
4439
  index = text_encoder_choices.index(text_encoder_filename)
4440
  index = 0 if index ==0 else index
4441
  text_encoder_choice = gr.Dropdown(
@@ -4446,6 +4519,16 @@ def generate_configuration_tab(state, blocks, header, model_choice):
4446
  value= index,
4447
  label="Text Encoder model"
4448
  )
 
 
 
 
 
 
 
 
 
 
4449
  save_path_choice = gr.Textbox(
4450
  label="Output Folder for Generated Videos",
4451
  value=server_config.get("save_path", save_path)
@@ -4510,14 +4593,7 @@ def generate_configuration_tab(state, blocks, header, model_choice):
4510
  value= profile,
4511
  label="Profile (for power users only, not needed to change it)"
4512
  )
4513
- # default_ui_choice = gr.Dropdown(
4514
- # choices=[
4515
- # ("Text to Video", "t2v"),
4516
- # ("Image to Video", "i2v"),
4517
- # ],
4518
- # value= default_ui,
4519
- # label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
4520
- # )
4521
  metadata_choice = gr.Dropdown(
4522
  choices=[
4523
  ("Export JSON files", "json"),
@@ -4563,6 +4639,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
4563
  state,
4564
  transformer_types_choices,
4565
  text_encoder_choice,
 
 
4566
  save_path_choice,
4567
  attention_choice,
4568
  compile_choice,
@@ -4957,7 +5035,7 @@ def create_demo():
4957
  theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
4958
 
4959
  with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
4960
- gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.3 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
4961
  global model_list
4962
 
4963
  tab_state = gr.State({ "tab_no":0 })
 
40
  AUTOSAVE_FILENAME = "queue.zip"
41
  PROMPT_VARS_MAX = 10
42
 
43
+ target_mmgp_version = "3.4.1"
44
  from importlib.metadata import version
45
  mmgp_version = version("mmgp")
46
  if mmgp_version != target_mmgp_version:
 
133
 
134
  model_filename = state["model_filename"]
135
 
136
+ model_type = get_model_type(model_filename)
137
+ inputs = state.get(model_type, None)
138
+ if model_choice != model_type or inputs ==None:
139
  raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
140
+
 
141
  inputs["state"] = state
142
  inputs.pop("lset_name")
143
  if inputs == None:
 
177
  gr.Info(f"Resolution {resolution} not supported by image 2 video")
178
  return
179
 
180
+ if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
181
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
182
  return
183
 
 
187
  if video_length > sliding_window_size:
188
  gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
189
 
190
+ if "phantom" in model_filename:
191
+ image_refs = inputs["image_refs"]
192
+
193
+ if isinstance(image_refs, list):
194
+ image_refs = [ convert_image(tup[0]) for tup in image_refs ]
195
+ os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
196
+ from wan.utils.utils import resize_and_remove_background
197
+ image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1, fit_into_canvas= True)
198
+
199
+
200
+ if len(prompts) > 0:
201
+ prompts = ["\n".join(prompts)]
202
 
203
+ for single_prompt in prompts:
204
+ extra_inputs = {
205
+ "prompt" : single_prompt,
206
+ "image_refs": image_refs,
207
+ }
208
+ inputs.update(extra_inputs)
209
+ add_video_task(**inputs)
210
+
211
+ elif "diffusion_forcing" in model_filename:
212
  image_start = inputs["image_start"]
213
  video_source = inputs["video_source"]
214
  keep_frames_video_source = inputs["keep_frames_video_source"]
 
1383
  check_loras = args.check_loras ==1
1384
  advanced = args.advanced
1385
 
 
 
 
 
1386
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
1387
  server_config_filename = "wgp_config.json"
1388
  if not os.path.isdir("settings"):
 
1418
  text = reader.read()
1419
  server_config = json.loads(text)
1420
 
1421
+ # for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
1422
+ # if Path(src_path).is_file():
1423
+ # shutil.move(src_path, tgt_path) )
1424
+ # for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
1425
+ # if Path(path).is_file():
1426
+ # os.remove(path)
1427
+
1428
+ path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
1429
+ if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
1430
+ os.remove(path)
1431
+
1432
+ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
1433
+ "ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
1434
+ "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
1435
+ "ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
1436
+ transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
1437
+ "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
1438
+ "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
1439
+ transformer_choices = transformer_choices_t2v + transformer_choices_i2v
1440
 
1441
+ model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
1442
  model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
1443
  "i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
1444
+ "flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
1445
+ "sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
1446
+ "phantom_1.3B" : "phantom_1.3B", }
1447
 
1448
 
1449
  def get_model_type(model_filename):
 
1455
  def test_class_i2v(model_filename):
1456
  return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
1457
 
1458
+ def get_model_name(model_filename, description_container = [""]):
1459
  if "Fun" in model_filename:
1460
  model_name = "Fun InP image2video"
1461
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1462
+ description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model."
1463
  elif "Vace" in model_filename:
1464
  model_name = "Vace ControlNet"
1465
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1466
+ description = "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video."
1467
  elif "image" in model_filename:
1468
  model_name = "Wan2.1 image2video"
1469
  model_name += " 720p" if "720p" in model_filename else " 480p"
1470
+ if "720p" in model_filename:
1471
+ description = "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
1472
+ else:
1473
+ description = "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
1474
  elif "recam" in model_filename:
1475
  model_name = "ReCamMaster"
1476
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1477
+ description = "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)"
1478
  elif "FLF2V" in model_filename:
1479
  model_name = "Wan2.1 FLF2V"
1480
  model_name += " 720p" if "720p" in model_filename else " 480p"
1481
+ description = "The First Last Frame 2 Video model is the official model Image 2 Video model that support Start and End frames."
1482
  elif "sky_reels2_diffusion_forcing" in model_filename:
1483
+ model_name = "SkyReels2 Diffusion Forcing"
1484
+ if "720p" in model_filename :
1485
+ model_name += " 720p"
1486
+ elif not "1.3B" in model_filename :
1487
+ model_name += " 540p"
1488
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1489
+ description = "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video."
1490
+ elif "phantom" in model_filename:
1491
+ model_name = "Wan2.1 Phantom"
1492
+ model_name += " 14B" if "14B" in model_filename else " 1.3B"
1493
+ description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
1494
  else:
1495
  model_name = "Wan2.1 text2video"
1496
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1497
+ description = "The original Wan Text 2 Video model. Most other models have been built on top of it"
1498
+ description_container[0] = description
1499
  return model_name
1500
 
1501
 
 
1549
  "slg_end_perc": 90
1550
  }
1551
 
1552
+ if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
1553
  ui_defaults.update({
1554
  "guidance_scale": 6.0,
1555
  "flow_shift": 8,
1556
+ "sliding_window_discard_last_frames" : 0,
1557
+ "resolution": "1280x720" if "720p" in filename else "960x544",
1558
+ "sliding_window_size" : 121 if "720p" in filename else 97,
1559
+ "RIFLEx_setting": 2,
1560
+ "guidance_scale": 6,
1561
+ "flow_shift": 8,
1562
  })
1563
 
1564
+
1565
+ if get_model_type(filename) in ("phantom_1.3B"):
1566
+ ui_defaults.update({
1567
+ "guidance_scale": 7.5,
1568
+ "flow_shift": 5,
1569
+ "resolution": "1280x720"
1570
+ })
1571
+
1572
+
1573
+
1574
  with open(defaults_filename, "w", encoding="utf-8") as f:
1575
  json.dump(ui_defaults, f, indent=4)
1576
  else:
 
1720
  from huggingface_hub import hf_hub_download, snapshot_download
1721
  repoId = "DeepBeepMeep/Wan2.1"
1722
  sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
1723
+ fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
1724
  targetRoot = "ckpts/"
1725
  for sourceFolder, files in zip(sourceFolderList,fileList ):
1726
  if len(files)==0:
 
1834
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
1835
 
1836
 
1837
+ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1838
 
1839
  cfg = WAN_CONFIGS['t2v-14B']
1840
  # cfg = WAN_CONFIGS['t2v-1.3B']
1841
  print(f"Loading '{model_filename}' model...")
1842
+ if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
1843
  model_factory = wan.DTT2V
1844
  else:
1845
  model_factory = wan.WanT2V
 
1850
  model_filename=model_filename,
1851
  text_encoder_filename= text_encoder_filename,
1852
  quantizeTransformer = quantizeTransformer,
1853
+ dtype = dtype,
1854
+ VAE_dtype = VAE_dtype,
1855
+ mixed_precision_transformer = mixed_precision_transformer
1856
  )
1857
 
1858
  pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
1859
 
1860
  return wan_model, pipe
1861
 
1862
+ def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
1863
 
1864
  print(f"Loading '{model_filename}' model...")
1865
 
1866
+ cfg = WAN_CONFIGS['i2v-14B']
1867
+ wan_model = wan.WanI2V(
1868
+ config=cfg,
1869
+ checkpoint_dir="ckpts",
1870
+ model_filename=model_filename,
1871
+ text_encoder_filename=text_encoder_filename,
1872
+ quantizeTransformer = quantizeTransformer,
1873
+ dtype = dtype,
1874
+ VAE_dtype = VAE_dtype,
1875
+ mixed_precision_transformer = mixed_precision_transformer
1876
+ )
1877
+ pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
1878
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1879
  return wan_model, pipe
1880
 
1881
 
 
1887
  perc_reserved_mem_max = args.perc_reserved_mem_max
1888
 
1889
  major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
1890
+ if major < 8:
 
 
1891
  print("Switching to f16 model as GPU architecture doesn't support bf16")
1892
+ default_dtype = torch.float16
1893
+ else:
1894
+ default_dtype = torch.float16 if args.fp16 else torch.bfloat16
1895
+ if default_dtype == torch.float16 :
1896
  if "quanto" in model_filename:
1897
  model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
1898
  download_models(model_filename, text_encoder_filename)
1899
+ VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
1900
+ mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
1901
  if test_class_i2v(model_filename):
1902
  res720P = "720p" in model_filename
1903
+ wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1904
  else:
1905
+ wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
1906
  wan_model._model_file_name = model_filename
1907
  kwargs = { "extraModelsToQuantize": None}
1908
  if profile == 2 or profile == 4:
 
1943
 
1944
 
1945
  def generate_header(model_filename, compile, attention_mode):
1946
+
1947
+ description_container = [""]
1948
+ get_model_name(model_filename, description_container)
1949
+ description = description_container[0]
1950
+ header = "<DIV style='height:40px'>" + description + "</DIV>"
1951
+
1952
+ header += "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
1953
  if attention_mode not in attention_modes_installed:
1954
  header += " -NOT INSTALLED-"
1955
  elif attention_mode not in attention_modes_supported:
 
1967
  def apply_changes( state,
1968
  transformer_types_choices,
1969
  text_encoder_choice,
1970
+ VAE_precision_choice,
1971
+ mixed_precision_choice,
1972
  save_path_choice,
1973
  attention_choice,
1974
  compile_choice,
 
1984
  if args.lock_config:
1985
  return
1986
  if gen_in_progress:
1987
+ return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>", gr.update(), gr.update()
1988
  global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1989
  server_config = {"attention_mode" : attention_choice,
1990
  "transformer_types": transformer_types_choices,
 
1993
  "compile" : compile_choice,
1994
  "profile" : profile_choice,
1995
  "vae_config" : vae_config_choice,
1996
+ "vae_precision" : VAE_precision_choice,
1997
+ "mixed_precision" : mixed_precision_choice,
1998
  "metadata_type": metadata_choice,
1999
  "transformer_quantization" : quantization_choice,
2000
  "boost" : boost_choice,
 
2116
  return callback
2117
  def abort_generation(state):
2118
  gen = get_gen_info(state)
2119
+ if "in_progress" in gen and wan_model != None:
2120
 
2121
+ wan_model._interrupt= True
 
 
 
2122
  msg = "Processing Request to abort Current Generation"
2123
  gen["status"] = msg
2124
  gr.Info(msg)
 
2201
  return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
2202
 
2203
 
 
 
 
 
 
 
 
2204
  def select_video(state , event_data: gr.EventData):
2205
  data= event_data._data
2206
  gen = get_gen_info(state)
 
2439
  # VAE Tiling
2440
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
2441
  if vae_config == 0:
2442
+ if server_config.get("vae_precision", "16") == "32":
2443
+ device_mem_capacity = device_mem_capacity / 2
2444
  if device_mem_capacity >= 24000:
2445
  use_vae_config = 1
2446
  elif device_mem_capacity >= 8000:
 
2553
  max_frames_to_generate = video_length
2554
  diffusion_forcing = "diffusion_forcing" in model_filename
2555
  vace = "Vace" in model_filename
2556
+ phantom = "phantom" in model_filename
2557
  if diffusion_forcing or vace:
2558
  reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
2559
  if diffusion_forcing and source_video != None:
 
2593
  extra_windows = 0
2594
  guide_start_frame = 0
2595
  video_length = first_window_video_length
2596
+ gen["extra_windows"] = 0
2597
  while not abort:
2598
  if sliding_window:
2599
  prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
 
2608
  window_no += 1
2609
  gen["window_no"] = window_no
2610
 
2611
+ if phantom:
2612
+ src_ref_images = image_refs.copy() if image_refs != None else None
2613
+ elif diffusion_forcing:
2614
  if video_source != None and len(video_source) > 0 and window_no == 1:
2615
  keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
2616
  prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
 
2619
  prefix_video_frames_count = prefix_video.shape[1]
2620
  pre_video_guide = prefix_video[:, -reuse_frames:]
2621
 
2622
+ elif vace:
2623
  # video_prompt_type = video_prompt_type +"G"
2624
  image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
2625
  video_guide_copy = video_guide
 
2670
  progress_args = [0, status + " - Encoding Prompt"]
2671
  send_cmd("progress", progress_args)
2672
 
2673
+ # samples = torch.empty( (1,2)) #for testing
2674
  # if False:
2675
 
2676
  try:
 
2693
  guide_scale=guidance_scale,
2694
  n_prompt=negative_prompt,
2695
  seed=seed,
 
2696
  callback=callback,
2697
  enable_RIFLEx = enable_RIFLEx,
2698
  VAE_tile_size = VAE_tile_size,
 
2797
  if samples == None:
2798
  abort = True
2799
  state["prompt"] = ""
2800
+ send_cmd("output")
2801
  else:
2802
  sample = samples.cpu()
2803
  if True: # for testing
 
2899
 
2900
  print(f"New video saved to Path: "+video_path)
2901
  file_list.append(video_path)
 
2902
  send_cmd("output")
2903
  if sliding_window :
2904
  if max_frames_to_generate > 0 and extra_windows == 0:
 
2906
  if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
2907
  break
2908
  video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
 
 
2909
 
2910
  seed += 1
2911
 
 
3473
  if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
3474
  inputs.pop("model_mode")
3475
 
3476
+ if not "Vace" in model_filename or not "phantom" in model_filename:
3477
  unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
3478
  for k in unsaved_params:
3479
  inputs.pop(k)
 
3833
  diffusion_forcing = "diffusion_forcing" in model_filename
3834
  recammaster = "recam" in model_filename
3835
  vace = "Vace" in model_filename
3836
+ phantom = "phantom" in model_filename
3837
  with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
3838
  if diffusion_forcing:
3839
  image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
 
3893
  model_mode = gr.Dropdown(visible=False)
3894
  keep_frames_video_source = gr.Text(visible=False)
3895
 
3896
+ with gr.Column(visible= vace or phantom) as video_prompt_column:
3897
  video_prompt_type_value= ui_defaults.get("video_prompt_type","")
3898
  video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
3899
  with gr.Row():
3900
+ if vace:
3901
+ video_prompt_type_video_guide = gr.Dropdown(
3902
+ choices=[
3903
+ ("None", ""),
3904
+ ("Transfer Human Motion from the Control Video", "PV"),
3905
+ ("Transfer Depth from the Control Video", "DV"),
3906
+ ("Recolorize the Control Video", "CV"),
3907
+ # ("Alternate Video Ending", "OV"),
3908
+ ("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
3909
+ ("Control Video and Mask video for stronger Inpainting ", "MV"),
3910
+ ],
3911
+ value=filter_letters(video_prompt_type_value, "ODPCMV"),
3912
+ label="Video to Video", scale = 3, visible= True
3913
+ )
3914
+ else:
3915
+ video_prompt_type_video_guide = gr.Dropdown(visible= False)
3916
+
3917
  video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
3918
 
3919
  video_prompt_type_image_refs = gr.Dropdown(
 
3931
  image_refs = gr.Gallery( label ="Reference Images",
3932
  type ="pil", show_label= True,
3933
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
3934
+ value= ui_defaults.get("image_refs", None),
3935
+ )
3936
 
3937
  # with gr.Row():
3938
  remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
 
3992
  # ("832x1104 (3:4, 720p)", "832x1104"),
3993
  # ("960x960 (1:1, 720p)", "960x960"),
3994
  # 480p
3995
+ ("960x544 (16:9, 540p)", "960x544"),
3996
+ ("544x960 (16:9, 540p)", "544x960"),
3997
  ("832x480 (16:9, 480p)", "832x480"),
3998
  ("480x832 (9:16, 480p)", "480x832"),
3999
  # ("832x624 (4:3, 540p)", "832x624"),
 
4146
  gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
4147
  gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
4148
  if diffusion_forcing:
4149
+ sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
4150
  sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4151
+ sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
4152
  else:
4153
  sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
4154
  sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
4155
+ sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
4156
 
4157
 
4158
  with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
 
4493
 
4494
  quantization_choice = gr.Dropdown(
4495
  choices=[
4496
+ ("Scaled Int8 Quantization (recommended)", "int8"),
4497
  ("16 bits (no quantization)", "bf16"),
4498
  ],
4499
  value= transformer_quantization,
4500
  label="Wan Transformer Model Quantization Type (if available)",
4501
  )
4502
 
4503
+ mixed_precision_choice = gr.Dropdown(
4504
+ choices=[
4505
+ ("16 bits only, requires less VRAM", "0"),
4506
+ ("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
4507
+ ],
4508
+ value= server_config.get("mixed_precision", "0"),
4509
+ label="Transformer Engine Calculation"
4510
+ )
4511
+
4512
  index = text_encoder_choices.index(text_encoder_filename)
4513
  index = 0 if index ==0 else index
4514
  text_encoder_choice = gr.Dropdown(
 
4519
  value= index,
4520
  label="Text Encoder model"
4521
  )
4522
+
4523
+ VAE_precision_choice = gr.Dropdown(
4524
+ choices=[
4525
+ ("16 bits, requires less VRAM and faster", "16"),
4526
+ ("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
4527
+ ],
4528
+ value= server_config.get("vae_precision", "16"),
4529
+ label="VAE Encoding / Decoding precision"
4530
+ )
4531
+
4532
  save_path_choice = gr.Textbox(
4533
  label="Output Folder for Generated Videos",
4534
  value=server_config.get("save_path", save_path)
 
4593
  value= profile,
4594
  label="Profile (for power users only, not needed to change it)"
4595
  )
4596
+
 
 
 
 
 
 
 
4597
  metadata_choice = gr.Dropdown(
4598
  choices=[
4599
  ("Export JSON files", "json"),
 
4639
  state,
4640
  transformer_types_choices,
4641
  text_encoder_choice,
4642
+ VAE_precision_choice,
4643
+ mixed_precision_choice,
4644
  save_path_choice,
4645
  attention_choice,
4646
  compile_choice,
 
5035
  theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
5036
 
5037
  with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
5038
+ gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
5039
  global model_list
5040
 
5041
  tab_state = gr.State({ "tab_no":0 })