deepbeepmeep
commited on
Commit
·
21a01ff
1
Parent(s):
cad98bc
added Phantom model support
Browse files- README.md +3 -1
- wan/diffusion_forcing.py +61 -52
- wan/image2video.py +35 -98
- wan/modules/model.py +73 -42
- wan/modules/vae.py +10 -7
- wan/text2video.py +99 -51
- wan/utils/utils.py +16 -5
- wgp.py +192 -114
README.md
CHANGED
|
@@ -10,7 +10,9 @@
|
|
| 10 |
|
| 11 |
|
| 12 |
## 🔥 Latest News!!
|
| 13 |
-
* April
|
|
|
|
|
|
|
| 14 |
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
| 15 |
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
| 16 |
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
## 🔥 Latest News!!
|
| 13 |
+
* April 27 2025: 👋 Wan 2.1GP v4.4: Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
|
| 14 |
+
* April 25 2025: 👋 Wan 2.1GP v4.3: Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos" (see Window Sliding section below).Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if chose an other type of attention, some of the processes will use Sdpa attention.
|
| 15 |
+
|
| 16 |
* April 18 2025: 👋 Wan 2.1GP v4.2: FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
|
| 17 |
* April 17 2025: 👋 Wan 2.1GP v4.1: Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
|
| 18 |
* April 13 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
wan/diffusion_forcing.py
CHANGED
|
@@ -31,6 +31,8 @@ class DTT2V:
|
|
| 31 |
text_encoder_filename = None,
|
| 32 |
quantizeTransformer = False,
|
| 33 |
dtype = torch.bfloat16,
|
|
|
|
|
|
|
| 34 |
):
|
| 35 |
self.device = torch.device(f"cuda")
|
| 36 |
self.config = config
|
|
@@ -50,24 +52,22 @@ class DTT2V:
|
|
| 50 |
self.vae_stride = config.vae_stride
|
| 51 |
self.patch_size = config.patch_size
|
| 52 |
|
| 53 |
-
|
| 54 |
self.vae = WanVAE(
|
| 55 |
-
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 56 |
device=self.device)
|
| 57 |
|
| 58 |
logging.info(f"Creating WanModel from {model_filename}")
|
| 59 |
from mmgp import offload
|
| 60 |
# model_filename = "model.safetensors"
|
| 61 |
-
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json"
|
| 62 |
# offload.load_model_data(self.model, "recam.ckpt")
|
| 63 |
# self.model.cpu()
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# offload.save_model(self.model, "
|
| 67 |
-
# offload.save_model(self.model, "
|
| 68 |
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
| 69 |
-
|
| 70 |
-
self.vae.model.to(self.dtype)
|
| 71 |
self.model.eval().requires_grad_(False)
|
| 72 |
|
| 73 |
self.scheduler = FlowUniPCMultistepScheduler()
|
|
@@ -228,11 +228,16 @@ class DTT2V:
|
|
| 228 |
latent_height = height // 8
|
| 229 |
latent_width = width // 8
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
| 233 |
if self.do_classifier_free_guidance:
|
| 234 |
-
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
|
| 235 |
-
negative_prompt_embeds =
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
| 238 |
init_timesteps = self.scheduler.timesteps
|
|
@@ -305,6 +310,17 @@ class DTT2V:
|
|
| 305 |
del time_steps_comb
|
| 306 |
from mmgp import offload
|
| 307 |
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
| 309 |
offload.set_step_no_for_lora(self.model, i)
|
| 310 |
update_mask_i = step_update_mask[i]
|
|
@@ -323,52 +339,45 @@ class DTT2V:
|
|
| 323 |
* noise_factor
|
| 324 |
)
|
| 325 |
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
| 326 |
-
kwrags
|
| 327 |
"x" : torch.stack([latent_model_input[0]]),
|
| 328 |
"t" : timestep,
|
| 329 |
-
"freqs" :freqs,
|
| 330 |
-
"fps" : fps_embeds,
|
| 331 |
-
"causal_block_size" : causal_block_size,
|
| 332 |
-
"causal_attention" : causal_attention,
|
| 333 |
-
"callback" : callback,
|
| 334 |
-
"pipeline" : self,
|
| 335 |
"current_step" : i,
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
if
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
)[0]
|
| 344 |
-
if self._interrupt:
|
| 345 |
-
return None
|
| 346 |
-
noise_pred= noise_pred.to(torch.float32)
|
| 347 |
-
else:
|
| 348 |
-
if joint_pass:
|
| 349 |
-
noise_pred_cond, noise_pred_uncond = self.model(
|
| 350 |
-
context=prompt_embeds,
|
| 351 |
-
context2=negative_prompt_embeds,
|
| 352 |
-
**kwrags,
|
| 353 |
-
)
|
| 354 |
-
if self._interrupt:
|
| 355 |
-
return None
|
| 356 |
-
else:
|
| 357 |
-
noise_pred_cond = self.model(
|
| 358 |
-
context=prompt_embeds,
|
| 359 |
**kwrags,
|
| 360 |
)[0]
|
| 361 |
-
if self._interrupt:
|
| 362 |
-
return None
|
| 363 |
-
noise_pred_uncond = self.model(
|
| 364 |
-
context=negative_prompt_embeds,
|
| 365 |
-
)[0]
|
| 366 |
if self._interrupt:
|
| 367 |
return None
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
for idx in range(valid_interval_start, valid_interval_end):
|
| 373 |
if update_mask_i[idx].item():
|
| 374 |
latents[0][:, idx] = sample_schedulers[idx].step(
|
|
|
|
| 31 |
text_encoder_filename = None,
|
| 32 |
quantizeTransformer = False,
|
| 33 |
dtype = torch.bfloat16,
|
| 34 |
+
VAE_dtype = torch.float32,
|
| 35 |
+
mixed_precision_transformer = False,
|
| 36 |
):
|
| 37 |
self.device = torch.device(f"cuda")
|
| 38 |
self.config = config
|
|
|
|
| 52 |
self.vae_stride = config.vae_stride
|
| 53 |
self.patch_size = config.patch_size
|
| 54 |
|
|
|
|
| 55 |
self.vae = WanVAE(
|
| 56 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
| 57 |
device=self.device)
|
| 58 |
|
| 59 |
logging.info(f"Creating WanModel from {model_filename}")
|
| 60 |
from mmgp import offload
|
| 61 |
# model_filename = "model.safetensors"
|
| 62 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) #, forcedConfigPath="config.json")
|
| 63 |
# offload.load_model_data(self.model, "recam.ckpt")
|
| 64 |
# self.model.cpu()
|
| 65 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
| 66 |
+
offload.change_dtype(self.model, dtype, True)
|
| 67 |
+
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", config_file_path="config.json")
|
| 68 |
+
# offload.save_model(self.model, "sky_reels2_diffusion_forcing_720p_14B_quanto_xbf16_int8.safetensors", do_quantize= True, config_file_path="config.json")
|
| 69 |
# offload.save_model(self.model, "rtfp16_int8.safetensors", do_quantize= "config.json")
|
| 70 |
+
|
|
|
|
| 71 |
self.model.eval().requires_grad_(False)
|
| 72 |
|
| 73 |
self.scheduler = FlowUniPCMultistepScheduler()
|
|
|
|
| 228 |
latent_height = height // 8
|
| 229 |
latent_width = width // 8
|
| 230 |
|
| 231 |
+
if self._interrupt:
|
| 232 |
+
return None
|
| 233 |
+
prompt_embeds = self.text_encoder([prompt], self.device)[0]
|
| 234 |
+
prompt_embeds = prompt_embeds.to(self.dtype).to(self.device)
|
| 235 |
if self.do_classifier_free_guidance:
|
| 236 |
+
negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)[0]
|
| 237 |
+
negative_prompt_embeds = negative_prompt_embeds.to(self.dtype).to(self.device)
|
| 238 |
+
|
| 239 |
+
if self._interrupt:
|
| 240 |
+
return None
|
| 241 |
|
| 242 |
self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
|
| 243 |
init_timesteps = self.scheduler.timesteps
|
|
|
|
| 310 |
del time_steps_comb
|
| 311 |
from mmgp import offload
|
| 312 |
freqs = get_rotary_pos_embed(latents[0].shape[1 :], enable_RIFLEx= False)
|
| 313 |
+
kwrags = {
|
| 314 |
+
"freqs" :freqs,
|
| 315 |
+
"fps" : fps_embeds,
|
| 316 |
+
"causal_block_size" : causal_block_size,
|
| 317 |
+
"causal_attention" : causal_attention,
|
| 318 |
+
"callback" : callback,
|
| 319 |
+
"pipeline" : self,
|
| 320 |
+
}
|
| 321 |
+
kwrags.update(i2v_extra_kwrags)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
for i, timestep_i in enumerate(tqdm(step_matrix)):
|
| 325 |
offload.set_step_no_for_lora(self.model, i)
|
| 326 |
update_mask_i = step_update_mask[i]
|
|
|
|
| 339 |
* noise_factor
|
| 340 |
)
|
| 341 |
timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
|
| 342 |
+
kwrags.update({
|
| 343 |
"x" : torch.stack([latent_model_input[0]]),
|
| 344 |
"t" : timestep,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
"current_step" : i,
|
| 346 |
+
})
|
| 347 |
+
|
| 348 |
+
# with torch.autocast(device_type="cuda"):
|
| 349 |
+
if True:
|
| 350 |
+
if not self.do_classifier_free_guidance:
|
| 351 |
+
noise_pred = self.model(
|
| 352 |
+
context=[prompt_embeds],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
**kwrags,
|
| 354 |
)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
if self._interrupt:
|
| 356 |
return None
|
| 357 |
+
noise_pred= noise_pred.to(torch.float32)
|
| 358 |
+
else:
|
| 359 |
+
if joint_pass:
|
| 360 |
+
noise_pred_cond, noise_pred_uncond = self.model(
|
| 361 |
+
context= [prompt_embeds, negative_prompt_embeds],
|
| 362 |
+
**kwrags,
|
| 363 |
+
)
|
| 364 |
+
if self._interrupt:
|
| 365 |
+
return None
|
| 366 |
+
else:
|
| 367 |
+
noise_pred_cond = self.model(
|
| 368 |
+
context=[prompt_embeds],
|
| 369 |
+
**kwrags,
|
| 370 |
+
)[0]
|
| 371 |
+
if self._interrupt:
|
| 372 |
+
return None
|
| 373 |
+
noise_pred_uncond = self.model(
|
| 374 |
+
context=[negative_prompt_embeds],
|
| 375 |
+
**kwrags,
|
| 376 |
+
)[0]
|
| 377 |
+
if self._interrupt:
|
| 378 |
+
return None
|
| 379 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 380 |
+
del noise_pred_cond, noise_pred_uncond
|
| 381 |
for idx in range(valid_interval_start, valid_interval_end):
|
| 382 |
if update_mask_i[idx].item():
|
| 383 |
latents[0][:, idx] = sample_schedulers[idx].step(
|
wan/image2video.py
CHANGED
|
@@ -48,47 +48,17 @@ class WanI2V:
|
|
| 48 |
self,
|
| 49 |
config,
|
| 50 |
checkpoint_dir,
|
| 51 |
-
rank=0,
|
| 52 |
-
t5_fsdp=False,
|
| 53 |
-
dit_fsdp=False,
|
| 54 |
-
use_usp=False,
|
| 55 |
-
t5_cpu=False,
|
| 56 |
-
init_on_cpu=True,
|
| 57 |
-
i2v720p= True,
|
| 58 |
model_filename ="",
|
| 59 |
text_encoder_filename="",
|
| 60 |
quantizeTransformer = False,
|
| 61 |
-
dtype = torch.bfloat16
|
|
|
|
|
|
|
| 62 |
):
|
| 63 |
-
r"""
|
| 64 |
-
Initializes the image-to-video generation model components.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
config (EasyDict):
|
| 68 |
-
Object containing model parameters initialized from config.py
|
| 69 |
-
checkpoint_dir (`str`):
|
| 70 |
-
Path to directory containing model checkpoints
|
| 71 |
-
device_id (`int`, *optional*, defaults to 0):
|
| 72 |
-
Id of target GPU device
|
| 73 |
-
rank (`int`, *optional*, defaults to 0):
|
| 74 |
-
Process rank for distributed training
|
| 75 |
-
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 76 |
-
Enable FSDP sharding for T5 model
|
| 77 |
-
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 78 |
-
Enable FSDP sharding for DiT model
|
| 79 |
-
use_usp (`bool`, *optional*, defaults to False):
|
| 80 |
-
Enable distribution strategy of USP.
|
| 81 |
-
t5_cpu (`bool`, *optional*, defaults to False):
|
| 82 |
-
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 83 |
-
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 84 |
-
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 85 |
-
"""
|
| 86 |
self.device = torch.device(f"cuda")
|
| 87 |
self.config = config
|
| 88 |
-
self.rank = rank
|
| 89 |
-
self.use_usp = use_usp
|
| 90 |
-
self.t5_cpu = t5_cpu
|
| 91 |
self.dtype = dtype
|
|
|
|
| 92 |
self.num_train_timesteps = config.num_train_timesteps
|
| 93 |
self.param_dtype = config.param_dtype
|
| 94 |
# shard_fn = partial(shard_model, device_id=device_id)
|
|
@@ -104,7 +74,7 @@ class WanI2V:
|
|
| 104 |
self.vae_stride = config.vae_stride
|
| 105 |
self.patch_size = config.patch_size
|
| 106 |
self.vae = WanVAE(
|
| 107 |
-
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 108 |
device=self.device)
|
| 109 |
|
| 110 |
self.clip = CLIPModel(
|
|
@@ -118,11 +88,9 @@ class WanI2V:
|
|
| 118 |
from mmgp import offload
|
| 119 |
|
| 120 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
| 121 |
-
|
| 122 |
-
|
| 123 |
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
|
| 124 |
-
if self.dtype == torch.float16:
|
| 125 |
-
self.vae.model.to(self.dtype)
|
| 126 |
|
| 127 |
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
| 128 |
self.model.eval().requires_grad_(False)
|
|
@@ -142,7 +110,6 @@ class WanI2V:
|
|
| 142 |
guide_scale=5.0,
|
| 143 |
n_prompt="",
|
| 144 |
seed=-1,
|
| 145 |
-
offload_model=True,
|
| 146 |
callback = None,
|
| 147 |
enable_RIFLEx = False,
|
| 148 |
VAE_tile_size= 0,
|
|
@@ -212,13 +179,13 @@ class WanI2V:
|
|
| 212 |
w = lat_w * self.vae_stride[2]
|
| 213 |
|
| 214 |
clip_image_size = self.clip.model.image_size
|
| 215 |
-
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device
|
| 216 |
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
| 217 |
-
img = img.sub_(0.5).div_(0.5).to(self.device
|
| 218 |
if img2!= None:
|
| 219 |
-
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device
|
| 220 |
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
| 221 |
-
img2 = img2.sub_(0.5).div_(0.5).to(self.device
|
| 222 |
|
| 223 |
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
| 224 |
|
|
@@ -244,25 +211,19 @@ class WanI2V:
|
|
| 244 |
if n_prompt == "":
|
| 245 |
n_prompt = self.sample_neg_prompt
|
| 246 |
|
|
|
|
|
|
|
|
|
|
| 247 |
# preprocess
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
if offload_model:
|
| 253 |
-
self.text_encoder.model.cpu()
|
| 254 |
-
else:
|
| 255 |
-
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 256 |
-
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 257 |
-
context = [t.to(self.device) for t in context]
|
| 258 |
-
context_null = [t.to(self.device) for t in context_null]
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
|
| 263 |
clip_context = self.clip.visual([img[:, None, :, :]])
|
| 264 |
-
if offload_model:
|
| 265 |
-
self.clip.model.cpu()
|
| 266 |
|
| 267 |
from mmgp import offload
|
| 268 |
offload.last_offload_obj.unload_all()
|
|
@@ -270,23 +231,20 @@ class WanI2V:
|
|
| 270 |
mean2 = 0
|
| 271 |
enc= torch.concat([
|
| 272 |
img_interpolated,
|
| 273 |
-
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.
|
| 274 |
img_interpolated2,
|
| 275 |
], dim=1).to(self.device)
|
| 276 |
else:
|
| 277 |
enc= torch.concat([
|
| 278 |
img_interpolated,
|
| 279 |
-
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.
|
| 280 |
], dim=1).to(self.device)
|
|
|
|
| 281 |
|
| 282 |
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
| 283 |
y = torch.concat([msk, lat_y])
|
|
|
|
| 284 |
|
| 285 |
-
@contextmanager
|
| 286 |
-
def noop_no_sync():
|
| 287 |
-
yield
|
| 288 |
-
|
| 289 |
-
no_sync = getattr(self.model, 'no_sync', noop_no_sync)
|
| 290 |
|
| 291 |
# evaluation mode
|
| 292 |
|
|
@@ -317,7 +275,7 @@ class WanI2V:
|
|
| 317 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 318 |
|
| 319 |
arg_c = {
|
| 320 |
-
'context': [context
|
| 321 |
'clip_fea': clip_context,
|
| 322 |
'y': [y],
|
| 323 |
'freqs' : freqs,
|
|
@@ -326,7 +284,7 @@ class WanI2V:
|
|
| 326 |
}
|
| 327 |
|
| 328 |
arg_null = {
|
| 329 |
-
'context': context_null,
|
| 330 |
'clip_fea': clip_context,
|
| 331 |
'y': [y],
|
| 332 |
'freqs' : freqs,
|
|
@@ -335,8 +293,7 @@ class WanI2V:
|
|
| 335 |
}
|
| 336 |
|
| 337 |
arg_both= {
|
| 338 |
-
'context': [context
|
| 339 |
-
'context2': context_null,
|
| 340 |
'clip_fea': clip_context,
|
| 341 |
'y': [y],
|
| 342 |
'freqs' : freqs,
|
|
@@ -344,9 +301,6 @@ class WanI2V:
|
|
| 344 |
'callback' : callback
|
| 345 |
}
|
| 346 |
|
| 347 |
-
if offload_model:
|
| 348 |
-
torch.cuda.empty_cache()
|
| 349 |
-
|
| 350 |
if self.model.enable_teacache:
|
| 351 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 352 |
|
|
@@ -379,8 +333,6 @@ class WanI2V:
|
|
| 379 |
)[0]
|
| 380 |
if self._interrupt:
|
| 381 |
return None
|
| 382 |
-
if offload_model:
|
| 383 |
-
torch.cuda.empty_cache()
|
| 384 |
noise_pred_uncond = self.model(
|
| 385 |
latent_model_input,
|
| 386 |
t=timestep,
|
|
@@ -392,8 +344,7 @@ class WanI2V:
|
|
| 392 |
if self._interrupt:
|
| 393 |
return None
|
| 394 |
del latent_model_input
|
| 395 |
-
|
| 396 |
-
torch.cuda.empty_cache()
|
| 397 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 398 |
noise_pred_text = noise_pred_cond
|
| 399 |
if cfg_star_switch:
|
|
@@ -412,9 +363,6 @@ class WanI2V:
|
|
| 412 |
|
| 413 |
del noise_pred_uncond
|
| 414 |
|
| 415 |
-
latent = latent.to(
|
| 416 |
-
torch.device('cpu') if offload_model else self.device)
|
| 417 |
-
|
| 418 |
temp_x0 = sample_scheduler.step(
|
| 419 |
noise_pred.unsqueeze(0),
|
| 420 |
t,
|
|
@@ -429,29 +377,18 @@ class WanI2V:
|
|
| 429 |
callback(i, latent, False)
|
| 430 |
|
| 431 |
|
| 432 |
-
x0 = [latent.to(self.device, dtype=self.dtype)]
|
| 433 |
-
|
| 434 |
-
if offload_model:
|
| 435 |
-
self.model.cpu()
|
| 436 |
-
torch.cuda.empty_cache()
|
| 437 |
|
| 438 |
-
|
| 439 |
-
# x0 = [lat_y]
|
| 440 |
-
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
| 441 |
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
video = video[:, :-1]
|
| 445 |
|
| 446 |
-
|
| 447 |
-
video =
|
|
|
|
| 448 |
|
| 449 |
del noise, latent
|
| 450 |
del sample_scheduler
|
| 451 |
-
if offload_model:
|
| 452 |
-
gc.collect()
|
| 453 |
-
torch.cuda.synchronize()
|
| 454 |
-
if dist.is_initialized():
|
| 455 |
-
dist.barrier()
|
| 456 |
|
| 457 |
return video
|
|
|
|
| 48 |
self,
|
| 49 |
config,
|
| 50 |
checkpoint_dir,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
model_filename ="",
|
| 52 |
text_encoder_filename="",
|
| 53 |
quantizeTransformer = False,
|
| 54 |
+
dtype = torch.bfloat16,
|
| 55 |
+
VAE_dtype = torch.float32,
|
| 56 |
+
mixed_precision_transformer = False
|
| 57 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
self.device = torch.device(f"cuda")
|
| 59 |
self.config = config
|
|
|
|
|
|
|
|
|
|
| 60 |
self.dtype = dtype
|
| 61 |
+
self.VAE_dtype = VAE_dtype
|
| 62 |
self.num_train_timesteps = config.num_train_timesteps
|
| 63 |
self.param_dtype = config.param_dtype
|
| 64 |
# shard_fn = partial(shard_model, device_id=device_id)
|
|
|
|
| 74 |
self.vae_stride = config.vae_stride
|
| 75 |
self.patch_size = config.patch_size
|
| 76 |
self.vae = WanVAE(
|
| 77 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype = VAE_dtype,
|
| 78 |
device=self.device)
|
| 79 |
|
| 80 |
self.clip = CLIPModel(
|
|
|
|
| 88 |
from mmgp import offload
|
| 89 |
|
| 90 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
| 91 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
| 92 |
+
offload.change_dtype(self.model, dtype, True)
|
| 93 |
# offload.save_model(self.model, "i2v_720p_fp16.safetensors",do_quantize=True)
|
|
|
|
|
|
|
| 94 |
|
| 95 |
# offload.save_model(self.model, "wan2.1_Fun_InP_1.3B_bf16_bis.safetensors")
|
| 96 |
self.model.eval().requires_grad_(False)
|
|
|
|
| 110 |
guide_scale=5.0,
|
| 111 |
n_prompt="",
|
| 112 |
seed=-1,
|
|
|
|
| 113 |
callback = None,
|
| 114 |
enable_RIFLEx = False,
|
| 115 |
VAE_tile_size= 0,
|
|
|
|
| 179 |
w = lat_w * self.vae_stride[2]
|
| 180 |
|
| 181 |
clip_image_size = self.clip.model.image_size
|
| 182 |
+
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
| 183 |
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
| 184 |
+
img = img.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
| 185 |
if img2!= None:
|
| 186 |
+
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device) #, self.dtype
|
| 187 |
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
| 188 |
+
img2 = img2.sub_(0.5).div_(0.5).to(self.device) #, self.dtype
|
| 189 |
|
| 190 |
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
| 191 |
|
|
|
|
| 211 |
if n_prompt == "":
|
| 212 |
n_prompt = self.sample_neg_prompt
|
| 213 |
|
| 214 |
+
if self._interrupt:
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
# preprocess
|
| 218 |
+
context = self.text_encoder([input_prompt], self.device)[0]
|
| 219 |
+
context_null = self.text_encoder([n_prompt], self.device)[0]
|
| 220 |
+
context = context.to(self.dtype)
|
| 221 |
+
context_null = context_null.to(self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
+
if self._interrupt:
|
| 224 |
+
return None
|
| 225 |
|
| 226 |
clip_context = self.clip.visual([img[:, None, :, :]])
|
|
|
|
|
|
|
| 227 |
|
| 228 |
from mmgp import offload
|
| 229 |
offload.last_offload_obj.unload_all()
|
|
|
|
| 231 |
mean2 = 0
|
| 232 |
enc= torch.concat([
|
| 233 |
img_interpolated,
|
| 234 |
+
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= self.VAE_dtype),
|
| 235 |
img_interpolated2,
|
| 236 |
], dim=1).to(self.device)
|
| 237 |
else:
|
| 238 |
enc= torch.concat([
|
| 239 |
img_interpolated,
|
| 240 |
+
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= self.VAE_dtype)
|
| 241 |
], dim=1).to(self.device)
|
| 242 |
+
img, img2, img_interpolated, img_interpolated2 = None, None, None, None
|
| 243 |
|
| 244 |
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
| 245 |
y = torch.concat([msk, lat_y])
|
| 246 |
+
lat_y = None
|
| 247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
# evaluation mode
|
| 250 |
|
|
|
|
| 275 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 276 |
|
| 277 |
arg_c = {
|
| 278 |
+
'context': [context],
|
| 279 |
'clip_fea': clip_context,
|
| 280 |
'y': [y],
|
| 281 |
'freqs' : freqs,
|
|
|
|
| 284 |
}
|
| 285 |
|
| 286 |
arg_null = {
|
| 287 |
+
'context': [context_null],
|
| 288 |
'clip_fea': clip_context,
|
| 289 |
'y': [y],
|
| 290 |
'freqs' : freqs,
|
|
|
|
| 293 |
}
|
| 294 |
|
| 295 |
arg_both= {
|
| 296 |
+
'context': [context, context_null],
|
|
|
|
| 297 |
'clip_fea': clip_context,
|
| 298 |
'y': [y],
|
| 299 |
'freqs' : freqs,
|
|
|
|
| 301 |
'callback' : callback
|
| 302 |
}
|
| 303 |
|
|
|
|
|
|
|
|
|
|
| 304 |
if self.model.enable_teacache:
|
| 305 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 306 |
|
|
|
|
| 333 |
)[0]
|
| 334 |
if self._interrupt:
|
| 335 |
return None
|
|
|
|
|
|
|
| 336 |
noise_pred_uncond = self.model(
|
| 337 |
latent_model_input,
|
| 338 |
t=timestep,
|
|
|
|
| 344 |
if self._interrupt:
|
| 345 |
return None
|
| 346 |
del latent_model_input
|
| 347 |
+
|
|
|
|
| 348 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 349 |
noise_pred_text = noise_pred_cond
|
| 350 |
if cfg_star_switch:
|
|
|
|
| 363 |
|
| 364 |
del noise_pred_uncond
|
| 365 |
|
|
|
|
|
|
|
|
|
|
| 366 |
temp_x0 = sample_scheduler.step(
|
| 367 |
noise_pred.unsqueeze(0),
|
| 368 |
t,
|
|
|
|
| 377 |
callback(i, latent, False)
|
| 378 |
|
| 379 |
|
| 380 |
+
# x0 = [latent.to(self.device, dtype=self.dtype)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
x0 = [latent]
|
|
|
|
|
|
|
| 383 |
|
| 384 |
+
# x0 = [lat_y]
|
| 385 |
+
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
|
|
|
| 386 |
|
| 387 |
+
if any_end_frame and add_frames_for_end_image:
|
| 388 |
+
# video[:, -1:] = img_interpolated2
|
| 389 |
+
video = video[:, :-1]
|
| 390 |
|
| 391 |
del noise, latent
|
| 392 |
del sample_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
return video
|
wan/modules/model.py
CHANGED
|
@@ -408,6 +408,9 @@ class WanAttentionBlock(nn.Module):
|
|
| 408 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 409 |
"""
|
| 410 |
hint = None
|
|
|
|
|
|
|
|
|
|
| 411 |
if self.block_id is not None and hints is not None:
|
| 412 |
kwargs = {
|
| 413 |
"grid_sizes" : grid_sizes,
|
|
@@ -434,9 +437,11 @@ class WanAttentionBlock(nn.Module):
|
|
| 434 |
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
| 435 |
x_mod += cam_emb
|
| 436 |
|
| 437 |
-
xlist = [x_mod]
|
| 438 |
del x_mod
|
| 439 |
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
|
|
|
|
|
|
|
| 440 |
if cam_emb != None:
|
| 441 |
y = self.projector(y)
|
| 442 |
|
|
@@ -445,15 +450,18 @@ class WanAttentionBlock(nn.Module):
|
|
| 445 |
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
| 446 |
del y
|
| 447 |
y = self.norm3(x)
|
|
|
|
| 448 |
ylist= [y]
|
| 449 |
del y
|
| 450 |
-
x += self.cross_attn(ylist, context)
|
|
|
|
| 451 |
y = self.norm2(x)
|
| 452 |
|
| 453 |
y = reshape_latent(y , latent_frames)
|
| 454 |
y *= 1 + e[4]
|
| 455 |
y += e[3]
|
| 456 |
y = reshape_latent(y , 1)
|
|
|
|
| 457 |
|
| 458 |
ffn = self.ffn[0]
|
| 459 |
gelu = self.ffn[1]
|
|
@@ -469,7 +477,7 @@ class WanAttentionBlock(nn.Module):
|
|
| 469 |
y_chunk[...] = ffn2(mlp_chunk)
|
| 470 |
del mlp_chunk
|
| 471 |
y = y.view(y_shape)
|
| 472 |
-
|
| 473 |
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
| 474 |
x.addcmul_(y, e[5])
|
| 475 |
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
|
@@ -532,7 +540,6 @@ class Head(nn.Module):
|
|
| 532 |
out_dim = math.prod(patch_size) * out_dim
|
| 533 |
self.norm = WanLayerNorm(dim, eps)
|
| 534 |
self.head = nn.Linear(dim, out_dim)
|
| 535 |
-
|
| 536 |
# modulation
|
| 537 |
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 538 |
|
|
@@ -552,6 +559,7 @@ class Head(nn.Module):
|
|
| 552 |
x *= (1 + e[1])
|
| 553 |
x += e[0]
|
| 554 |
x = reshape_latent(x , 1)
|
|
|
|
| 555 |
x = self.head(x)
|
| 556 |
return x
|
| 557 |
|
|
@@ -735,6 +743,44 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 735 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 736 |
|
| 737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 739 |
rescale_func = np.poly1d(self.coefficients)
|
| 740 |
e_list = []
|
|
@@ -788,7 +834,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 788 |
freqs = None,
|
| 789 |
pipeline = None,
|
| 790 |
current_step = 0,
|
| 791 |
-
context2 = None,
|
| 792 |
is_uncond=False,
|
| 793 |
max_steps = 0,
|
| 794 |
slg_layers=None,
|
|
@@ -797,7 +842,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 797 |
fps = None,
|
| 798 |
causal_block_size = 1,
|
| 799 |
causal_attention = False,
|
|
|
|
| 800 |
):
|
|
|
|
|
|
|
| 801 |
|
| 802 |
if self.model_type == 'i2v':
|
| 803 |
assert clip_fea is not None and y is not None
|
|
@@ -810,9 +858,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 810 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 811 |
|
| 812 |
# embeddings
|
| 813 |
-
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 814 |
-
|
| 815 |
-
|
| 816 |
|
| 817 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
| 818 |
embed_sizes = grid_sizes[0]
|
|
@@ -836,57 +884,46 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 836 |
|
| 837 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 838 |
x = x[0]
|
|
|
|
|
|
|
|
|
|
| 839 |
|
| 840 |
if t.dim() == 2:
|
| 841 |
b, f = t.shape
|
| 842 |
_flag_df = True
|
| 843 |
else:
|
| 844 |
_flag_df = False
|
| 845 |
-
|
| 846 |
e = self.time_embedding(
|
| 847 |
-
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.patch_embedding.weight.dtype)
|
| 848 |
) # b, dim
|
| 849 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
| 850 |
|
| 851 |
if self.inject_sample_info:
|
| 852 |
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
| 853 |
|
| 854 |
-
fps_emb = self.fps_embedding(fps).float()
|
| 855 |
if _flag_df:
|
| 856 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
| 857 |
else:
|
| 858 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
| 859 |
|
| 860 |
# context
|
| 861 |
-
context = self.text_embedding(
|
| 862 |
-
|
| 863 |
-
torch.cat(
|
| 864 |
-
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 865 |
-
for u in context
|
| 866 |
-
]))
|
| 867 |
-
if context2!=None:
|
| 868 |
-
context2 = self.text_embedding(
|
| 869 |
-
torch.stack([
|
| 870 |
-
torch.cat(
|
| 871 |
-
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 872 |
-
for u in context2
|
| 873 |
-
]))
|
| 874 |
-
|
| 875 |
if clip_fea is not None:
|
| 876 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 877 |
-
context = torch.
|
| 878 |
-
if context2 != None:
|
| 879 |
-
context2 = torch.concat([context_clip, context2], dim=1)
|
| 880 |
|
| 881 |
-
joint_pass =
|
|
|
|
| 882 |
if joint_pass:
|
| 883 |
-
|
| 884 |
-
|
|
|
|
|
|
|
| 885 |
is_uncond = False
|
| 886 |
-
else:
|
| 887 |
-
x_list = [x]
|
| 888 |
-
context_list = [context]
|
| 889 |
del x
|
|
|
|
| 890 |
|
| 891 |
# arguments
|
| 892 |
|
|
@@ -945,10 +982,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 945 |
if callback != None:
|
| 946 |
callback(-1, None, False, True)
|
| 947 |
if pipeline._interrupt:
|
| 948 |
-
|
| 949 |
-
return None, None
|
| 950 |
-
else:
|
| 951 |
-
return [None]
|
| 952 |
|
| 953 |
if slg_layers is not None and block_idx in slg_layers:
|
| 954 |
if is_uncond and not joint_pass:
|
|
@@ -983,10 +1017,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 983 |
x_list[i] = self.unpatchify(x, grid_sizes)
|
| 984 |
del x
|
| 985 |
|
| 986 |
-
|
| 987 |
-
return x_list[0][0], x_list[1][0]
|
| 988 |
-
else:
|
| 989 |
-
return [u.float() for u in x_list[0]]
|
| 990 |
|
| 991 |
def unpatchify(self, x, grid_sizes):
|
| 992 |
r"""
|
|
|
|
| 408 |
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 409 |
"""
|
| 410 |
hint = None
|
| 411 |
+
attention_dtype = self.self_attn.q.weight.dtype
|
| 412 |
+
dtype = x.dtype
|
| 413 |
+
|
| 414 |
if self.block_id is not None and hints is not None:
|
| 415 |
kwargs = {
|
| 416 |
"grid_sizes" : grid_sizes,
|
|
|
|
| 437 |
cam_emb = rearrange(cam_emb, 'b f h w d -> b (f h w) d')
|
| 438 |
x_mod += cam_emb
|
| 439 |
|
| 440 |
+
xlist = [x_mod.to(attention_dtype)]
|
| 441 |
del x_mod
|
| 442 |
y = self.self_attn( xlist, grid_sizes, freqs, block_mask)
|
| 443 |
+
y = y.to(dtype)
|
| 444 |
+
|
| 445 |
if cam_emb != None:
|
| 446 |
y = self.projector(y)
|
| 447 |
|
|
|
|
| 450 |
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
| 451 |
del y
|
| 452 |
y = self.norm3(x)
|
| 453 |
+
y = y.to(attention_dtype)
|
| 454 |
ylist= [y]
|
| 455 |
del y
|
| 456 |
+
x += self.cross_attn(ylist, context).to(dtype)
|
| 457 |
+
|
| 458 |
y = self.norm2(x)
|
| 459 |
|
| 460 |
y = reshape_latent(y , latent_frames)
|
| 461 |
y *= 1 + e[4]
|
| 462 |
y += e[3]
|
| 463 |
y = reshape_latent(y , 1)
|
| 464 |
+
y = y.to(attention_dtype)
|
| 465 |
|
| 466 |
ffn = self.ffn[0]
|
| 467 |
gelu = self.ffn[1]
|
|
|
|
| 477 |
y_chunk[...] = ffn2(mlp_chunk)
|
| 478 |
del mlp_chunk
|
| 479 |
y = y.view(y_shape)
|
| 480 |
+
y = y.to(dtype)
|
| 481 |
x, y = reshape_latent(x , latent_frames), reshape_latent(y , latent_frames)
|
| 482 |
x.addcmul_(y, e[5])
|
| 483 |
x, y = reshape_latent(x , 1), reshape_latent(y , 1)
|
|
|
|
| 540 |
out_dim = math.prod(patch_size) * out_dim
|
| 541 |
self.norm = WanLayerNorm(dim, eps)
|
| 542 |
self.head = nn.Linear(dim, out_dim)
|
|
|
|
| 543 |
# modulation
|
| 544 |
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 545 |
|
|
|
|
| 559 |
x *= (1 + e[1])
|
| 560 |
x += e[0]
|
| 561 |
x = reshape_latent(x , 1)
|
| 562 |
+
x= x.to(self.head.weight.dtype)
|
| 563 |
x = self.head(x)
|
| 564 |
return x
|
| 565 |
|
|
|
|
| 743 |
block.projector.bias = nn.Parameter(torch.zeros(dim))
|
| 744 |
|
| 745 |
|
| 746 |
+
def lock_layers_dtypes(self, dtype = torch.float32, force = False):
|
| 747 |
+
count = 0
|
| 748 |
+
layer_list = [self.head, self.head.head, self.patch_embedding, self.time_embedding, self.time_embedding[0], self.time_embedding[2],
|
| 749 |
+
self.time_projection, self.time_projection[1]] #, self.text_embedding, self.text_embedding[0], self.text_embedding[2] ]
|
| 750 |
+
if hasattr(self, "fps_embedding"):
|
| 751 |
+
layer_list += [self.fps_embedding, self.fps_projection, self.fps_projection[0], self.fps_projection[2]]
|
| 752 |
+
|
| 753 |
+
if hasattr(self, "vace_patch_embedding"):
|
| 754 |
+
layer_list += [self.vace_patch_embedding]
|
| 755 |
+
layer_list += [self.vace_blocks[0].before_proj]
|
| 756 |
+
for block in self.vace_blocks:
|
| 757 |
+
layer_list += [block.after_proj, block.norm3]
|
| 758 |
+
|
| 759 |
+
# cam master
|
| 760 |
+
if hasattr(self.blocks[0], "projector"):
|
| 761 |
+
for block in self.blocks:
|
| 762 |
+
layer_list += [block.projector]
|
| 763 |
+
|
| 764 |
+
for block in self.blocks:
|
| 765 |
+
layer_list += [block.norm3]
|
| 766 |
+
for layer in layer_list:
|
| 767 |
+
if hasattr(layer, "weight"):
|
| 768 |
+
if layer.weight.dtype == dtype :
|
| 769 |
+
count += 1
|
| 770 |
+
elif force:
|
| 771 |
+
if hasattr(layer, "weight"):
|
| 772 |
+
layer.weight.data = layer.weight.data.to(dtype)
|
| 773 |
+
if hasattr(layer, "bias"):
|
| 774 |
+
layer.bias.data = layer.bias.data.to(dtype)
|
| 775 |
+
count += 1
|
| 776 |
+
|
| 777 |
+
layer._lock_dtype = dtype
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
if count > 0:
|
| 781 |
+
self._lock_dtype = dtype
|
| 782 |
+
|
| 783 |
+
|
| 784 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 785 |
rescale_func = np.poly1d(self.coefficients)
|
| 786 |
e_list = []
|
|
|
|
| 834 |
freqs = None,
|
| 835 |
pipeline = None,
|
| 836 |
current_step = 0,
|
|
|
|
| 837 |
is_uncond=False,
|
| 838 |
max_steps = 0,
|
| 839 |
slg_layers=None,
|
|
|
|
| 842 |
fps = None,
|
| 843 |
causal_block_size = 1,
|
| 844 |
causal_attention = False,
|
| 845 |
+
x_neg = None
|
| 846 |
):
|
| 847 |
+
# dtype = self.blocks[0].self_attn.q.weight.dtype
|
| 848 |
+
dtype = self.patch_embedding.weight.dtype
|
| 849 |
|
| 850 |
if self.model_type == 'i2v':
|
| 851 |
assert clip_fea is not None and y is not None
|
|
|
|
| 858 |
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 859 |
|
| 860 |
# embeddings
|
| 861 |
+
x = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x]
|
| 862 |
+
if x_neg !=None:
|
| 863 |
+
x_neg = [self.patch_embedding(u.unsqueeze(0)).to(dtype) for u in x_neg]
|
| 864 |
|
| 865 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
| 866 |
embed_sizes = grid_sizes[0]
|
|
|
|
| 884 |
|
| 885 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 886 |
x = x[0]
|
| 887 |
+
if x_neg !=None:
|
| 888 |
+
x_neg = [u.flatten(2).transpose(1, 2) for u in x_neg]
|
| 889 |
+
x_neg = x_neg[0]
|
| 890 |
|
| 891 |
if t.dim() == 2:
|
| 892 |
b, f = t.shape
|
| 893 |
_flag_df = True
|
| 894 |
else:
|
| 895 |
_flag_df = False
|
|
|
|
| 896 |
e = self.time_embedding(
|
| 897 |
+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype) # self.patch_embedding.weight.dtype)
|
| 898 |
) # b, dim
|
| 899 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(e.dtype)
|
| 900 |
|
| 901 |
if self.inject_sample_info:
|
| 902 |
fps = torch.tensor(fps, dtype=torch.long, device=device)
|
| 903 |
|
| 904 |
+
fps_emb = self.fps_embedding(fps).to(dtype) # float()
|
| 905 |
if _flag_df:
|
| 906 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(t.shape[1], 1, 1)
|
| 907 |
else:
|
| 908 |
e0 = e0 + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
|
| 909 |
|
| 910 |
# context
|
| 911 |
+
context = [self.text_embedding( torch.cat( [u, u.new_zeros(self.text_len - u.size(0), u.size(1))] ).unsqueeze(0) ) for u in context ]
|
| 912 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 913 |
if clip_fea is not None:
|
| 914 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 915 |
+
context = [ torch.cat( [context_clip, u ], dim=1 ) for u in context ]
|
|
|
|
|
|
|
| 916 |
|
| 917 |
+
joint_pass = len(context) > 0
|
| 918 |
+
x_list = [x]
|
| 919 |
if joint_pass:
|
| 920 |
+
if x_neg == None:
|
| 921 |
+
x_list += [x.clone() for i in range(len(context) - 1) ]
|
| 922 |
+
else:
|
| 923 |
+
x_list += [x.clone() for i in range(len(context) - 2) ] + [x_neg]
|
| 924 |
is_uncond = False
|
|
|
|
|
|
|
|
|
|
| 925 |
del x
|
| 926 |
+
context_list = context
|
| 927 |
|
| 928 |
# arguments
|
| 929 |
|
|
|
|
| 982 |
if callback != None:
|
| 983 |
callback(-1, None, False, True)
|
| 984 |
if pipeline._interrupt:
|
| 985 |
+
return [None] * len(x_list)
|
|
|
|
|
|
|
|
|
|
| 986 |
|
| 987 |
if slg_layers is not None and block_idx in slg_layers:
|
| 988 |
if is_uncond and not joint_pass:
|
|
|
|
| 1017 |
x_list[i] = self.unpatchify(x, grid_sizes)
|
| 1018 |
del x
|
| 1019 |
|
| 1020 |
+
return [x[0].float() for x in x_list]
|
|
|
|
|
|
|
|
|
|
| 1021 |
|
| 1022 |
def unpatchify(self, x, grid_sizes):
|
| 1023 |
r"""
|
wan/modules/vae.py
CHANGED
|
@@ -752,7 +752,8 @@ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
|
|
| 752 |
logging.info(f'loading {pretrained_path}')
|
| 753 |
# model.load_state_dict(
|
| 754 |
# torch.load(pretrained_path, map_location=device), assign=True)
|
| 755 |
-
offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
|
|
|
|
| 756 |
return model
|
| 757 |
|
| 758 |
|
|
@@ -782,20 +783,22 @@ class WanVAE:
|
|
| 782 |
self.model = _video_vae(
|
| 783 |
pretrained_path=vae_pth,
|
| 784 |
z_dim=z_dim,
|
| 785 |
-
).eval() #.requires_grad_(False).to(device)
|
| 786 |
-
|
| 787 |
def encode(self, videos, tile_size = 256, any_end_frame = False):
|
| 788 |
"""
|
| 789 |
videos: A list of videos each with shape [C, T, H, W].
|
| 790 |
"""
|
|
|
|
|
|
|
| 791 |
if tile_size > 0:
|
| 792 |
-
return [ self.model.spatial_tiled_encode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
| 793 |
else:
|
| 794 |
-
return [ self.model.encode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
| 795 |
|
| 796 |
|
| 797 |
def decode(self, zs, tile_size, any_end_frame = False):
|
| 798 |
if tile_size > 0:
|
| 799 |
-
return [ self.model.spatial_tiled_decode(u.unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
| 800 |
else:
|
| 801 |
-
return [ self.model.decode(u.unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
|
|
|
| 752 |
logging.info(f'loading {pretrained_path}')
|
| 753 |
# model.load_state_dict(
|
| 754 |
# torch.load(pretrained_path, map_location=device), assign=True)
|
| 755 |
+
# offload.load_model_data(model, pretrained_path.replace(".pth", "_bf16.safetensors"), writable_tensors= False)
|
| 756 |
+
offload.load_model_data(model, pretrained_path.replace(".pth", ".safetensors"), writable_tensors= False)
|
| 757 |
return model
|
| 758 |
|
| 759 |
|
|
|
|
| 783 |
self.model = _video_vae(
|
| 784 |
pretrained_path=vae_pth,
|
| 785 |
z_dim=z_dim,
|
| 786 |
+
).to(dtype).eval() #.requires_grad_(False).to(device)
|
| 787 |
+
|
| 788 |
def encode(self, videos, tile_size = 256, any_end_frame = False):
|
| 789 |
"""
|
| 790 |
videos: A list of videos each with shape [C, T, H, W].
|
| 791 |
"""
|
| 792 |
+
original_dtype = videos[0].dtype
|
| 793 |
+
|
| 794 |
if tile_size > 0:
|
| 795 |
+
return [ self.model.spatial_tiled_encode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
| 796 |
else:
|
| 797 |
+
return [ self.model.encode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).float().squeeze(0) for u in videos ]
|
| 798 |
|
| 799 |
|
| 800 |
def decode(self, zs, tile_size, any_end_frame = False):
|
| 801 |
if tile_size > 0:
|
| 802 |
+
return [ self.model.spatial_tiled_decode(u.to(self.dtype).unsqueeze(0), self.scale, tile_size, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
| 803 |
else:
|
| 804 |
+
return [ self.model.decode(u.to(self.dtype).unsqueeze(0), self.scale, any_end_frame=any_end_frame).clamp_(-1, 1).float().squeeze(0) for u in zs ]
|
wan/text2video.py
CHANGED
|
@@ -52,7 +52,9 @@ class WanT2V:
|
|
| 52 |
model_filename = None,
|
| 53 |
text_encoder_filename = None,
|
| 54 |
quantizeTransformer = False,
|
| 55 |
-
dtype = torch.bfloat16
|
|
|
|
|
|
|
| 56 |
):
|
| 57 |
self.device = torch.device(f"cuda")
|
| 58 |
self.config = config
|
|
@@ -71,24 +73,23 @@ class WanT2V:
|
|
| 71 |
|
| 72 |
self.vae_stride = config.vae_stride
|
| 73 |
self.patch_size = config.patch_size
|
| 74 |
-
|
| 75 |
|
| 76 |
self.vae = WanVAE(
|
| 77 |
-
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 78 |
device=self.device)
|
| 79 |
|
| 80 |
logging.info(f"Creating WanModel from {model_filename}")
|
| 81 |
from mmgp import offload
|
| 82 |
-
|
| 83 |
-
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
|
| 84 |
-
# offload.load_model_data(self.model, "
|
|
|
|
|
|
|
| 85 |
# self.model.cpu()
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
# offload.save_model(self.model, "
|
| 90 |
-
if self.dtype == torch.float16:
|
| 91 |
-
self.vae.model.to(self.dtype)
|
| 92 |
self.model.eval().requires_grad_(False)
|
| 93 |
|
| 94 |
|
|
@@ -252,6 +253,15 @@ class WanT2V:
|
|
| 252 |
|
| 253 |
return self.vae.decode(trimed_zs, tile_size= tile_size)
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
def generate(self,
|
| 256 |
input_prompt,
|
| 257 |
input_frames= None,
|
|
@@ -320,8 +330,15 @@ class WanT2V:
|
|
| 320 |
seed_g = torch.Generator(device=self.device)
|
| 321 |
seed_g.manual_seed(seed)
|
| 322 |
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
if target_camera != None:
|
| 326 |
size = (source_video.shape[2], source_video.shape[1])
|
| 327 |
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
|
@@ -346,8 +363,12 @@ class WanT2V:
|
|
| 346 |
target_shape = list(z0[0].shape)
|
| 347 |
target_shape[0] = int(target_shape[0] / 2)
|
| 348 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
F = frame_num
|
| 350 |
-
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 351 |
size[1] // self.vae_stride[1],
|
| 352 |
size[0] // self.vae_stride[2])
|
| 353 |
|
|
@@ -355,8 +376,8 @@ class WanT2V:
|
|
| 355 |
(self.patch_size[1] * self.patch_size[2]) *
|
| 356 |
target_shape[1])
|
| 357 |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
|
| 361 |
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
|
| 362 |
|
|
@@ -393,21 +414,15 @@ class WanT2V:
|
|
| 393 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
| 394 |
else:
|
| 395 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
|
| 399 |
|
| 400 |
if target_camera != None:
|
| 401 |
-
|
| 402 |
-
arg_c.update(recam_dict)
|
| 403 |
-
arg_null.update(recam_dict)
|
| 404 |
-
arg_both.update(recam_dict)
|
| 405 |
|
| 406 |
if input_frames != None:
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
arg_null.update(vace_dict)
|
| 410 |
-
arg_both.update(vace_dict)
|
| 411 |
|
| 412 |
if self.model.enable_teacache:
|
| 413 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
|
@@ -424,39 +439,68 @@ class WanT2V:
|
|
| 424 |
timestep = [t]
|
| 425 |
offload.set_step_no_for_lora(self.model, i)
|
| 426 |
timestep = torch.stack(timestep)
|
| 427 |
-
|
|
|
|
| 428 |
if joint_pass:
|
| 429 |
-
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
if self._interrupt:
|
| 432 |
return None
|
| 433 |
else:
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
# del latent_model_input
|
| 444 |
|
| 445 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
-
|
| 452 |
-
|
| 453 |
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
|
| 461 |
temp_x0 = sample_scheduler.step(
|
| 462 |
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
|
@@ -473,8 +517,12 @@ class WanT2V:
|
|
| 473 |
x0 = latents
|
| 474 |
|
| 475 |
if input_frames == None:
|
|
|
|
|
|
|
|
|
|
| 476 |
videos = self.vae.decode(x0, VAE_tile_size)
|
| 477 |
else:
|
|
|
|
| 478 |
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
| 479 |
|
| 480 |
del latents
|
|
|
|
| 52 |
model_filename = None,
|
| 53 |
text_encoder_filename = None,
|
| 54 |
quantizeTransformer = False,
|
| 55 |
+
dtype = torch.bfloat16,
|
| 56 |
+
VAE_dtype = torch.float32,
|
| 57 |
+
mixed_precision_transformer = False
|
| 58 |
):
|
| 59 |
self.device = torch.device(f"cuda")
|
| 60 |
self.config = config
|
|
|
|
| 73 |
|
| 74 |
self.vae_stride = config.vae_stride
|
| 75 |
self.patch_size = config.patch_size
|
|
|
|
| 76 |
|
| 77 |
self.vae = WanVAE(
|
| 78 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), dtype= VAE_dtype,
|
| 79 |
device=self.device)
|
| 80 |
|
| 81 |
logging.info(f"Creating WanModel from {model_filename}")
|
| 82 |
from mmgp import offload
|
| 83 |
+
# model_filename
|
| 84 |
+
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False ) #, forcedConfigPath= "e:/vace_config.json")
|
| 85 |
+
# offload.load_model_data(self.model, "e:/vace.safetensors")
|
| 86 |
+
# offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
|
| 87 |
+
# self.model.to(torch.bfloat16)
|
| 88 |
# self.model.cpu()
|
| 89 |
+
self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype, True)
|
| 90 |
+
offload.change_dtype(self.model, dtype, True)
|
| 91 |
+
# offload.save_model(self.model, "mvace.safetensors", config_file_path="e:/vace_config.json")
|
| 92 |
+
# offload.save_model(self.model, "phantom_1.3B.safetensors")
|
|
|
|
|
|
|
| 93 |
self.model.eval().requires_grad_(False)
|
| 94 |
|
| 95 |
|
|
|
|
| 253 |
|
| 254 |
return self.vae.decode(trimed_zs, tile_size= tile_size)
|
| 255 |
|
| 256 |
+
def get_vae_latents(self, ref_images, device, tile_size= 0):
|
| 257 |
+
ref_vae_latents = []
|
| 258 |
+
for ref_image in ref_images:
|
| 259 |
+
ref_image = TF.to_tensor(ref_image).sub_(0.5).div_(0.5).to(self.device)
|
| 260 |
+
img_vae_latent = self.vae.encode([ref_image.unsqueeze(1)], tile_size= tile_size)
|
| 261 |
+
ref_vae_latents.append(img_vae_latent[0])
|
| 262 |
+
|
| 263 |
+
return torch.cat(ref_vae_latents, dim=1)
|
| 264 |
+
|
| 265 |
def generate(self,
|
| 266 |
input_prompt,
|
| 267 |
input_frames= None,
|
|
|
|
| 330 |
seed_g = torch.Generator(device=self.device)
|
| 331 |
seed_g.manual_seed(seed)
|
| 332 |
|
| 333 |
+
if self._interrupt:
|
| 334 |
+
return None
|
| 335 |
+
context = self.text_encoder([input_prompt], self.device)[0]
|
| 336 |
+
context_null = self.text_encoder([n_prompt], self.device)[0]
|
| 337 |
+
context = context.to(self.dtype)
|
| 338 |
+
context_null = context_null.to(self.dtype)
|
| 339 |
+
input_ref_images_neg = None
|
| 340 |
+
phantom = False
|
| 341 |
+
|
| 342 |
if target_camera != None:
|
| 343 |
size = (source_video.shape[2], source_video.shape[1])
|
| 344 |
source_video = source_video.to(dtype=self.dtype , device=self.device)
|
|
|
|
| 363 |
target_shape = list(z0[0].shape)
|
| 364 |
target_shape[0] = int(target_shape[0] / 2)
|
| 365 |
else:
|
| 366 |
+
if input_ref_images != None: # Phantom Ref images
|
| 367 |
+
phantom = True
|
| 368 |
+
input_ref_images = [self.get_vae_latents(input_ref_images, self.device)]
|
| 369 |
+
input_ref_images_neg = [torch.zeros_like(input_ref_images[0])]
|
| 370 |
F = frame_num
|
| 371 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + (input_ref_images[0].shape[1] if input_ref_images != None else 0),
|
| 372 |
size[1] // self.vae_stride[1],
|
| 373 |
size[0] // self.vae_stride[2])
|
| 374 |
|
|
|
|
| 376 |
(self.patch_size[1] * self.patch_size[2]) *
|
| 377 |
target_shape[1])
|
| 378 |
|
| 379 |
+
if self._interrupt:
|
| 380 |
+
return None
|
| 381 |
|
| 382 |
noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
|
| 383 |
|
|
|
|
| 414 |
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
|
| 415 |
else:
|
| 416 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 417 |
+
|
| 418 |
+
kwargs = {'freqs': freqs, 'pipeline': self, 'callback': callback}
|
|
|
|
| 419 |
|
| 420 |
if target_camera != None:
|
| 421 |
+
kwargs.update({'cam_emb': cam_emb})
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
if input_frames != None:
|
| 424 |
+
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale})
|
| 425 |
+
|
|
|
|
|
|
|
| 426 |
|
| 427 |
if self.model.enable_teacache:
|
| 428 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
|
|
|
| 439 |
timestep = [t]
|
| 440 |
offload.set_step_no_for_lora(self.model, i)
|
| 441 |
timestep = torch.stack(timestep)
|
| 442 |
+
kwargs["current_step"] = i
|
| 443 |
+
kwargs["t"] = timestep
|
| 444 |
if joint_pass:
|
| 445 |
+
if phantom:
|
| 446 |
+
pos_it, pos_i, neg = self.model(
|
| 447 |
+
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)],
|
| 448 |
+
x_neg = [torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)],
|
| 449 |
+
context = [context, context_null, context_null], **kwargs)
|
| 450 |
+
else:
|
| 451 |
+
noise_pred_cond, noise_pred_uncond = self.model(
|
| 452 |
+
latent_model_input, slg_layers=slg_layers_local, context = [context, context_null], **kwargs)
|
| 453 |
if self._interrupt:
|
| 454 |
return None
|
| 455 |
else:
|
| 456 |
+
if phantom:
|
| 457 |
+
pos_it = self.model(
|
| 458 |
+
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context], **kwargs
|
| 459 |
+
)[0]
|
| 460 |
+
if self._interrupt:
|
| 461 |
+
return None
|
| 462 |
+
pos_i = self.model(
|
| 463 |
+
[torch.cat([latent[:,:-ref_latent.shape[1]], ref_latent], dim=1) for latent, ref_latent in zip(latent_model_input, input_ref_images)], context = [context_null],**kwargs
|
| 464 |
+
)[0]
|
| 465 |
+
if self._interrupt:
|
| 466 |
+
return None
|
| 467 |
+
neg = self.model(
|
| 468 |
+
[torch.cat([latent[:,:-ref_latent_neg.shape[1]], ref_latent_neg], dim=1) for latent, ref_latent_neg in zip(latent_model_input, input_ref_images_neg)], context = [context_null], **kwargs
|
| 469 |
+
)[0]
|
| 470 |
+
if self._interrupt:
|
| 471 |
+
return None
|
| 472 |
+
else:
|
| 473 |
+
noise_pred_cond = self.model(
|
| 474 |
+
latent_model_input, is_uncond = False, context = [context], **kwargs)[0]
|
| 475 |
+
if self._interrupt:
|
| 476 |
+
return None
|
| 477 |
+
noise_pred_uncond = self.model(
|
| 478 |
+
latent_model_input, is_uncond = True, slg_layers=slg_layers_local,context = [context_null], **kwargs)[0]
|
| 479 |
+
if self._interrupt:
|
| 480 |
+
return None
|
| 481 |
|
| 482 |
# del latent_model_input
|
| 483 |
|
| 484 |
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 485 |
+
if phantom:
|
| 486 |
+
guide_scale_img= 5.0
|
| 487 |
+
guide_scale_text= guide_scale #7.5
|
| 488 |
+
noise_pred = neg + guide_scale_img * (pos_i - neg) + guide_scale_text * (pos_it - pos_i)
|
| 489 |
+
else:
|
| 490 |
+
noise_pred_text = noise_pred_cond
|
| 491 |
+
if cfg_star_switch:
|
| 492 |
+
positive_flat = noise_pred_text.view(batch_size, -1)
|
| 493 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
| 494 |
|
| 495 |
+
alpha = optimized_scale(positive_flat,negative_flat)
|
| 496 |
+
alpha = alpha.view(batch_size, 1, 1, 1)
|
| 497 |
|
| 498 |
+
if (i <= cfg_zero_step):
|
| 499 |
+
noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
|
| 500 |
+
else:
|
| 501 |
+
noise_pred_uncond *= alpha
|
| 502 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 503 |
+
noise_pred_uncond, noise_pred_cond, noise_pred_text, pos_it, pos_i, neg = None, None, None, None, None, None
|
| 504 |
|
| 505 |
temp_x0 = sample_scheduler.step(
|
| 506 |
noise_pred[:, :target_shape[1]].unsqueeze(0),
|
|
|
|
| 517 |
x0 = latents
|
| 518 |
|
| 519 |
if input_frames == None:
|
| 520 |
+
if phantom:
|
| 521 |
+
# phantom post processing
|
| 522 |
+
x0 = [x0_[:,:-input_ref_images[0].shape[1]] for x0_ in x0]
|
| 523 |
videos = self.vae.decode(x0, VAE_tile_size)
|
| 524 |
else:
|
| 525 |
+
# vace post processing
|
| 526 |
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
| 527 |
|
| 528 |
del latents
|
wan/utils/utils.py
CHANGED
|
@@ -69,18 +69,29 @@ def remove_background(img, session=None):
|
|
| 69 |
|
| 70 |
|
| 71 |
|
| 72 |
-
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background ):
|
| 73 |
if rm_background:
|
| 74 |
session = new_session()
|
| 75 |
|
| 76 |
output_list =[]
|
| 77 |
for img in img_list:
|
| 78 |
width, height = img.size
|
| 79 |
-
scale = (budget_height * budget_width / (height * width))**(1/2)
|
| 80 |
-
new_height = int( round(height * scale / 16) * 16)
|
| 81 |
-
new_width = int( round(width * scale / 16) * 16)
|
| 82 |
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
if rm_background:
|
| 85 |
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
| 86 |
output_list.append(resized_image)
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
|
| 72 |
+
def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, fit_into_canvas = False ):
|
| 73 |
if rm_background:
|
| 74 |
session = new_session()
|
| 75 |
|
| 76 |
output_list =[]
|
| 77 |
for img in img_list:
|
| 78 |
width, height = img.size
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
+
if fit_into_canvas:
|
| 81 |
+
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
|
| 82 |
+
scale = min(budget_height / height, budget_width / width)
|
| 83 |
+
new_height = int(height * scale)
|
| 84 |
+
new_width = int(width * scale)
|
| 85 |
+
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
| 86 |
+
top = (budget_height - new_height) // 2
|
| 87 |
+
left = (budget_width - new_width) // 2
|
| 88 |
+
white_canvas[top:top + new_height, left:left + new_width] = np.array(resized_image)
|
| 89 |
+
resized_image = Image.fromarray(white_canvas)
|
| 90 |
+
else:
|
| 91 |
+
scale = (budget_height * budget_width / (height * width))**(1/2)
|
| 92 |
+
new_height = int( round(height * scale / 16) * 16)
|
| 93 |
+
new_width = int( round(width * scale / 16) * 16)
|
| 94 |
+
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
| 95 |
if rm_background:
|
| 96 |
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
| 97 |
output_list.append(resized_image)
|
wgp.py
CHANGED
|
@@ -40,7 +40,7 @@ global_queue_ref = []
|
|
| 40 |
AUTOSAVE_FILENAME = "queue.zip"
|
| 41 |
PROMPT_VARS_MAX = 10
|
| 42 |
|
| 43 |
-
target_mmgp_version = "3.4.
|
| 44 |
from importlib.metadata import version
|
| 45 |
mmgp_version = version("mmgp")
|
| 46 |
if mmgp_version != target_mmgp_version:
|
|
@@ -133,10 +133,11 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 133 |
|
| 134 |
model_filename = state["model_filename"]
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
|
| 138 |
-
|
| 139 |
-
inputs = state.get(get_model_type(model_filename), None)
|
| 140 |
inputs["state"] = state
|
| 141 |
inputs.pop("lset_name")
|
| 142 |
if inputs == None:
|
|
@@ -176,7 +177,7 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 176 |
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
| 177 |
return
|
| 178 |
|
| 179 |
-
if "1.3B" in model_filename and width * height > 848*480:
|
| 180 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 181 |
return
|
| 182 |
|
|
@@ -186,8 +187,28 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 186 |
if video_length > sliding_window_size:
|
| 187 |
gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
image_start = inputs["image_start"]
|
| 192 |
video_source = inputs["video_source"]
|
| 193 |
keep_frames_video_source = inputs["keep_frames_video_source"]
|
|
@@ -1362,10 +1383,6 @@ quantizeTransformer = args.quantize_transformer
|
|
| 1362 |
check_loras = args.check_loras ==1
|
| 1363 |
advanced = args.advanced
|
| 1364 |
|
| 1365 |
-
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors",
|
| 1366 |
-
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"]
|
| 1367 |
-
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
|
| 1368 |
-
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
| 1369 |
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
| 1370 |
server_config_filename = "wgp_config.json"
|
| 1371 |
if not os.path.isdir("settings"):
|
|
@@ -1401,11 +1418,32 @@ else:
|
|
| 1401 |
text = reader.read()
|
| 1402 |
server_config = json.loads(text)
|
| 1403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1404 |
|
| 1405 |
-
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B"]
|
| 1406 |
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
| 1407 |
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
| 1408 |
-
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B"
|
|
|
|
|
|
|
| 1409 |
|
| 1410 |
|
| 1411 |
def get_model_type(model_filename):
|
|
@@ -1417,29 +1455,47 @@ def get_model_type(model_filename):
|
|
| 1417 |
def test_class_i2v(model_filename):
|
| 1418 |
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
|
| 1419 |
|
| 1420 |
-
def get_model_name(model_filename):
|
| 1421 |
if "Fun" in model_filename:
|
| 1422 |
model_name = "Fun InP image2video"
|
| 1423 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
|
| 1424 |
elif "Vace" in model_filename:
|
| 1425 |
model_name = "Vace ControlNet"
|
| 1426 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
|
| 1427 |
elif "image" in model_filename:
|
| 1428 |
model_name = "Wan2.1 image2video"
|
| 1429 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1430 |
elif "recam" in model_filename:
|
| 1431 |
model_name = "ReCamMaster"
|
| 1432 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
|
| 1433 |
elif "FLF2V" in model_filename:
|
| 1434 |
model_name = "Wan2.1 FLF2V"
|
| 1435 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
|
|
|
| 1436 |
elif "sky_reels2_diffusion_forcing" in model_filename:
|
| 1437 |
-
model_name = "SkyReels2
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1438 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1439 |
else:
|
| 1440 |
model_name = "Wan2.1 text2video"
|
| 1441 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1442 |
-
|
|
|
|
| 1443 |
return model_name
|
| 1444 |
|
| 1445 |
|
|
@@ -1493,13 +1549,28 @@ def get_default_settings(filename):
|
|
| 1493 |
"slg_end_perc": 90
|
| 1494 |
}
|
| 1495 |
|
| 1496 |
-
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B"):
|
| 1497 |
ui_defaults.update({
|
| 1498 |
"guidance_scale": 6.0,
|
| 1499 |
"flow_shift": 8,
|
| 1500 |
-
"sliding_window_discard_last_frames" : 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1501 |
})
|
| 1502 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1503 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 1504 |
json.dump(ui_defaults, f, indent=4)
|
| 1505 |
else:
|
|
@@ -1649,7 +1720,7 @@ def download_models(transformer_filename, text_encoder_filename):
|
|
| 1649 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 1650 |
repoId = "DeepBeepMeep/Wan2.1"
|
| 1651 |
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
|
| 1652 |
-
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.
|
| 1653 |
targetRoot = "ckpts/"
|
| 1654 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
| 1655 |
if len(files)==0:
|
|
@@ -1763,12 +1834,12 @@ def setup_loras(model_filename, transformer, lora_dir, lora_preselected_preset,
|
|
| 1763 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
| 1764 |
|
| 1765 |
|
| 1766 |
-
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
|
| 1767 |
|
| 1768 |
cfg = WAN_CONFIGS['t2v-14B']
|
| 1769 |
# cfg = WAN_CONFIGS['t2v-1.3B']
|
| 1770 |
print(f"Loading '{model_filename}' model...")
|
| 1771 |
-
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B"):
|
| 1772 |
model_factory = wan.DTT2V
|
| 1773 |
else:
|
| 1774 |
model_factory = wan.WanT2V
|
|
@@ -1779,52 +1850,32 @@ def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = t
|
|
| 1779 |
model_filename=model_filename,
|
| 1780 |
text_encoder_filename= text_encoder_filename,
|
| 1781 |
quantizeTransformer = quantizeTransformer,
|
| 1782 |
-
dtype = dtype
|
|
|
|
|
|
|
| 1783 |
)
|
| 1784 |
|
| 1785 |
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
| 1786 |
|
| 1787 |
return wan_model, pipe
|
| 1788 |
|
| 1789 |
-
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16):
|
| 1790 |
|
| 1791 |
print(f"Loading '{model_filename}' model...")
|
| 1792 |
|
| 1793 |
-
|
| 1794 |
-
|
| 1795 |
-
|
| 1796 |
-
|
| 1797 |
-
|
| 1798 |
-
|
| 1799 |
-
|
| 1800 |
-
|
| 1801 |
-
|
| 1802 |
-
|
| 1803 |
-
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
-
dtype = dtype
|
| 1807 |
-
)
|
| 1808 |
-
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
| 1809 |
-
|
| 1810 |
-
elif value == '480P':
|
| 1811 |
-
cfg = WAN_CONFIGS['i2v-14B']
|
| 1812 |
-
wan_model = wan.WanI2V(
|
| 1813 |
-
config=cfg,
|
| 1814 |
-
checkpoint_dir="ckpts",
|
| 1815 |
-
rank=0,
|
| 1816 |
-
t5_fsdp=False,
|
| 1817 |
-
dit_fsdp=False,
|
| 1818 |
-
use_usp=False,
|
| 1819 |
-
i2v720p= False,
|
| 1820 |
-
model_filename=model_filename,
|
| 1821 |
-
text_encoder_filename=text_encoder_filename,
|
| 1822 |
-
quantizeTransformer = quantizeTransformer,
|
| 1823 |
-
dtype = dtype
|
| 1824 |
-
)
|
| 1825 |
-
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
| 1826 |
-
else:
|
| 1827 |
-
raise Exception("Model i2v {value} not supported")
|
| 1828 |
return wan_model, pipe
|
| 1829 |
|
| 1830 |
|
|
@@ -1836,18 +1887,22 @@ def load_models(model_filename):
|
|
| 1836 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
| 1837 |
|
| 1838 |
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
| 1839 |
-
|
| 1840 |
-
# default_dtype = torch.bfloat16
|
| 1841 |
-
if default_dtype == torch.float16 or args.fp16:
|
| 1842 |
print("Switching to f16 model as GPU architecture doesn't support bf16")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1843 |
if "quanto" in model_filename:
|
| 1844 |
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
|
| 1845 |
download_models(model_filename, text_encoder_filename)
|
|
|
|
|
|
|
| 1846 |
if test_class_i2v(model_filename):
|
| 1847 |
res720P = "720p" in model_filename
|
| 1848 |
-
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype )
|
| 1849 |
else:
|
| 1850 |
-
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype)
|
| 1851 |
wan_model._model_file_name = model_filename
|
| 1852 |
kwargs = { "extraModelsToQuantize": None}
|
| 1853 |
if profile == 2 or profile == 4:
|
|
@@ -1888,8 +1943,13 @@ def get_default_flow(filename, i2v):
|
|
| 1888 |
|
| 1889 |
|
| 1890 |
def generate_header(model_filename, compile, attention_mode):
|
| 1891 |
-
|
| 1892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1893 |
if attention_mode not in attention_modes_installed:
|
| 1894 |
header += " -NOT INSTALLED-"
|
| 1895 |
elif attention_mode not in attention_modes_supported:
|
|
@@ -1907,6 +1967,8 @@ def generate_header(model_filename, compile, attention_mode):
|
|
| 1907 |
def apply_changes( state,
|
| 1908 |
transformer_types_choices,
|
| 1909 |
text_encoder_choice,
|
|
|
|
|
|
|
| 1910 |
save_path_choice,
|
| 1911 |
attention_choice,
|
| 1912 |
compile_choice,
|
|
@@ -1922,7 +1984,7 @@ def apply_changes( state,
|
|
| 1922 |
if args.lock_config:
|
| 1923 |
return
|
| 1924 |
if gen_in_progress:
|
| 1925 |
-
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
| 1926 |
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
| 1927 |
server_config = {"attention_mode" : attention_choice,
|
| 1928 |
"transformer_types": transformer_types_choices,
|
|
@@ -1931,6 +1993,8 @@ def apply_changes( state,
|
|
| 1931 |
"compile" : compile_choice,
|
| 1932 |
"profile" : profile_choice,
|
| 1933 |
"vae_config" : vae_config_choice,
|
|
|
|
|
|
|
| 1934 |
"metadata_type": metadata_choice,
|
| 1935 |
"transformer_quantization" : quantization_choice,
|
| 1936 |
"boost" : boost_choice,
|
|
@@ -2052,12 +2116,9 @@ def build_callback(state, pipe, send_cmd, status, num_inference_steps):
|
|
| 2052 |
return callback
|
| 2053 |
def abort_generation(state):
|
| 2054 |
gen = get_gen_info(state)
|
| 2055 |
-
if "in_progress" in gen:
|
| 2056 |
|
| 2057 |
-
|
| 2058 |
-
gen["extra_orders"] = 0
|
| 2059 |
-
if wan_model != None:
|
| 2060 |
-
wan_model._interrupt= True
|
| 2061 |
msg = "Processing Request to abort Current Generation"
|
| 2062 |
gen["status"] = msg
|
| 2063 |
gr.Info(msg)
|
|
@@ -2140,13 +2201,6 @@ def finalize_generation(state):
|
|
| 2140 |
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
|
| 2141 |
|
| 2142 |
|
| 2143 |
-
def refresh_gallery_on_trigger(state):
|
| 2144 |
-
gen = get_gen_info(state)
|
| 2145 |
-
|
| 2146 |
-
if(gen.get("update_gallery", False)):
|
| 2147 |
-
gen['update_gallery'] = False
|
| 2148 |
-
return gr.update(value=gen.get("file_list", []))
|
| 2149 |
-
|
| 2150 |
def select_video(state , event_data: gr.EventData):
|
| 2151 |
data= event_data._data
|
| 2152 |
gen = get_gen_info(state)
|
|
@@ -2385,6 +2439,8 @@ def generate_video(
|
|
| 2385 |
# VAE Tiling
|
| 2386 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
| 2387 |
if vae_config == 0:
|
|
|
|
|
|
|
| 2388 |
if device_mem_capacity >= 24000:
|
| 2389 |
use_vae_config = 1
|
| 2390 |
elif device_mem_capacity >= 8000:
|
|
@@ -2497,6 +2553,7 @@ def generate_video(
|
|
| 2497 |
max_frames_to_generate = video_length
|
| 2498 |
diffusion_forcing = "diffusion_forcing" in model_filename
|
| 2499 |
vace = "Vace" in model_filename
|
|
|
|
| 2500 |
if diffusion_forcing or vace:
|
| 2501 |
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
| 2502 |
if diffusion_forcing and source_video != None:
|
|
@@ -2536,6 +2593,7 @@ def generate_video(
|
|
| 2536 |
extra_windows = 0
|
| 2537 |
guide_start_frame = 0
|
| 2538 |
video_length = first_window_video_length
|
|
|
|
| 2539 |
while not abort:
|
| 2540 |
if sliding_window:
|
| 2541 |
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
|
@@ -2550,7 +2608,9 @@ def generate_video(
|
|
| 2550 |
window_no += 1
|
| 2551 |
gen["window_no"] = window_no
|
| 2552 |
|
| 2553 |
-
if
|
|
|
|
|
|
|
| 2554 |
if video_source != None and len(video_source) > 0 and window_no == 1:
|
| 2555 |
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
|
| 2556 |
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
|
|
@@ -2559,7 +2619,7 @@ def generate_video(
|
|
| 2559 |
prefix_video_frames_count = prefix_video.shape[1]
|
| 2560 |
pre_video_guide = prefix_video[:, -reuse_frames:]
|
| 2561 |
|
| 2562 |
-
|
| 2563 |
# video_prompt_type = video_prompt_type +"G"
|
| 2564 |
image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
|
| 2565 |
video_guide_copy = video_guide
|
|
@@ -2610,7 +2670,7 @@ def generate_video(
|
|
| 2610 |
progress_args = [0, status + " - Encoding Prompt"]
|
| 2611 |
send_cmd("progress", progress_args)
|
| 2612 |
|
| 2613 |
-
samples = torch.empty( (1,2)) #for testing
|
| 2614 |
# if False:
|
| 2615 |
|
| 2616 |
try:
|
|
@@ -2633,7 +2693,6 @@ def generate_video(
|
|
| 2633 |
guide_scale=guidance_scale,
|
| 2634 |
n_prompt=negative_prompt,
|
| 2635 |
seed=seed,
|
| 2636 |
-
offload_model=False,
|
| 2637 |
callback=callback,
|
| 2638 |
enable_RIFLEx = enable_RIFLEx,
|
| 2639 |
VAE_tile_size = VAE_tile_size,
|
|
@@ -2738,6 +2797,7 @@ def generate_video(
|
|
| 2738 |
if samples == None:
|
| 2739 |
abort = True
|
| 2740 |
state["prompt"] = ""
|
|
|
|
| 2741 |
else:
|
| 2742 |
sample = samples.cpu()
|
| 2743 |
if True: # for testing
|
|
@@ -2839,7 +2899,6 @@ def generate_video(
|
|
| 2839 |
|
| 2840 |
print(f"New video saved to Path: "+video_path)
|
| 2841 |
file_list.append(video_path)
|
| 2842 |
-
state['update_gallery'] = True
|
| 2843 |
send_cmd("output")
|
| 2844 |
if sliding_window :
|
| 2845 |
if max_frames_to_generate > 0 and extra_windows == 0:
|
|
@@ -2847,8 +2906,6 @@ def generate_video(
|
|
| 2847 |
if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
|
| 2848 |
break
|
| 2849 |
video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
|
| 2850 |
-
else:
|
| 2851 |
-
break
|
| 2852 |
|
| 2853 |
seed += 1
|
| 2854 |
|
|
@@ -3416,7 +3473,7 @@ def prepare_inputs_dict(target, inputs ):
|
|
| 3416 |
if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
|
| 3417 |
inputs.pop("model_mode")
|
| 3418 |
|
| 3419 |
-
if not "Vace" in model_filename:
|
| 3420 |
unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
|
| 3421 |
for k in unsaved_params:
|
| 3422 |
inputs.pop(k)
|
|
@@ -3776,6 +3833,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 3776 |
diffusion_forcing = "diffusion_forcing" in model_filename
|
| 3777 |
recammaster = "recam" in model_filename
|
| 3778 |
vace = "Vace" in model_filename
|
|
|
|
| 3779 |
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
| 3780 |
if diffusion_forcing:
|
| 3781 |
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
|
@@ -3835,23 +3893,27 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 3835 |
model_mode = gr.Dropdown(visible=False)
|
| 3836 |
keep_frames_video_source = gr.Text(visible=False)
|
| 3837 |
|
| 3838 |
-
with gr.Column(visible= vace ) as video_prompt_column:
|
| 3839 |
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
| 3840 |
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
| 3841 |
with gr.Row():
|
| 3842 |
-
|
| 3843 |
-
|
| 3844 |
-
|
| 3845 |
-
|
| 3846 |
-
|
| 3847 |
-
|
| 3848 |
-
|
| 3849 |
-
|
| 3850 |
-
|
| 3851 |
-
|
| 3852 |
-
|
| 3853 |
-
|
| 3854 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3855 |
video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
|
| 3856 |
|
| 3857 |
video_prompt_type_image_refs = gr.Dropdown(
|
|
@@ -3869,7 +3931,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 3869 |
image_refs = gr.Gallery( label ="Reference Images",
|
| 3870 |
type ="pil", show_label= True,
|
| 3871 |
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
|
| 3872 |
-
value= ui_defaults.get("image_refs", None)
|
|
|
|
| 3873 |
|
| 3874 |
# with gr.Row():
|
| 3875 |
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
|
|
@@ -3929,7 +3992,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 3929 |
# ("832x1104 (3:4, 720p)", "832x1104"),
|
| 3930 |
# ("960x960 (1:1, 720p)", "960x960"),
|
| 3931 |
# 480p
|
| 3932 |
-
|
|
|
|
| 3933 |
("832x480 (16:9, 480p)", "832x480"),
|
| 3934 |
("480x832 (9:16, 480p)", "480x832"),
|
| 3935 |
# ("832x624 (4:3, 540p)", "832x624"),
|
|
@@ -4082,13 +4146,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 4082 |
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
|
| 4083 |
gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
|
| 4084 |
if diffusion_forcing:
|
| 4085 |
-
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size")
|
| 4086 |
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
| 4087 |
-
sliding_window_discard_last_frames = gr.Slider(0,
|
| 4088 |
else:
|
| 4089 |
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
| 4090 |
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
| 4091 |
-
sliding_window_discard_last_frames = gr.Slider(0,
|
| 4092 |
|
| 4093 |
|
| 4094 |
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
|
@@ -4429,13 +4493,22 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
|
| 4429 |
|
| 4430 |
quantization_choice = gr.Dropdown(
|
| 4431 |
choices=[
|
| 4432 |
-
("Int8 Quantization (recommended)", "int8"),
|
| 4433 |
("16 bits (no quantization)", "bf16"),
|
| 4434 |
],
|
| 4435 |
value= transformer_quantization,
|
| 4436 |
label="Wan Transformer Model Quantization Type (if available)",
|
| 4437 |
)
|
| 4438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4439 |
index = text_encoder_choices.index(text_encoder_filename)
|
| 4440 |
index = 0 if index ==0 else index
|
| 4441 |
text_encoder_choice = gr.Dropdown(
|
|
@@ -4446,6 +4519,16 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
|
| 4446 |
value= index,
|
| 4447 |
label="Text Encoder model"
|
| 4448 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4449 |
save_path_choice = gr.Textbox(
|
| 4450 |
label="Output Folder for Generated Videos",
|
| 4451 |
value=server_config.get("save_path", save_path)
|
|
@@ -4510,14 +4593,7 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
|
| 4510 |
value= profile,
|
| 4511 |
label="Profile (for power users only, not needed to change it)"
|
| 4512 |
)
|
| 4513 |
-
|
| 4514 |
-
# choices=[
|
| 4515 |
-
# ("Text to Video", "t2v"),
|
| 4516 |
-
# ("Image to Video", "i2v"),
|
| 4517 |
-
# ],
|
| 4518 |
-
# value= default_ui,
|
| 4519 |
-
# label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
|
| 4520 |
-
# )
|
| 4521 |
metadata_choice = gr.Dropdown(
|
| 4522 |
choices=[
|
| 4523 |
("Export JSON files", "json"),
|
|
@@ -4563,6 +4639,8 @@ def generate_configuration_tab(state, blocks, header, model_choice):
|
|
| 4563 |
state,
|
| 4564 |
transformer_types_choices,
|
| 4565 |
text_encoder_choice,
|
|
|
|
|
|
|
| 4566 |
save_path_choice,
|
| 4567 |
attention_choice,
|
| 4568 |
compile_choice,
|
|
@@ -4957,7 +5035,7 @@ def create_demo():
|
|
| 4957 |
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
| 4958 |
|
| 4959 |
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
| 4960 |
-
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.
|
| 4961 |
global model_list
|
| 4962 |
|
| 4963 |
tab_state = gr.State({ "tab_no":0 })
|
|
|
|
| 40 |
AUTOSAVE_FILENAME = "queue.zip"
|
| 41 |
PROMPT_VARS_MAX = 10
|
| 42 |
|
| 43 |
+
target_mmgp_version = "3.4.1"
|
| 44 |
from importlib.metadata import version
|
| 45 |
mmgp_version = version("mmgp")
|
| 46 |
if mmgp_version != target_mmgp_version:
|
|
|
|
| 133 |
|
| 134 |
model_filename = state["model_filename"]
|
| 135 |
|
| 136 |
+
model_type = get_model_type(model_filename)
|
| 137 |
+
inputs = state.get(model_type, None)
|
| 138 |
+
if model_choice != model_type or inputs ==None:
|
| 139 |
raise gr.Error("Webform can not be used as the App has been restarted since the form was displayed. Please refresh the page")
|
| 140 |
+
|
|
|
|
| 141 |
inputs["state"] = state
|
| 142 |
inputs.pop("lset_name")
|
| 143 |
if inputs == None:
|
|
|
|
| 177 |
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
| 178 |
return
|
| 179 |
|
| 180 |
+
if "1.3B" in model_filename and width * height > 848*480 and any( model in model_filename for model in ["image2video", "text2video"] ):
|
| 181 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 182 |
return
|
| 183 |
|
|
|
|
| 187 |
if video_length > sliding_window_size:
|
| 188 |
gr.Info(f"The Number of Frames to generate ({video_length}) is greater than the Sliding Window Size ({sliding_window_size}) , multiple Windows will be generated")
|
| 189 |
|
| 190 |
+
if "phantom" in model_filename:
|
| 191 |
+
image_refs = inputs["image_refs"]
|
| 192 |
+
|
| 193 |
+
if isinstance(image_refs, list):
|
| 194 |
+
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
| 195 |
+
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
|
| 196 |
+
from wan.utils.utils import resize_and_remove_background
|
| 197 |
+
image_refs = resize_and_remove_background(image_refs, width, height, inputs["remove_background_image_ref"] ==1, fit_into_canvas= True)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if len(prompts) > 0:
|
| 201 |
+
prompts = ["\n".join(prompts)]
|
| 202 |
|
| 203 |
+
for single_prompt in prompts:
|
| 204 |
+
extra_inputs = {
|
| 205 |
+
"prompt" : single_prompt,
|
| 206 |
+
"image_refs": image_refs,
|
| 207 |
+
}
|
| 208 |
+
inputs.update(extra_inputs)
|
| 209 |
+
add_video_task(**inputs)
|
| 210 |
+
|
| 211 |
+
elif "diffusion_forcing" in model_filename:
|
| 212 |
image_start = inputs["image_start"]
|
| 213 |
video_source = inputs["video_source"]
|
| 214 |
keep_frames_video_source = inputs["keep_frames_video_source"]
|
|
|
|
| 1383 |
check_loras = args.check_loras ==1
|
| 1384 |
advanced = args.advanced
|
| 1385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1386 |
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
| 1387 |
server_config_filename = "wgp_config.json"
|
| 1388 |
if not os.path.isdir("settings"):
|
|
|
|
| 1418 |
text = reader.read()
|
| 1419 |
server_config = json.loads(text)
|
| 1420 |
|
| 1421 |
+
# for src_path, tgt_path in zip( ["ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors"], ["ckpts/sky_reels2_diffusion_forcing_540p_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_540p_14B_bf16.safetensors"] ):
|
| 1422 |
+
# if Path(src_path).is_file():
|
| 1423 |
+
# shutil.move(src_path, tgt_path) )
|
| 1424 |
+
# for path in ["ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"]:
|
| 1425 |
+
# if Path(path).is_file():
|
| 1426 |
+
# os.remove(path)
|
| 1427 |
+
|
| 1428 |
+
path= "ckpts/sky_reels2_diffusion_forcing_1.3B_bf16.safetensors"
|
| 1429 |
+
if os.path.isfile(path) and os.path.getsize(path) > 4000000000:
|
| 1430 |
+
os.remove(path)
|
| 1431 |
+
|
| 1432 |
+
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_mbf16.safetensors",
|
| 1433 |
+
"ckpts/wan2.1_recammaster_1.3B_bf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_1.3B_mbf16.safetensors", "ckpts/sky_reels2_diffusion_forcing_14B_bf16.safetensors",
|
| 1434 |
+
"ckpts/sky_reels2_diffusion_forcing_14B_quanto_int8.safetensors", "ckpts/sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors","ckpts/sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors",
|
| 1435 |
+
"ckpts/wan2_1_phantom_1.3B_mbf16.safetensors"]
|
| 1436 |
+
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors",
|
| 1437 |
+
"ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors",
|
| 1438 |
+
"ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_bf16.safetensors", "ckpts/wan2.1_FLF2V_720p_14B_quanto_int8.safetensors"]
|
| 1439 |
+
transformer_choices = transformer_choices_t2v + transformer_choices_i2v
|
| 1440 |
|
| 1441 |
+
model_types = [ "t2v_1.3B", "vace_1.3B", "fun_inp_1.3B", "t2v", "i2v", "i2v_720p", "fun_inp", "recam_1.3B", "flf2v_720p", "sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B", "phantom_1.3B"]
|
| 1442 |
model_signatures = {"t2v": "text2video_14B", "t2v_1.3B" : "text2video_1.3B", "fun_inp_1.3B" : "Fun_InP_1.3B", "fun_inp" : "Fun_InP_14B",
|
| 1443 |
"i2v" : "image2video_480p", "i2v_720p" : "image2video_720p" , "vace_1.3B" : "Vace_1.3B", "recam_1.3B": "recammaster_1.3B",
|
| 1444 |
+
"flf2v_720p" : "FLF2V_720p", "sky_df_1.3B" : "sky_reels2_diffusion_forcing_1.3B", "sky_df_14B" : "sky_reels2_diffusion_forcing_14B",
|
| 1445 |
+
"sky_df_720p_14B" : "sky_reels2_diffusion_forcing_720p_14B",
|
| 1446 |
+
"phantom_1.3B" : "phantom_1.3B", }
|
| 1447 |
|
| 1448 |
|
| 1449 |
def get_model_type(model_filename):
|
|
|
|
| 1455 |
def test_class_i2v(model_filename):
|
| 1456 |
return "image2video" in model_filename or "Fun_InP" in model_filename or "FLF2V" in model_filename
|
| 1457 |
|
| 1458 |
+
def get_model_name(model_filename, description_container = [""]):
|
| 1459 |
if "Fun" in model_filename:
|
| 1460 |
model_name = "Fun InP image2video"
|
| 1461 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1462 |
+
description = "The Fun model is an alternative image 2 video that supports out the box End Image fixing (contrary to the original Wan image 2 video model). The 1.3B adds also image 2 to video capability to the 1.3B model."
|
| 1463 |
elif "Vace" in model_filename:
|
| 1464 |
model_name = "Vace ControlNet"
|
| 1465 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1466 |
+
description = "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video."
|
| 1467 |
elif "image" in model_filename:
|
| 1468 |
model_name = "Wan2.1 image2video"
|
| 1469 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
| 1470 |
+
if "720p" in model_filename:
|
| 1471 |
+
description = "The standard Wan Image 2 Video specialized to generate 720p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
|
| 1472 |
+
else:
|
| 1473 |
+
description = "The standard Wan Image 2 Video specialized to generate 480p images. It also offers Start and End Image support (End Image is not supported in the original model but seems to work well)"
|
| 1474 |
elif "recam" in model_filename:
|
| 1475 |
model_name = "ReCamMaster"
|
| 1476 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1477 |
+
description = "The Recam Master in theory should allow you to replay a video by applying a different camera movement. The model supports only video that are at least 81 frames long (any frame beyond will be ignored)"
|
| 1478 |
elif "FLF2V" in model_filename:
|
| 1479 |
model_name = "Wan2.1 FLF2V"
|
| 1480 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
| 1481 |
+
description = "The First Last Frame 2 Video model is the official model Image 2 Video model that support Start and End frames."
|
| 1482 |
elif "sky_reels2_diffusion_forcing" in model_filename:
|
| 1483 |
+
model_name = "SkyReels2 Diffusion Forcing"
|
| 1484 |
+
if "720p" in model_filename :
|
| 1485 |
+
model_name += " 720p"
|
| 1486 |
+
elif not "1.3B" in model_filename :
|
| 1487 |
+
model_name += " 540p"
|
| 1488 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1489 |
+
description = "The SkyReels 2 Diffusion Forcing model has been designed to generate very long videos that exceeds the usual 5s limit. You can also use this model to extend any existing video."
|
| 1490 |
+
elif "phantom" in model_filename:
|
| 1491 |
+
model_name = "Wan2.1 Phantom"
|
| 1492 |
+
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1493 |
+
description = "The Phantom model is specialized to transfer people or objects of your choice into a generated Video. It produces very nices results when used at 720p."
|
| 1494 |
else:
|
| 1495 |
model_name = "Wan2.1 text2video"
|
| 1496 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1497 |
+
description = "The original Wan Text 2 Video model. Most other models have been built on top of it"
|
| 1498 |
+
description_container[0] = description
|
| 1499 |
return model_name
|
| 1500 |
|
| 1501 |
|
|
|
|
| 1549 |
"slg_end_perc": 90
|
| 1550 |
}
|
| 1551 |
|
| 1552 |
+
if get_model_type(filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
|
| 1553 |
ui_defaults.update({
|
| 1554 |
"guidance_scale": 6.0,
|
| 1555 |
"flow_shift": 8,
|
| 1556 |
+
"sliding_window_discard_last_frames" : 0,
|
| 1557 |
+
"resolution": "1280x720" if "720p" in filename else "960x544",
|
| 1558 |
+
"sliding_window_size" : 121 if "720p" in filename else 97,
|
| 1559 |
+
"RIFLEx_setting": 2,
|
| 1560 |
+
"guidance_scale": 6,
|
| 1561 |
+
"flow_shift": 8,
|
| 1562 |
})
|
| 1563 |
|
| 1564 |
+
|
| 1565 |
+
if get_model_type(filename) in ("phantom_1.3B"):
|
| 1566 |
+
ui_defaults.update({
|
| 1567 |
+
"guidance_scale": 7.5,
|
| 1568 |
+
"flow_shift": 5,
|
| 1569 |
+
"resolution": "1280x720"
|
| 1570 |
+
})
|
| 1571 |
+
|
| 1572 |
+
|
| 1573 |
+
|
| 1574 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 1575 |
json.dump(ui_defaults, f, indent=4)
|
| 1576 |
else:
|
|
|
|
| 1720 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 1721 |
repoId = "DeepBeepMeep/Wan2.1"
|
| 1722 |
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "mask", "", ]
|
| 1723 |
+
fileList = [ [], [],[], ["sam_vit_h_4b8939_fp16.safetensors"], ["Wan2.1_VAE.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
| 1724 |
targetRoot = "ckpts/"
|
| 1725 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
| 1726 |
if len(files)==0:
|
|
|
|
| 1834 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
| 1835 |
|
| 1836 |
|
| 1837 |
+
def load_t2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
| 1838 |
|
| 1839 |
cfg = WAN_CONFIGS['t2v-14B']
|
| 1840 |
# cfg = WAN_CONFIGS['t2v-1.3B']
|
| 1841 |
print(f"Loading '{model_filename}' model...")
|
| 1842 |
+
if get_model_type(model_filename) in ("sky_df_1.3B", "sky_df_14B", "sky_df_720p_14B"):
|
| 1843 |
model_factory = wan.DTT2V
|
| 1844 |
else:
|
| 1845 |
model_factory = wan.WanT2V
|
|
|
|
| 1850 |
model_filename=model_filename,
|
| 1851 |
text_encoder_filename= text_encoder_filename,
|
| 1852 |
quantizeTransformer = quantizeTransformer,
|
| 1853 |
+
dtype = dtype,
|
| 1854 |
+
VAE_dtype = VAE_dtype,
|
| 1855 |
+
mixed_precision_transformer = mixed_precision_transformer
|
| 1856 |
)
|
| 1857 |
|
| 1858 |
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "vae": wan_model.vae.model }
|
| 1859 |
|
| 1860 |
return wan_model, pipe
|
| 1861 |
|
| 1862 |
+
def load_i2v_model(model_filename, value, quantizeTransformer = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False):
|
| 1863 |
|
| 1864 |
print(f"Loading '{model_filename}' model...")
|
| 1865 |
|
| 1866 |
+
cfg = WAN_CONFIGS['i2v-14B']
|
| 1867 |
+
wan_model = wan.WanI2V(
|
| 1868 |
+
config=cfg,
|
| 1869 |
+
checkpoint_dir="ckpts",
|
| 1870 |
+
model_filename=model_filename,
|
| 1871 |
+
text_encoder_filename=text_encoder_filename,
|
| 1872 |
+
quantizeTransformer = quantizeTransformer,
|
| 1873 |
+
dtype = dtype,
|
| 1874 |
+
VAE_dtype = VAE_dtype,
|
| 1875 |
+
mixed_precision_transformer = mixed_precision_transformer
|
| 1876 |
+
)
|
| 1877 |
+
pipe = {"transformer": wan_model.model, "text_encoder" : wan_model.text_encoder.model, "text_encoder_2": wan_model.clip.model, "vae": wan_model.vae.model } #
|
| 1878 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1879 |
return wan_model, pipe
|
| 1880 |
|
| 1881 |
|
|
|
|
| 1887 |
perc_reserved_mem_max = args.perc_reserved_mem_max
|
| 1888 |
|
| 1889 |
major, minor = torch.cuda.get_device_capability(args.gpu if len(args.gpu) > 0 else None)
|
| 1890 |
+
if major < 8:
|
|
|
|
|
|
|
| 1891 |
print("Switching to f16 model as GPU architecture doesn't support bf16")
|
| 1892 |
+
default_dtype = torch.float16
|
| 1893 |
+
else:
|
| 1894 |
+
default_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
| 1895 |
+
if default_dtype == torch.float16 :
|
| 1896 |
if "quanto" in model_filename:
|
| 1897 |
model_filename = model_filename.replace("quanto_int8", "quanto_fp16_int8")
|
| 1898 |
download_models(model_filename, text_encoder_filename)
|
| 1899 |
+
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
|
| 1900 |
+
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
|
| 1901 |
if test_class_i2v(model_filename):
|
| 1902 |
res720P = "720p" in model_filename
|
| 1903 |
+
wan_model, pipe = load_i2v_model(model_filename, "720P" if res720P else "480P", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
| 1904 |
else:
|
| 1905 |
+
wan_model, pipe = load_t2v_model(model_filename, "", quantizeTransformer = quantizeTransformer, dtype = default_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer)
|
| 1906 |
wan_model._model_file_name = model_filename
|
| 1907 |
kwargs = { "extraModelsToQuantize": None}
|
| 1908 |
if profile == 2 or profile == 4:
|
|
|
|
| 1943 |
|
| 1944 |
|
| 1945 |
def generate_header(model_filename, compile, attention_mode):
|
| 1946 |
+
|
| 1947 |
+
description_container = [""]
|
| 1948 |
+
get_model_name(model_filename, description_container)
|
| 1949 |
+
description = description_container[0]
|
| 1950 |
+
header = "<DIV style='height:40px'>" + description + "</DIV>"
|
| 1951 |
+
|
| 1952 |
+
header += "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
|
| 1953 |
if attention_mode not in attention_modes_installed:
|
| 1954 |
header += " -NOT INSTALLED-"
|
| 1955 |
elif attention_mode not in attention_modes_supported:
|
|
|
|
| 1967 |
def apply_changes( state,
|
| 1968 |
transformer_types_choices,
|
| 1969 |
text_encoder_choice,
|
| 1970 |
+
VAE_precision_choice,
|
| 1971 |
+
mixed_precision_choice,
|
| 1972 |
save_path_choice,
|
| 1973 |
attention_choice,
|
| 1974 |
compile_choice,
|
|
|
|
| 1984 |
if args.lock_config:
|
| 1985 |
return
|
| 1986 |
if gen_in_progress:
|
| 1987 |
+
return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>", gr.update(), gr.update()
|
| 1988 |
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
| 1989 |
server_config = {"attention_mode" : attention_choice,
|
| 1990 |
"transformer_types": transformer_types_choices,
|
|
|
|
| 1993 |
"compile" : compile_choice,
|
| 1994 |
"profile" : profile_choice,
|
| 1995 |
"vae_config" : vae_config_choice,
|
| 1996 |
+
"vae_precision" : VAE_precision_choice,
|
| 1997 |
+
"mixed_precision" : mixed_precision_choice,
|
| 1998 |
"metadata_type": metadata_choice,
|
| 1999 |
"transformer_quantization" : quantization_choice,
|
| 2000 |
"boost" : boost_choice,
|
|
|
|
| 2116 |
return callback
|
| 2117 |
def abort_generation(state):
|
| 2118 |
gen = get_gen_info(state)
|
| 2119 |
+
if "in_progress" in gen and wan_model != None:
|
| 2120 |
|
| 2121 |
+
wan_model._interrupt= True
|
|
|
|
|
|
|
|
|
|
| 2122 |
msg = "Processing Request to abort Current Generation"
|
| 2123 |
gen["status"] = msg
|
| 2124 |
gr.Info(msg)
|
|
|
|
| 2201 |
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
|
| 2202 |
|
| 2203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2204 |
def select_video(state , event_data: gr.EventData):
|
| 2205 |
data= event_data._data
|
| 2206 |
gen = get_gen_info(state)
|
|
|
|
| 2439 |
# VAE Tiling
|
| 2440 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
| 2441 |
if vae_config == 0:
|
| 2442 |
+
if server_config.get("vae_precision", "16") == "32":
|
| 2443 |
+
device_mem_capacity = device_mem_capacity / 2
|
| 2444 |
if device_mem_capacity >= 24000:
|
| 2445 |
use_vae_config = 1
|
| 2446 |
elif device_mem_capacity >= 8000:
|
|
|
|
| 2553 |
max_frames_to_generate = video_length
|
| 2554 |
diffusion_forcing = "diffusion_forcing" in model_filename
|
| 2555 |
vace = "Vace" in model_filename
|
| 2556 |
+
phantom = "phantom" in model_filename
|
| 2557 |
if diffusion_forcing or vace:
|
| 2558 |
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
|
| 2559 |
if diffusion_forcing and source_video != None:
|
|
|
|
| 2593 |
extra_windows = 0
|
| 2594 |
guide_start_frame = 0
|
| 2595 |
video_length = first_window_video_length
|
| 2596 |
+
gen["extra_windows"] = 0
|
| 2597 |
while not abort:
|
| 2598 |
if sliding_window:
|
| 2599 |
prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
|
|
|
|
| 2608 |
window_no += 1
|
| 2609 |
gen["window_no"] = window_no
|
| 2610 |
|
| 2611 |
+
if phantom:
|
| 2612 |
+
src_ref_images = image_refs.copy() if image_refs != None else None
|
| 2613 |
+
elif diffusion_forcing:
|
| 2614 |
if video_source != None and len(video_source) > 0 and window_no == 1:
|
| 2615 |
keep_frames_video_source= 1000 if len(keep_frames_video_source) ==0 else int(keep_frames_video_source)
|
| 2616 |
prefix_video = preprocess_video(None, width=width, height=height,video_in=video_source, max_frames= keep_frames_video_source , start_frame = 0, fit_canvas= True, target_fps = fps)
|
|
|
|
| 2619 |
prefix_video_frames_count = prefix_video.shape[1]
|
| 2620 |
pre_video_guide = prefix_video[:, -reuse_frames:]
|
| 2621 |
|
| 2622 |
+
elif vace:
|
| 2623 |
# video_prompt_type = video_prompt_type +"G"
|
| 2624 |
image_refs_copy = image_refs.copy() if image_refs != None else None # required since prepare_source do inplace modifications
|
| 2625 |
video_guide_copy = video_guide
|
|
|
|
| 2670 |
progress_args = [0, status + " - Encoding Prompt"]
|
| 2671 |
send_cmd("progress", progress_args)
|
| 2672 |
|
| 2673 |
+
# samples = torch.empty( (1,2)) #for testing
|
| 2674 |
# if False:
|
| 2675 |
|
| 2676 |
try:
|
|
|
|
| 2693 |
guide_scale=guidance_scale,
|
| 2694 |
n_prompt=negative_prompt,
|
| 2695 |
seed=seed,
|
|
|
|
| 2696 |
callback=callback,
|
| 2697 |
enable_RIFLEx = enable_RIFLEx,
|
| 2698 |
VAE_tile_size = VAE_tile_size,
|
|
|
|
| 2797 |
if samples == None:
|
| 2798 |
abort = True
|
| 2799 |
state["prompt"] = ""
|
| 2800 |
+
send_cmd("output")
|
| 2801 |
else:
|
| 2802 |
sample = samples.cpu()
|
| 2803 |
if True: # for testing
|
|
|
|
| 2899 |
|
| 2900 |
print(f"New video saved to Path: "+video_path)
|
| 2901 |
file_list.append(video_path)
|
|
|
|
| 2902 |
send_cmd("output")
|
| 2903 |
if sliding_window :
|
| 2904 |
if max_frames_to_generate > 0 and extra_windows == 0:
|
|
|
|
| 2906 |
if (current_length - prefix_video_frames_count)>= max_frames_to_generate:
|
| 2907 |
break
|
| 2908 |
video_length = min(sliding_window_size, ((max_frames_to_generate - (current_length - prefix_video_frames_count) + reuse_frames + discard_last_frames) // 4) * 4 + 1 )
|
|
|
|
|
|
|
| 2909 |
|
| 2910 |
seed += 1
|
| 2911 |
|
|
|
|
| 3473 |
if not "recam" in model_filename or not "diffusion_forcing" in model_filename:
|
| 3474 |
inputs.pop("model_mode")
|
| 3475 |
|
| 3476 |
+
if not "Vace" in model_filename or not "phantom" in model_filename:
|
| 3477 |
unsaved_params = ["keep_frames_video_guide", "video_prompt_type", "remove_background_image_ref"]
|
| 3478 |
for k in unsaved_params:
|
| 3479 |
inputs.pop(k)
|
|
|
|
| 3833 |
diffusion_forcing = "diffusion_forcing" in model_filename
|
| 3834 |
recammaster = "recam" in model_filename
|
| 3835 |
vace = "Vace" in model_filename
|
| 3836 |
+
phantom = "phantom" in model_filename
|
| 3837 |
with gr.Column(visible= test_class_i2v(model_filename) or diffusion_forcing or recammaster) as image_prompt_column:
|
| 3838 |
if diffusion_forcing:
|
| 3839 |
image_prompt_type_value= ui_defaults.get("image_prompt_type","S")
|
|
|
|
| 3893 |
model_mode = gr.Dropdown(visible=False)
|
| 3894 |
keep_frames_video_source = gr.Text(visible=False)
|
| 3895 |
|
| 3896 |
+
with gr.Column(visible= vace or phantom) as video_prompt_column:
|
| 3897 |
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
| 3898 |
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
| 3899 |
with gr.Row():
|
| 3900 |
+
if vace:
|
| 3901 |
+
video_prompt_type_video_guide = gr.Dropdown(
|
| 3902 |
+
choices=[
|
| 3903 |
+
("None", ""),
|
| 3904 |
+
("Transfer Human Motion from the Control Video", "PV"),
|
| 3905 |
+
("Transfer Depth from the Control Video", "DV"),
|
| 3906 |
+
("Recolorize the Control Video", "CV"),
|
| 3907 |
+
# ("Alternate Video Ending", "OV"),
|
| 3908 |
+
("Video contains Open Pose, Depth, Black & White, Inpainting ", "V"),
|
| 3909 |
+
("Control Video and Mask video for stronger Inpainting ", "MV"),
|
| 3910 |
+
],
|
| 3911 |
+
value=filter_letters(video_prompt_type_value, "ODPCMV"),
|
| 3912 |
+
label="Video to Video", scale = 3, visible= True
|
| 3913 |
+
)
|
| 3914 |
+
else:
|
| 3915 |
+
video_prompt_type_video_guide = gr.Dropdown(visible= False)
|
| 3916 |
+
|
| 3917 |
video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
|
| 3918 |
|
| 3919 |
video_prompt_type_image_refs = gr.Dropdown(
|
|
|
|
| 3931 |
image_refs = gr.Gallery( label ="Reference Images",
|
| 3932 |
type ="pil", show_label= True,
|
| 3933 |
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
|
| 3934 |
+
value= ui_defaults.get("image_refs", None),
|
| 3935 |
+
)
|
| 3936 |
|
| 3937 |
# with gr.Row():
|
| 3938 |
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
|
|
|
|
| 3992 |
# ("832x1104 (3:4, 720p)", "832x1104"),
|
| 3993 |
# ("960x960 (1:1, 720p)", "960x960"),
|
| 3994 |
# 480p
|
| 3995 |
+
("960x544 (16:9, 540p)", "960x544"),
|
| 3996 |
+
("544x960 (16:9, 540p)", "544x960"),
|
| 3997 |
("832x480 (16:9, 480p)", "832x480"),
|
| 3998 |
("480x832 (9:16, 480p)", "480x832"),
|
| 3999 |
# ("832x624 (4:3, 540p)", "832x624"),
|
|
|
|
| 4146 |
gr.Markdown("<B>A Sliding Window allows you to generate video with a duration not limited by the Model</B>")
|
| 4147 |
gr.Markdown("<B>It is automatically turned on if the number of frames to generate is higher than the Window Size</B>")
|
| 4148 |
if diffusion_forcing:
|
| 4149 |
+
sliding_window_size = gr.Slider(37, 137, value=ui_defaults.get("sliding_window_size", 97), step=20, label="Sliding Window Size (recommended to keep it at 97)")
|
| 4150 |
sliding_window_overlap = gr.Slider(17, 97, value=ui_defaults.get("sliding_window_overlap",17), step=20, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
| 4151 |
+
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
|
| 4152 |
else:
|
| 4153 |
sliding_window_size = gr.Slider(5, 137, value=ui_defaults.get("sliding_window_size", 81), step=4, label="Sliding Window Size")
|
| 4154 |
sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",17), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
|
| 4155 |
+
sliding_window_discard_last_frames = gr.Slider(0, 12, value=ui_defaults.get("sliding_window_discard_last_frames", 4), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
|
| 4156 |
|
| 4157 |
|
| 4158 |
with gr.Tab("Miscellaneous", visible= not "recam" in model_filename):
|
|
|
|
| 4493 |
|
| 4494 |
quantization_choice = gr.Dropdown(
|
| 4495 |
choices=[
|
| 4496 |
+
("Scaled Int8 Quantization (recommended)", "int8"),
|
| 4497 |
("16 bits (no quantization)", "bf16"),
|
| 4498 |
],
|
| 4499 |
value= transformer_quantization,
|
| 4500 |
label="Wan Transformer Model Quantization Type (if available)",
|
| 4501 |
)
|
| 4502 |
|
| 4503 |
+
mixed_precision_choice = gr.Dropdown(
|
| 4504 |
+
choices=[
|
| 4505 |
+
("16 bits only, requires less VRAM", "0"),
|
| 4506 |
+
("Mixed 16 / 32 bits, slightly more VRAM needed but better Quality", "1"),
|
| 4507 |
+
],
|
| 4508 |
+
value= server_config.get("mixed_precision", "0"),
|
| 4509 |
+
label="Transformer Engine Calculation"
|
| 4510 |
+
)
|
| 4511 |
+
|
| 4512 |
index = text_encoder_choices.index(text_encoder_filename)
|
| 4513 |
index = 0 if index ==0 else index
|
| 4514 |
text_encoder_choice = gr.Dropdown(
|
|
|
|
| 4519 |
value= index,
|
| 4520 |
label="Text Encoder model"
|
| 4521 |
)
|
| 4522 |
+
|
| 4523 |
+
VAE_precision_choice = gr.Dropdown(
|
| 4524 |
+
choices=[
|
| 4525 |
+
("16 bits, requires less VRAM and faster", "16"),
|
| 4526 |
+
("32 bits, requires twice more VRAM and slower but recommended with Window Sliding", "32"),
|
| 4527 |
+
],
|
| 4528 |
+
value= server_config.get("vae_precision", "16"),
|
| 4529 |
+
label="VAE Encoding / Decoding precision"
|
| 4530 |
+
)
|
| 4531 |
+
|
| 4532 |
save_path_choice = gr.Textbox(
|
| 4533 |
label="Output Folder for Generated Videos",
|
| 4534 |
value=server_config.get("save_path", save_path)
|
|
|
|
| 4593 |
value= profile,
|
| 4594 |
label="Profile (for power users only, not needed to change it)"
|
| 4595 |
)
|
| 4596 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4597 |
metadata_choice = gr.Dropdown(
|
| 4598 |
choices=[
|
| 4599 |
("Export JSON files", "json"),
|
|
|
|
| 4639 |
state,
|
| 4640 |
transformer_types_choices,
|
| 4641 |
text_encoder_choice,
|
| 4642 |
+
VAE_precision_choice,
|
| 4643 |
+
mixed_precision_choice,
|
| 4644 |
save_path_choice,
|
| 4645 |
attention_choice,
|
| 4646 |
compile_choice,
|
|
|
|
| 5035 |
theme = gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md")
|
| 5036 |
|
| 5037 |
with gr.Blocks(css=css, theme=theme, title= "Wan2GP") as main:
|
| 5038 |
+
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
| 5039 |
global model_list
|
| 5040 |
|
| 5041 |
tab_state = gr.State({ "tab_no":0 })
|