Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
948696e
1
Parent(s):
7d5369f
Added 10% boost, improved Loras and Teacache
Browse files- README.md +25 -10
- gradio_server.py +178 -76
- wan/image2video.py +30 -12
- wan/modules/model.py +123 -60
- wan/text2video.py +19 -10
README.md
CHANGED
|
@@ -19,10 +19,11 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
-
* Mar
|
| 23 |
-
* Mar
|
|
|
|
| 24 |
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
| 25 |
-
* Mar
|
| 26 |
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
| 27 |
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
| 28 |
- Support for all Wan including the Image to Video model
|
|
@@ -152,15 +153,29 @@ python gradio_server.py --attention sdpa
|
|
| 152 |
|
| 153 |
### Loras support
|
| 154 |
|
| 155 |
-
-- Ready to be used but theoretical as no lora for Wan have been released as of today. ---
|
| 156 |
|
| 157 |
-
Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
|
| 158 |
|
| 159 |
-
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0)
|
|
|
|
| 160 |
|
| 161 |
-
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
```bash
|
| 165 |
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
| 166 |
```
|
|
@@ -180,11 +195,11 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|
| 180 |
--open-browser : open automatically Browser when launching Gradio Server\
|
| 181 |
--lock-config : prevent modifying the video engine configuration from the interface\
|
| 182 |
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
|
| 183 |
-
--multiple-images :
|
| 184 |
--compile : turn on pytorch compilation\
|
| 185 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
| 186 |
--profile no : default (4) : no of profile between 1 and 5\
|
| 187 |
-
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware
|
| 188 |
|
| 189 |
### Profiles (for power users only)
|
| 190 |
You can choose between 5 profiles, but two are really relevant here :
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
+
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
| 23 |
+
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
| 24 |
+
* Mar 04, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
|
| 25 |
If you upgrade you will need to do a 'pip install -r requirements.txt' again.
|
| 26 |
+
* Mar 04, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
|
| 27 |
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
| 28 |
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
| 29 |
- Support for all Wan including the Image to Video model
|
|
|
|
| 153 |
|
| 154 |
### Loras support
|
| 155 |
|
|
|
|
| 156 |
|
| 157 |
+
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
| 158 |
|
| 159 |
+
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora shoud be separated by a space character or a carriage return. For instance:\
|
| 160 |
+
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
| 161 |
|
| 162 |
+
Alternatively for each Lora's multiplier you may specify a list of float numbers multipliers separated by a "," (no space) that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
|
| 163 |
|
| 164 |
+
If multiple Loras are defined, remember that each multiplier associated to different Loras should be separated by a space or a carriage return, so we can specify the evolution of multipliers for multiple Loras. For instance for two Loras (press Shift Return to force a carriage return):
|
| 165 |
+
|
| 166 |
+
```
|
| 167 |
+
0.9,0.8,0.7
|
| 168 |
+
1.2,1.1,1.0
|
| 169 |
+
```
|
| 170 |
+
You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio Web interface. These presets will save the *comment* part of the prompt that should contain some instructions how to use the corresponding the loras (for instance by specifying a trigger word or providing an example).A comment in the prompt is a line that starts that a #. It will be ignored by the video generator. For instance:
|
| 171 |
+
|
| 172 |
+
```
|
| 173 |
+
# use they keyword ohnvx to trigger the Lora*
|
| 174 |
+
A ohnvx is driving a car
|
| 175 |
+
```
|
| 176 |
+
Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
|
| 177 |
+
|
| 178 |
+
Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
|
| 179 |
```bash
|
| 180 |
python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
|
| 181 |
```
|
|
|
|
| 195 |
--open-browser : open automatically Browser when launching Gradio Server\
|
| 196 |
--lock-config : prevent modifying the video engine configuration from the interface\
|
| 197 |
--share : create a shareable URL on huggingface so that your server can be accessed remotely\
|
| 198 |
+
--multiple-images : allow the users to choose multiple images as different starting points for new videos\
|
| 199 |
--compile : turn on pytorch compilation\
|
| 200 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
| 201 |
--profile no : default (4) : no of profile between 1 and 5\
|
| 202 |
+
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
|
| 203 |
|
| 204 |
### Profiles (for power users only)
|
| 205 |
You can choose between 5 profiles, but two are really relevant here :
|
gradio_server.py
CHANGED
|
@@ -57,18 +57,18 @@ def _parse_args():
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
)
|
| 66 |
|
| 67 |
|
| 68 |
parser.add_argument(
|
| 69 |
"--lora-dir",
|
| 70 |
type=str,
|
| 71 |
-
default="
|
| 72 |
help="Path to a directory that contains Loras"
|
| 73 |
)
|
| 74 |
|
|
@@ -80,12 +80,12 @@ def _parse_args():
|
|
| 80 |
help="Lora preset to preload"
|
| 81 |
)
|
| 82 |
|
| 83 |
-
parser.add_argument(
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
)
|
| 89 |
|
| 90 |
parser.add_argument(
|
| 91 |
"--profile",
|
|
@@ -198,6 +198,7 @@ if not Path(server_config_filename).is_file():
|
|
| 198 |
"text_encoder_filename" : text_encoder_choices[1],
|
| 199 |
"compile" : "",
|
| 200 |
"default_ui": "t2v",
|
|
|
|
| 201 |
"vae_config": 0,
|
| 202 |
"profile" : profile_type.LowRAM_LowVRAM }
|
| 203 |
|
|
@@ -223,6 +224,7 @@ if len(args.attention)> 0:
|
|
| 223 |
|
| 224 |
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
| 225 |
compile = server_config.get("compile", "")
|
|
|
|
| 226 |
vae_config = server_config.get("vae_config", 0)
|
| 227 |
if len(args.vae_config) > 0:
|
| 228 |
vae_config = int(args.vae_config)
|
|
@@ -234,13 +236,14 @@ if args.t2v:
|
|
| 234 |
if args.i2v:
|
| 235 |
use_image2video = True
|
| 236 |
|
| 237 |
-
if use_image2video:
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
else:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
|
|
|
| 244 |
default_tea_cache = 0
|
| 245 |
# if args.fast : #or args.fastest
|
| 246 |
# transformer_filename_t2v = transformer_choices_t2v[2]
|
|
@@ -321,8 +324,16 @@ def extract_preset(lset_name, loras):
|
|
| 321 |
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
| 322 |
|
| 323 |
loras_mult_choices = lset["loras_mult"]
|
| 324 |
-
|
|
|
|
| 325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
| 327 |
loras =[]
|
| 328 |
loras_names = []
|
|
@@ -337,7 +348,7 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|
| 337 |
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
| 338 |
|
| 339 |
default_lora_preset = ""
|
| 340 |
-
|
| 341 |
if lora_dir != None:
|
| 342 |
import glob
|
| 343 |
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
|
@@ -350,15 +361,16 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|
| 350 |
|
| 351 |
if len(loras) > 0:
|
| 352 |
loras_names = [ Path(lora).stem for lora in loras ]
|
| 353 |
-
|
| 354 |
|
| 355 |
if len(lora_preselected_preset) > 0:
|
| 356 |
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
| 357 |
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
| 358 |
default_lora_preset = lora_preselected_preset
|
| 359 |
-
default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
|
| 360 |
-
|
| 361 |
-
|
|
|
|
| 362 |
|
| 363 |
|
| 364 |
def load_t2v_model(model_filename, value):
|
|
@@ -439,13 +451,13 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
|
| 439 |
kwargs["budgets"] = { "*" : "70%" }
|
| 440 |
|
| 441 |
|
| 442 |
-
loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
| 443 |
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
| 444 |
|
| 445 |
|
| 446 |
-
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
| 447 |
|
| 448 |
-
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 449 |
gen_in_progress = False
|
| 450 |
|
| 451 |
def get_auto_attention():
|
|
@@ -487,13 +499,14 @@ def apply_changes( state,
|
|
| 487 |
profile_choice,
|
| 488 |
vae_config_choice,
|
| 489 |
default_ui_choice ="t2v",
|
|
|
|
| 490 |
):
|
| 491 |
if args.lock_config:
|
| 492 |
return
|
| 493 |
if gen_in_progress:
|
| 494 |
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
| 495 |
return
|
| 496 |
-
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
|
| 497 |
server_config = {"attention_mode" : attention_choice,
|
| 498 |
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
| 499 |
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
|
@@ -502,6 +515,7 @@ def apply_changes( state,
|
|
| 502 |
"profile" : profile_choice,
|
| 503 |
"vae_config" : vae_config_choice,
|
| 504 |
"default_ui" : default_ui_choice,
|
|
|
|
| 505 |
}
|
| 506 |
|
| 507 |
if Path(server_config_filename).is_file():
|
|
@@ -529,7 +543,7 @@ def apply_changes( state,
|
|
| 529 |
state["config_new"] = server_config
|
| 530 |
state["config_old"] = old_server_config
|
| 531 |
|
| 532 |
-
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
|
| 533 |
attention_mode = server_config["attention_mode"]
|
| 534 |
profile = server_config["profile"]
|
| 535 |
compile = server_config["compile"]
|
|
@@ -537,8 +551,8 @@ def apply_changes( state,
|
|
| 537 |
transformer_filename_i2v = server_config["transformer_filename_i2v"]
|
| 538 |
text_encoder_filename = server_config["text_encoder_filename"]
|
| 539 |
vae_config = server_config["vae_config"]
|
| 540 |
-
|
| 541 |
-
if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
|
| 542 |
if "attention_mode" in changes:
|
| 543 |
pass
|
| 544 |
|
|
@@ -548,7 +562,7 @@ def apply_changes( state,
|
|
| 548 |
offloadobj = None
|
| 549 |
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
| 550 |
|
| 551 |
-
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 552 |
|
| 553 |
|
| 554 |
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
|
@@ -727,7 +741,9 @@ def generate_video(
|
|
| 727 |
if len(prompt) ==0:
|
| 728 |
return
|
| 729 |
prompts = prompt.replace("\r", "").split("\n")
|
| 730 |
-
|
|
|
|
|
|
|
| 731 |
if use_image2video:
|
| 732 |
if image_to_continue is not None:
|
| 733 |
if isinstance(image_to_continue, list):
|
|
@@ -772,6 +788,9 @@ def generate_video(
|
|
| 772 |
return False
|
| 773 |
list_mult_choices_nums = []
|
| 774 |
if len(loras_mult_choices) > 0:
|
|
|
|
|
|
|
|
|
|
| 775 |
list_mult_choices_str = loras_mult_choices.split(" ")
|
| 776 |
for i, mult in enumerate(list_mult_choices_str):
|
| 777 |
mult = mult.strip()
|
|
@@ -805,18 +824,36 @@ def generate_video(
|
|
| 805 |
# VAE Tiling
|
| 806 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
| 807 |
|
| 808 |
-
|
| 809 |
# TeaCache
|
| 810 |
trans = wan_model.model
|
| 811 |
trans.enable_teacache = tea_cache > 0
|
| 812 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
import random
|
| 814 |
if seed == None or seed <0:
|
| 815 |
seed = random.randint(0, 999999999)
|
| 816 |
|
| 817 |
file_list = []
|
| 818 |
state["file_list"] = file_list
|
| 819 |
-
from einops import rearrange
|
| 820 |
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
| 821 |
os.makedirs(save_path, exist_ok=True)
|
| 822 |
video_no = 0
|
|
@@ -830,14 +867,12 @@ def generate_video(
|
|
| 830 |
|
| 831 |
if trans.enable_teacache:
|
| 832 |
trans.teacache_counter = 0
|
| 833 |
-
trans.
|
| 834 |
-
trans.teacache_start_step =
|
|
|
|
|
|
|
| 835 |
trans.previous_residual_uncond = None
|
| 836 |
-
trans.previous_modulated_input_uncond = None
|
| 837 |
trans.previous_residual_cond = None
|
| 838 |
-
trans.previous_modulated_input_cond= None
|
| 839 |
-
|
| 840 |
-
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
|
| 841 |
|
| 842 |
video_no += 1
|
| 843 |
status = f"Video {video_no}/{total_video}"
|
|
@@ -853,7 +888,7 @@ def generate_video(
|
|
| 853 |
if use_image2video:
|
| 854 |
samples = wan_model.generate(
|
| 855 |
prompt,
|
| 856 |
-
image_to_continue[ (video_no-1) % len(image_to_continue)],
|
| 857 |
frame_num=(video_length // 4)* 4 + 1,
|
| 858 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 859 |
shift=flow_shift,
|
|
@@ -864,7 +899,8 @@ def generate_video(
|
|
| 864 |
offload_model=False,
|
| 865 |
callback=callback,
|
| 866 |
enable_RIFLEx = enable_RIFLEx,
|
| 867 |
-
VAE_tile_size = VAE_tile_size
|
|
|
|
| 868 |
)
|
| 869 |
|
| 870 |
else:
|
|
@@ -880,7 +916,8 @@ def generate_video(
|
|
| 880 |
offload_model=False,
|
| 881 |
callback=callback,
|
| 882 |
enable_RIFLEx = enable_RIFLEx,
|
| 883 |
-
VAE_tile_size = VAE_tile_size
|
|
|
|
| 884 |
)
|
| 885 |
except Exception as e:
|
| 886 |
gen_in_progress = False
|
|
@@ -911,6 +948,7 @@ def generate_video(
|
|
| 911 |
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 912 |
|
| 913 |
if trans.enable_teacache:
|
|
|
|
| 914 |
trans.previous_residual_uncond = None
|
| 915 |
trans.previous_residual_cond = None
|
| 916 |
|
|
@@ -957,7 +995,25 @@ def generate_video(
|
|
| 957 |
|
| 958 |
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
| 959 |
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
global loras_presets
|
| 962 |
|
| 963 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
|
@@ -968,6 +1024,16 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
|
|
| 968 |
|
| 969 |
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
|
| 970 |
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 971 |
lset_name_filename = lset_name + ".lset"
|
| 972 |
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
|
| 973 |
|
|
@@ -982,7 +1048,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
|
|
| 982 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
| 983 |
lset_choices.append( (new_preset_msg, ""))
|
| 984 |
|
| 985 |
-
return gr.Dropdown(choices=lset_choices, value= lset_name)
|
| 986 |
|
| 987 |
def delete_lset(lset_name):
|
| 988 |
global loras_presets
|
|
@@ -1000,23 +1066,31 @@ def delete_lset(lset_name):
|
|
| 1000 |
|
| 1001 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 1002 |
lset_choices.append((new_preset_msg, ""))
|
| 1003 |
-
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
|
|
|
|
| 1004 |
|
| 1005 |
-
def apply_lset(lset_name, loras_choices, loras_mult_choices):
|
| 1006 |
|
| 1007 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
| 1008 |
gr.Info("Please choose a preset in the list or create one")
|
| 1009 |
else:
|
| 1010 |
-
loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
| 1012 |
|
| 1013 |
-
return loras_choices, loras_mult_choices
|
| 1014 |
|
| 1015 |
def create_demo():
|
| 1016 |
|
| 1017 |
default_inference_steps = 30
|
|
|
|
| 1018 |
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
| 1019 |
-
|
| 1020 |
with gr.Blocks() as demo:
|
| 1021 |
state = gr.State({})
|
| 1022 |
|
|
@@ -1130,6 +1204,16 @@ def create_demo():
|
|
| 1130 |
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
| 1131 |
)
|
| 1132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1133 |
profile_choice = gr.Dropdown(
|
| 1134 |
choices=[
|
| 1135 |
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
|
@@ -1161,16 +1245,12 @@ def create_demo():
|
|
| 1161 |
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
| 1162 |
if args.multiple_images:
|
| 1163 |
image_to_continue = gr.Gallery(
|
| 1164 |
-
label="Images as a starting point for new videos", type ="
|
| 1165 |
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
| 1166 |
else:
|
| 1167 |
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
| 1168 |
|
| 1169 |
-
|
| 1170 |
-
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
|
| 1171 |
-
else:
|
| 1172 |
-
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
|
| 1173 |
-
|
| 1174 |
|
| 1175 |
with gr.Row():
|
| 1176 |
if use_image2video:
|
|
@@ -1223,9 +1303,21 @@ def create_demo():
|
|
| 1223 |
# with gr.Column():
|
| 1224 |
with gr.Row(height=17):
|
| 1225 |
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1226 |
with gr.Row(height=17):
|
|
|
|
|
|
|
| 1227 |
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
| 1228 |
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
|
|
|
| 1229 |
|
| 1230 |
|
| 1231 |
loras_choices = gr.Dropdown(
|
|
@@ -1237,7 +1329,7 @@ def create_demo():
|
|
| 1237 |
visible= len(loras)>0,
|
| 1238 |
label="Activated Loras"
|
| 1239 |
)
|
| 1240 |
-
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
|
| 1241 |
|
| 1242 |
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
| 1243 |
with gr.Row(visible=False) as advanced_row:
|
|
@@ -1250,18 +1342,23 @@ def create_demo():
|
|
| 1250 |
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
| 1251 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
| 1252 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
| 1253 |
-
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
-
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
|
| 1263 |
-
|
| 1264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1265 |
|
| 1266 |
RIFLEx_setting = gr.Dropdown(
|
| 1267 |
choices=[
|
|
@@ -1283,9 +1380,13 @@ def create_demo():
|
|
| 1283 |
generate_btn = gr.Button("Generate")
|
| 1284 |
abort_btn = gr.Button("Abort")
|
| 1285 |
|
| 1286 |
-
save_lset_btn.click(
|
| 1287 |
-
|
| 1288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1289 |
|
| 1290 |
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
| 1291 |
|
|
@@ -1335,6 +1436,7 @@ def create_demo():
|
|
| 1335 |
profile_choice,
|
| 1336 |
vae_config_choice,
|
| 1337 |
default_ui_choice,
|
|
|
|
| 1338 |
],
|
| 1339 |
outputs= msg
|
| 1340 |
).then(
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
|
| 60 |
+
# parser.add_argument(
|
| 61 |
+
# "--lora-dir-i2v",
|
| 62 |
+
# type=str,
|
| 63 |
+
# default="loras_i2v",
|
| 64 |
+
# help="Path to a directory that contains Loras for i2v"
|
| 65 |
+
# )
|
| 66 |
|
| 67 |
|
| 68 |
parser.add_argument(
|
| 69 |
"--lora-dir",
|
| 70 |
type=str,
|
| 71 |
+
default="",
|
| 72 |
help="Path to a directory that contains Loras"
|
| 73 |
)
|
| 74 |
|
|
|
|
| 80 |
help="Lora preset to preload"
|
| 81 |
)
|
| 82 |
|
| 83 |
+
# parser.add_argument(
|
| 84 |
+
# "--lora-preset-i2v",
|
| 85 |
+
# type=str,
|
| 86 |
+
# default="",
|
| 87 |
+
# help="Lora preset to preload for i2v"
|
| 88 |
+
# )
|
| 89 |
|
| 90 |
parser.add_argument(
|
| 91 |
"--profile",
|
|
|
|
| 198 |
"text_encoder_filename" : text_encoder_choices[1],
|
| 199 |
"compile" : "",
|
| 200 |
"default_ui": "t2v",
|
| 201 |
+
"boost" : 1,
|
| 202 |
"vae_config": 0,
|
| 203 |
"profile" : profile_type.LowRAM_LowVRAM }
|
| 204 |
|
|
|
|
| 224 |
|
| 225 |
profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
|
| 226 |
compile = server_config.get("compile", "")
|
| 227 |
+
boost = server_config.get("boost", 1)
|
| 228 |
vae_config = server_config.get("vae_config", 0)
|
| 229 |
if len(args.vae_config) > 0:
|
| 230 |
vae_config = int(args.vae_config)
|
|
|
|
| 236 |
if args.i2v:
|
| 237 |
use_image2video = True
|
| 238 |
|
| 239 |
+
# if use_image2video:
|
| 240 |
+
# lora_dir =args.lora_dir_i2v
|
| 241 |
+
# lora_preselected_preset = args.lora_preset_i2v
|
| 242 |
+
# else:
|
| 243 |
+
lora_dir =args.lora_dir
|
| 244 |
+
if len(lora_dir) ==0:
|
| 245 |
+
lora_dir = "loras_i2v" if use_image2video else "loras"
|
| 246 |
+
lora_preselected_preset = args.lora_preset
|
| 247 |
default_tea_cache = 0
|
| 248 |
# if args.fast : #or args.fastest
|
| 249 |
# transformer_filename_t2v = transformer_choices_t2v[2]
|
|
|
|
| 324 |
raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
|
| 325 |
|
| 326 |
loras_mult_choices = lset["loras_mult"]
|
| 327 |
+
prompt = lset.get("prompt", "")
|
| 328 |
+
return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
|
| 329 |
|
| 330 |
+
def get_default_prompt(i2v):
|
| 331 |
+
if i2v:
|
| 332 |
+
return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
|
| 333 |
+
else:
|
| 334 |
+
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
| 335 |
+
|
| 336 |
+
|
| 337 |
def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
| 338 |
loras =[]
|
| 339 |
loras_names = []
|
|
|
|
| 348 |
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
| 349 |
|
| 350 |
default_lora_preset = ""
|
| 351 |
+
default_prompt = ""
|
| 352 |
if lora_dir != None:
|
| 353 |
import glob
|
| 354 |
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
|
|
|
| 361 |
|
| 362 |
if len(loras) > 0:
|
| 363 |
loras_names = [ Path(lora).stem for lora in loras ]
|
| 364 |
+
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
| 365 |
|
| 366 |
if len(lora_preselected_preset) > 0:
|
| 367 |
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
| 368 |
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
| 369 |
default_lora_preset = lora_preselected_preset
|
| 370 |
+
default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
|
| 371 |
+
if len(default_prompt) == 0:
|
| 372 |
+
default_prompt = get_default_prompt(use_image2video)
|
| 373 |
+
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
| 374 |
|
| 375 |
|
| 376 |
def load_t2v_model(model_filename, value):
|
|
|
|
| 451 |
kwargs["budgets"] = { "*" : "70%" }
|
| 452 |
|
| 453 |
|
| 454 |
+
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
|
| 455 |
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
|
| 456 |
|
| 457 |
|
| 458 |
+
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
| 459 |
|
| 460 |
+
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 461 |
gen_in_progress = False
|
| 462 |
|
| 463 |
def get_auto_attention():
|
|
|
|
| 499 |
profile_choice,
|
| 500 |
vae_config_choice,
|
| 501 |
default_ui_choice ="t2v",
|
| 502 |
+
boost_choice = 1
|
| 503 |
):
|
| 504 |
if args.lock_config:
|
| 505 |
return
|
| 506 |
if gen_in_progress:
|
| 507 |
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
| 508 |
return
|
| 509 |
+
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
| 510 |
server_config = {"attention_mode" : attention_choice,
|
| 511 |
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
| 512 |
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
|
|
|
|
| 515 |
"profile" : profile_choice,
|
| 516 |
"vae_config" : vae_config_choice,
|
| 517 |
"default_ui" : default_ui_choice,
|
| 518 |
+
"boost" : boost_choice,
|
| 519 |
}
|
| 520 |
|
| 521 |
if Path(server_config_filename).is_file():
|
|
|
|
| 543 |
state["config_new"] = server_config
|
| 544 |
state["config_old"] = old_server_config
|
| 545 |
|
| 546 |
+
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
|
| 547 |
attention_mode = server_config["attention_mode"]
|
| 548 |
profile = server_config["profile"]
|
| 549 |
compile = server_config["compile"]
|
|
|
|
| 551 |
transformer_filename_i2v = server_config["transformer_filename_i2v"]
|
| 552 |
text_encoder_filename = server_config["text_encoder_filename"]
|
| 553 |
vae_config = server_config["vae_config"]
|
| 554 |
+
boost = server_config["boost"]
|
| 555 |
+
if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
|
| 556 |
if "attention_mode" in changes:
|
| 557 |
pass
|
| 558 |
|
|
|
|
| 562 |
offloadobj = None
|
| 563 |
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
| 564 |
|
| 565 |
+
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 566 |
|
| 567 |
|
| 568 |
yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
|
|
|
|
| 741 |
if len(prompt) ==0:
|
| 742 |
return
|
| 743 |
prompts = prompt.replace("\r", "").split("\n")
|
| 744 |
+
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
| 745 |
+
if len(prompts) ==0:
|
| 746 |
+
return
|
| 747 |
if use_image2video:
|
| 748 |
if image_to_continue is not None:
|
| 749 |
if isinstance(image_to_continue, list):
|
|
|
|
| 788 |
return False
|
| 789 |
list_mult_choices_nums = []
|
| 790 |
if len(loras_mult_choices) > 0:
|
| 791 |
+
loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n")
|
| 792 |
+
loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
|
| 793 |
+
loras_mult_choices = " ".join(loras_mult_choices_list)
|
| 794 |
list_mult_choices_str = loras_mult_choices.split(" ")
|
| 795 |
for i, mult in enumerate(list_mult_choices_str):
|
| 796 |
mult = mult.strip()
|
|
|
|
| 824 |
# VAE Tiling
|
| 825 |
device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
|
| 826 |
|
| 827 |
+
joint_pass = boost ==1
|
| 828 |
# TeaCache
|
| 829 |
trans = wan_model.model
|
| 830 |
trans.enable_teacache = tea_cache > 0
|
| 831 |
+
if trans.enable_teacache:
|
| 832 |
+
if use_image2video:
|
| 833 |
+
if '480p' in transformer_filename_i2v:
|
| 834 |
+
# teacache_thresholds = [0.13, .19, 0.26]
|
| 835 |
+
trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
|
| 836 |
+
elif '720p' in transformer_filename_i2v:
|
| 837 |
+
teacache_thresholds = [0.18, 0.2 , 0.3]
|
| 838 |
+
trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
|
| 839 |
+
else:
|
| 840 |
+
raise gr.Error("Teacache not supported for this model")
|
| 841 |
+
else:
|
| 842 |
+
if '1.3B' in transformer_filename_t2v:
|
| 843 |
+
# teacache_thresholds= [0.05, 0.07, 0.08]
|
| 844 |
+
trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
|
| 845 |
+
elif '14B' in transformer_filename_t2v:
|
| 846 |
+
# teacache_thresholds = [0.14, 0.15, 0.2]
|
| 847 |
+
trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
|
| 848 |
+
else:
|
| 849 |
+
raise gr.Error("Teacache not supported for this model")
|
| 850 |
+
|
| 851 |
import random
|
| 852 |
if seed == None or seed <0:
|
| 853 |
seed = random.randint(0, 999999999)
|
| 854 |
|
| 855 |
file_list = []
|
| 856 |
state["file_list"] = file_list
|
|
|
|
| 857 |
save_path = os.path.join(os.getcwd(), "gradio_outputs")
|
| 858 |
os.makedirs(save_path, exist_ok=True)
|
| 859 |
video_no = 0
|
|
|
|
| 867 |
|
| 868 |
if trans.enable_teacache:
|
| 869 |
trans.teacache_counter = 0
|
| 870 |
+
trans.teacache_multiplier = tea_cache
|
| 871 |
+
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
| 872 |
+
trans.num_steps = num_inference_steps
|
| 873 |
+
trans.teacache_skipped_steps = 0
|
| 874 |
trans.previous_residual_uncond = None
|
|
|
|
| 875 |
trans.previous_residual_cond = None
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
video_no += 1
|
| 878 |
status = f"Video {video_no}/{total_video}"
|
|
|
|
| 888 |
if use_image2video:
|
| 889 |
samples = wan_model.generate(
|
| 890 |
prompt,
|
| 891 |
+
image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
|
| 892 |
frame_num=(video_length // 4)* 4 + 1,
|
| 893 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 894 |
shift=flow_shift,
|
|
|
|
| 899 |
offload_model=False,
|
| 900 |
callback=callback,
|
| 901 |
enable_RIFLEx = enable_RIFLEx,
|
| 902 |
+
VAE_tile_size = VAE_tile_size,
|
| 903 |
+
joint_pass = joint_pass,
|
| 904 |
)
|
| 905 |
|
| 906 |
else:
|
|
|
|
| 916 |
offload_model=False,
|
| 917 |
callback=callback,
|
| 918 |
enable_RIFLEx = enable_RIFLEx,
|
| 919 |
+
VAE_tile_size = VAE_tile_size,
|
| 920 |
+
joint_pass = joint_pass,
|
| 921 |
)
|
| 922 |
except Exception as e:
|
| 923 |
gen_in_progress = False
|
|
|
|
| 948 |
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 949 |
|
| 950 |
if trans.enable_teacache:
|
| 951 |
+
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
| 952 |
trans.previous_residual_uncond = None
|
| 953 |
trans.previous_residual_cond = None
|
| 954 |
|
|
|
|
| 995 |
|
| 996 |
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
| 997 |
|
| 998 |
+
|
| 999 |
+
def validate_delete_lset(lset_name):
|
| 1000 |
+
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1001 |
+
gr.Info(f"Choose a Preset to delete")
|
| 1002 |
+
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
| 1003 |
+
else:
|
| 1004 |
+
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
| 1005 |
+
|
| 1006 |
+
def validate_save_lset(lset_name):
|
| 1007 |
+
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1008 |
+
gr.Info("Please enter a name for the preset")
|
| 1009 |
+
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
| 1010 |
+
else:
|
| 1011 |
+
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
| 1012 |
+
|
| 1013 |
+
def cancel_lset():
|
| 1014 |
+
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1015 |
+
|
| 1016 |
+
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
| 1017 |
global loras_presets
|
| 1018 |
|
| 1019 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
|
|
|
| 1024 |
|
| 1025 |
loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
|
| 1026 |
lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
|
| 1027 |
+
if save_lset_prompt_cbox!=1:
|
| 1028 |
+
prompts = prompt.replace("\r", "").split("\n")
|
| 1029 |
+
prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")]
|
| 1030 |
+
prompt = "\n".join(prompts)
|
| 1031 |
+
|
| 1032 |
+
if len(prompt) > 0:
|
| 1033 |
+
lset["prompt"] = prompt
|
| 1034 |
+
lset["full_prompt"] = save_lset_prompt_cbox ==1
|
| 1035 |
+
|
| 1036 |
+
|
| 1037 |
lset_name_filename = lset_name + ".lset"
|
| 1038 |
full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
|
| 1039 |
|
|
|
|
| 1048 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
| 1049 |
lset_choices.append( (new_preset_msg, ""))
|
| 1050 |
|
| 1051 |
+
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1052 |
|
| 1053 |
def delete_lset(lset_name):
|
| 1054 |
global loras_presets
|
|
|
|
| 1066 |
|
| 1067 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 1068 |
lset_choices.append((new_preset_msg, ""))
|
| 1069 |
+
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1070 |
+
|
| 1071 |
|
| 1072 |
+
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
| 1073 |
|
| 1074 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
| 1075 |
gr.Info("Please choose a preset in the list or create one")
|
| 1076 |
else:
|
| 1077 |
+
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
|
| 1078 |
+
if full_prompt:
|
| 1079 |
+
prompt = preset_prompt
|
| 1080 |
+
elif len(preset_prompt) > 0:
|
| 1081 |
+
prompts = prompt.replace("\r", "").split("\n")
|
| 1082 |
+
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
| 1083 |
+
prompt = "\n".join(prompts)
|
| 1084 |
+
prompt = preset_prompt + '\n' + prompt
|
| 1085 |
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
| 1086 |
|
| 1087 |
+
return loras_choices, loras_mult_choices, prompt
|
| 1088 |
|
| 1089 |
def create_demo():
|
| 1090 |
|
| 1091 |
default_inference_steps = 30
|
| 1092 |
+
|
| 1093 |
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
|
|
|
| 1094 |
with gr.Blocks() as demo:
|
| 1095 |
state = gr.State({})
|
| 1096 |
|
|
|
|
| 1204 |
label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
|
| 1205 |
)
|
| 1206 |
|
| 1207 |
+
boost_choice = gr.Dropdown(
|
| 1208 |
+
choices=[
|
| 1209 |
+
# ("Auto (ON if Video longer than 5s)", 0),
|
| 1210 |
+
("ON", 1),
|
| 1211 |
+
("OFF", 2),
|
| 1212 |
+
],
|
| 1213 |
+
value=boost,
|
| 1214 |
+
label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
profile_choice = gr.Dropdown(
|
| 1218 |
choices=[
|
| 1219 |
("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
|
|
|
|
| 1245 |
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
| 1246 |
if args.multiple_images:
|
| 1247 |
image_to_continue = gr.Gallery(
|
| 1248 |
+
label="Images as a starting point for new videos", type ="pil", #file_types= "image",
|
| 1249 |
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
|
| 1250 |
else:
|
| 1251 |
image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
|
| 1252 |
|
| 1253 |
+
prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1254 |
|
| 1255 |
with gr.Row():
|
| 1256 |
if use_image2video:
|
|
|
|
| 1303 |
# with gr.Column():
|
| 1304 |
with gr.Row(height=17):
|
| 1305 |
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
| 1306 |
+
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
| 1307 |
+
save_lset_prompt_drop= gr.Dropdown(
|
| 1308 |
+
choices=[
|
| 1309 |
+
("Save Prompt Comments Only", 0),
|
| 1310 |
+
("Save Full Prompt", 1)
|
| 1311 |
+
], show_label= False, container=False, visible= False
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
with gr.Row(height=17):
|
| 1316 |
+
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
|
| 1317 |
+
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
|
| 1318 |
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
| 1319 |
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
| 1320 |
+
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
| 1321 |
|
| 1322 |
|
| 1323 |
loras_choices = gr.Dropdown(
|
|
|
|
| 1329 |
visible= len(loras)>0,
|
| 1330 |
label="Activated Loras"
|
| 1331 |
)
|
| 1332 |
+
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
|
| 1333 |
|
| 1334 |
show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
|
| 1335 |
with gr.Row(visible=False) as advanced_row:
|
|
|
|
| 1342 |
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
| 1343 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
| 1344 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
| 1345 |
+
with gr.Row():
|
| 1346 |
+
gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
|
| 1347 |
+
with gr.Row():
|
| 1348 |
+
tea_cache_setting = gr.Dropdown(
|
| 1349 |
+
choices=[
|
| 1350 |
+
("Tea Cache Disabled", 0),
|
| 1351 |
+
("around x1.5 speed up", 1.5),
|
| 1352 |
+
("around x1.75 speed up", 1.75),
|
| 1353 |
+
("around x2 speed up", 2.0),
|
| 1354 |
+
("around x2.25 speed up", 2.25),
|
| 1355 |
+
("around x2.5 speed up", 2.5),
|
| 1356 |
+
],
|
| 1357 |
+
value=default_tea_cache,
|
| 1358 |
+
visible=True,
|
| 1359 |
+
label="Tea Cache Global Acceleration"
|
| 1360 |
+
)
|
| 1361 |
+
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
|
| 1362 |
|
| 1363 |
RIFLEx_setting = gr.Dropdown(
|
| 1364 |
choices=[
|
|
|
|
| 1380 |
generate_btn = gr.Button("Generate")
|
| 1381 |
abort_btn = gr.Button("Abort")
|
| 1382 |
|
| 1383 |
+
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1384 |
+
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1385 |
+
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1386 |
+
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1387 |
+
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
| 1388 |
+
|
| 1389 |
+
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
| 1390 |
|
| 1391 |
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
| 1392 |
|
|
|
|
| 1436 |
profile_choice,
|
| 1437 |
vae_config_choice,
|
| 1438 |
default_ui_choice,
|
| 1439 |
+
boost_choice,
|
| 1440 |
],
|
| 1441 |
outputs= msg
|
| 1442 |
).then(
|
wan/image2video.py
CHANGED
|
@@ -146,7 +146,7 @@ class WanI2V:
|
|
| 146 |
callback = None,
|
| 147 |
enable_RIFLEx = False,
|
| 148 |
VAE_tile_size= 0,
|
| 149 |
-
|
| 150 |
):
|
| 151 |
r"""
|
| 152 |
Generates video frames from input image and text prompt using diffusion process.
|
|
@@ -310,9 +310,22 @@ class WanI2V:
|
|
| 310 |
'pipeline' : self
|
| 311 |
}
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
if offload_model:
|
| 314 |
torch.cuda.empty_cache()
|
| 315 |
|
|
|
|
|
|
|
|
|
|
| 316 |
# self.model.to(self.device)
|
| 317 |
if callback != None:
|
| 318 |
callback(-1, None)
|
|
@@ -323,17 +336,22 @@ class WanI2V:
|
|
| 323 |
timestep = [t]
|
| 324 |
|
| 325 |
timestep = torch.stack(timestep).to(self.device)
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
del latent_model_input
|
| 338 |
if offload_model:
|
| 339 |
torch.cuda.empty_cache()
|
|
|
|
| 146 |
callback = None,
|
| 147 |
enable_RIFLEx = False,
|
| 148 |
VAE_tile_size= 0,
|
| 149 |
+
joint_pass = False,
|
| 150 |
):
|
| 151 |
r"""
|
| 152 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
|
| 310 |
'pipeline' : self
|
| 311 |
}
|
| 312 |
|
| 313 |
+
arg_both= {
|
| 314 |
+
'context': [context[0]],
|
| 315 |
+
'context2': context_null,
|
| 316 |
+
'clip_fea': clip_context,
|
| 317 |
+
'seq_len': max_seq_len,
|
| 318 |
+
'y': [y],
|
| 319 |
+
'freqs' : freqs,
|
| 320 |
+
'pipeline' : self
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
if offload_model:
|
| 324 |
torch.cuda.empty_cache()
|
| 325 |
|
| 326 |
+
if self.model.enable_teacache:
|
| 327 |
+
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 328 |
+
|
| 329 |
# self.model.to(self.device)
|
| 330 |
if callback != None:
|
| 331 |
callback(-1, None)
|
|
|
|
| 336 |
timestep = [t]
|
| 337 |
|
| 338 |
timestep = torch.stack(timestep).to(self.device)
|
| 339 |
+
if joint_pass:
|
| 340 |
+
noise_pred_cond, noise_pred_uncond = self.model(
|
| 341 |
+
latent_model_input, t=timestep, current_step=i, **arg_both)
|
| 342 |
+
if self._interrupt:
|
| 343 |
+
return None
|
| 344 |
+
else:
|
| 345 |
+
noise_pred_cond = self.model(
|
| 346 |
+
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
| 347 |
+
if self._interrupt:
|
| 348 |
+
return None
|
| 349 |
+
if offload_model:
|
| 350 |
+
torch.cuda.empty_cache()
|
| 351 |
+
noise_pred_uncond = self.model(
|
| 352 |
+
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
| 353 |
+
if self._interrupt:
|
| 354 |
+
return None
|
| 355 |
del latent_model_input
|
| 356 |
if offload_model:
|
| 357 |
torch.cuda.empty_cache()
|
wan/modules/model.py
CHANGED
|
@@ -667,7 +667,43 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 667 |
|
| 668 |
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
| 669 |
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 671 |
def forward(
|
| 672 |
self,
|
| 673 |
x,
|
|
@@ -679,6 +715,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 679 |
freqs = None,
|
| 680 |
pipeline = None,
|
| 681 |
current_step = 0,
|
|
|
|
| 682 |
is_uncond=False
|
| 683 |
):
|
| 684 |
r"""
|
|
@@ -722,10 +759,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 722 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 723 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 724 |
assert seq_lens.max() <= seq_len
|
| 725 |
-
x
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
|
|
|
|
|
|
|
|
|
| 729 |
|
| 730 |
# time embeddings
|
| 731 |
e = self.time_embedding(
|
|
@@ -740,82 +780,105 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 740 |
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 741 |
for u in context
|
| 742 |
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
|
| 744 |
if clip_fea is not None:
|
| 745 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 746 |
context = torch.concat([context_clip, context], dim=1)
|
| 747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
should_calc = True
|
| 749 |
-
if self.enable_teacache
|
| 750 |
-
if
|
| 751 |
-
|
| 752 |
-
self.accumulated_rel_l1_distance_uncond = 0
|
| 753 |
-
self.teacache_skipped_cond_steps = 0
|
| 754 |
-
self.teacache_skipped_uncond_steps = 0
|
| 755 |
else:
|
| 756 |
-
|
| 757 |
-
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
|
| 758 |
-
|
| 759 |
-
temb_relative_l1 = relative_l1_distance(prev_input, e0)
|
| 760 |
-
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
|
| 761 |
-
|
| 762 |
-
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
|
| 763 |
-
should_calc = False
|
| 764 |
-
self.teacache_counter += 1
|
| 765 |
-
else:
|
| 766 |
should_calc = True
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
if is_uncond:
|
| 770 |
-
self.previous_modulated_input_uncond = e0.clone()
|
| 771 |
-
if should_calc:
|
| 772 |
-
self.previous_residual_uncond = None
|
| 773 |
else:
|
| 774 |
-
|
| 775 |
-
self.
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
if should_calc:
|
|
|
|
|
|
|
|
|
|
| 787 |
if self.enable_teacache:
|
| 788 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
# arguments
|
| 790 |
kwargs = dict(
|
| 791 |
-
e=e0,
|
| 792 |
seq_lens=seq_lens,
|
| 793 |
grid_sizes=grid_sizes,
|
| 794 |
freqs=freqs,
|
| 795 |
-
context=context,
|
| 796 |
context_lens=context_lens)
|
| 797 |
|
| 798 |
for block in self.blocks:
|
| 799 |
if pipeline._interrupt:
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
|
| 804 |
if self.enable_teacache:
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
self.previous_residual_uncond
|
| 809 |
else:
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
|
| 820 |
def unpatchify(self, x, grid_sizes):
|
| 821 |
r"""
|
|
|
|
| 667 |
|
| 668 |
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
| 669 |
|
| 670 |
+
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 671 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 672 |
+
e_list = []
|
| 673 |
+
for t in timesteps:
|
| 674 |
+
t = torch.stack([t])
|
| 675 |
+
e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
|
| 676 |
+
|
| 677 |
+
best_threshold = 0.01
|
| 678 |
+
best_diff = 1000
|
| 679 |
+
target_nb_steps= int(len(timesteps) / speed_factor)
|
| 680 |
+
threshold = 0.01
|
| 681 |
+
while threshold <= 0.6:
|
| 682 |
+
accumulated_rel_l1_distance =0
|
| 683 |
+
nb_steps = 0
|
| 684 |
+
diff = 1000
|
| 685 |
+
for i, t in enumerate(timesteps):
|
| 686 |
+
skip = False
|
| 687 |
+
if not (i<=start_step or i== len(timesteps)):
|
| 688 |
+
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
|
| 689 |
+
if accumulated_rel_l1_distance < threshold:
|
| 690 |
+
skip = True
|
| 691 |
+
else:
|
| 692 |
+
accumulated_rel_l1_distance = 0
|
| 693 |
+
previous_modulated_input = e_list[i]
|
| 694 |
+
if not skip:
|
| 695 |
+
nb_steps += 1
|
| 696 |
+
diff = abs(target_nb_steps - nb_steps)
|
| 697 |
+
if diff < best_diff:
|
| 698 |
+
best_threshold = threshold
|
| 699 |
+
best_diff = diff
|
| 700 |
+
elif diff > best_diff:
|
| 701 |
+
break
|
| 702 |
+
threshold += 0.01
|
| 703 |
+
self.rel_l1_thresh = best_threshold
|
| 704 |
+
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
|
| 705 |
+
return best_threshold
|
| 706 |
+
|
| 707 |
def forward(
|
| 708 |
self,
|
| 709 |
x,
|
|
|
|
| 715 |
freqs = None,
|
| 716 |
pipeline = None,
|
| 717 |
current_step = 0,
|
| 718 |
+
context2 = None,
|
| 719 |
is_uncond=False
|
| 720 |
):
|
| 721 |
r"""
|
|
|
|
| 759 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 760 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 761 |
assert seq_lens.max() <= seq_len
|
| 762 |
+
if len(x)==1 and seq_len == x[0].size(1):
|
| 763 |
+
x = x[0]
|
| 764 |
+
else:
|
| 765 |
+
x = torch.cat([
|
| 766 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 767 |
+
dim=1) for u in x
|
| 768 |
+
])
|
| 769 |
|
| 770 |
# time embeddings
|
| 771 |
e = self.time_embedding(
|
|
|
|
| 780 |
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 781 |
for u in context
|
| 782 |
]))
|
| 783 |
+
if context2!=None:
|
| 784 |
+
context2 = self.text_embedding(
|
| 785 |
+
torch.stack([
|
| 786 |
+
torch.cat(
|
| 787 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 788 |
+
for u in context2
|
| 789 |
+
]))
|
| 790 |
|
| 791 |
if clip_fea is not None:
|
| 792 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 793 |
context = torch.concat([context_clip, context], dim=1)
|
| 794 |
+
if context2 != None:
|
| 795 |
+
context2 = torch.concat([context_clip, context2], dim=1)
|
| 796 |
+
|
| 797 |
+
joint_pass = context2 != None
|
| 798 |
+
if joint_pass:
|
| 799 |
+
x_list = [x, x.clone()]
|
| 800 |
+
context_list = [context, context2]
|
| 801 |
+
is_uncond = False
|
| 802 |
+
else:
|
| 803 |
+
x_list = [x]
|
| 804 |
+
context_list = [context]
|
| 805 |
+
del x
|
| 806 |
should_calc = True
|
| 807 |
+
if self.enable_teacache:
|
| 808 |
+
if is_uncond:
|
| 809 |
+
should_calc = self.should_calc
|
|
|
|
|
|
|
|
|
|
| 810 |
else:
|
| 811 |
+
if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
should_calc = True
|
| 813 |
+
self.accumulated_rel_l1_distance = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
else:
|
| 815 |
+
rescale_func = np.poly1d(self.coefficients)
|
| 816 |
+
self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
| 817 |
+
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
| 818 |
+
should_calc = False
|
| 819 |
+
self.teacache_skipped_steps += 1
|
| 820 |
+
# print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
|
| 821 |
+
else:
|
| 822 |
+
should_calc = True
|
| 823 |
+
self.accumulated_rel_l1_distance = 0
|
| 824 |
+
self.previous_modulated_input = e
|
| 825 |
+
self.should_calc = should_calc
|
| 826 |
+
|
| 827 |
+
if not should_calc:
|
| 828 |
+
for i, x in enumerate(x_list):
|
| 829 |
+
x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
|
| 830 |
+
else:
|
| 831 |
if self.enable_teacache:
|
| 832 |
+
if joint_pass or is_uncond:
|
| 833 |
+
self.previous_residual_uncond = None
|
| 834 |
+
if joint_pass or not is_uncond:
|
| 835 |
+
self.previous_residual_cond = None
|
| 836 |
+
ori_hidden_states = x_list[0].clone()
|
| 837 |
# arguments
|
| 838 |
kwargs = dict(
|
| 839 |
+
# e=e0,
|
| 840 |
seq_lens=seq_lens,
|
| 841 |
grid_sizes=grid_sizes,
|
| 842 |
freqs=freqs,
|
| 843 |
+
# context=context,
|
| 844 |
context_lens=context_lens)
|
| 845 |
|
| 846 |
for block in self.blocks:
|
| 847 |
if pipeline._interrupt:
|
| 848 |
+
if joint_pass:
|
| 849 |
+
return None, None
|
| 850 |
+
else:
|
| 851 |
+
return [None]
|
| 852 |
+
for i, (x, context) in enumerate(zip(x_list, context_list)):
|
| 853 |
+
x_list[i] = block(x, context = context, e= e0, **kwargs)
|
| 854 |
+
del x
|
| 855 |
|
| 856 |
if self.enable_teacache:
|
| 857 |
+
if joint_pass:
|
| 858 |
+
self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
|
| 859 |
+
self.previous_residual_uncond = ori_hidden_states
|
| 860 |
+
torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
|
| 861 |
else:
|
| 862 |
+
residual = ori_hidden_states # just to have a readable code
|
| 863 |
+
torch.sub(x_list[0], ori_hidden_states, out=residual)
|
| 864 |
+
if i==1 or is_uncond:
|
| 865 |
+
self.previous_residual_uncond = residual
|
| 866 |
+
else:
|
| 867 |
+
self.previous_residual_cond = residual
|
| 868 |
+
residual, ori_hidden_states = None, None
|
| 869 |
+
|
| 870 |
+
for i, x in enumerate(x_list):
|
| 871 |
+
# head
|
| 872 |
+
x = self.head(x, e)
|
| 873 |
+
|
| 874 |
+
# unpatchify
|
| 875 |
+
x_list[i] = self.unpatchify(x, grid_sizes)
|
| 876 |
+
del x
|
| 877 |
+
|
| 878 |
+
if joint_pass:
|
| 879 |
+
return x_list[0][0], x_list[1][0]
|
| 880 |
+
else:
|
| 881 |
+
return [u.float() for u in x_list[0]]
|
| 882 |
|
| 883 |
def unpatchify(self, x, grid_sizes):
|
| 884 |
r"""
|
wan/text2video.py
CHANGED
|
@@ -131,7 +131,8 @@ class WanT2V:
|
|
| 131 |
offload_model=True,
|
| 132 |
callback = None,
|
| 133 |
enable_RIFLEx = None,
|
| 134 |
-
VAE_tile_size = 0
|
|
|
|
| 135 |
):
|
| 136 |
r"""
|
| 137 |
Generates video frames from text prompt using diffusion process.
|
|
@@ -240,8 +241,10 @@ class WanT2V:
|
|
| 240 |
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
| 241 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 242 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
|
|
|
| 243 |
|
| 244 |
-
|
|
|
|
| 245 |
if callback != None:
|
| 246 |
callback(-1, None)
|
| 247 |
for i, t in enumerate(tqdm(timesteps)):
|
|
@@ -251,14 +254,20 @@ class WanT2V:
|
|
| 251 |
timestep = torch.stack(timestep)
|
| 252 |
|
| 253 |
# self.model.to(self.device)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
del latent_model_input
|
| 264 |
noise_pred = noise_pred_uncond + guide_scale * (
|
|
|
|
| 131 |
offload_model=True,
|
| 132 |
callback = None,
|
| 133 |
enable_RIFLEx = None,
|
| 134 |
+
VAE_tile_size = 0,
|
| 135 |
+
joint_pass = False,
|
| 136 |
):
|
| 137 |
r"""
|
| 138 |
Generates video frames from text prompt using diffusion process.
|
|
|
|
| 241 |
freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
|
| 242 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 243 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 244 |
+
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 245 |
|
| 246 |
+
if self.model.enable_teacache:
|
| 247 |
+
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 248 |
if callback != None:
|
| 249 |
callback(-1, None)
|
| 250 |
for i, t in enumerate(tqdm(timesteps)):
|
|
|
|
| 254 |
timestep = torch.stack(timestep)
|
| 255 |
|
| 256 |
# self.model.to(self.device)
|
| 257 |
+
if joint_pass:
|
| 258 |
+
noise_pred_cond, noise_pred_uncond = self.model(
|
| 259 |
+
latent_model_input, t=timestep,current_step=i, **arg_both)
|
| 260 |
+
if self._interrupt:
|
| 261 |
+
return None
|
| 262 |
+
else:
|
| 263 |
+
noise_pred_cond = self.model(
|
| 264 |
+
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
| 265 |
+
if self._interrupt:
|
| 266 |
+
return None
|
| 267 |
+
noise_pred_uncond = self.model(
|
| 268 |
+
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
| 269 |
+
if self._interrupt:
|
| 270 |
+
return None
|
| 271 |
|
| 272 |
del latent_model_input
|
| 273 |
noise_pred = noise_pred_uncond + guide_scale * (
|