DeepBeepMeep commited on
Commit
948696e
·
1 Parent(s): 7d5369f

Added 10% boost, improved Loras and Teacache

Browse files
Files changed (5) hide show
  1. README.md +25 -10
  2. gradio_server.py +178 -76
  3. wan/image2video.py +30 -12
  4. wan/modules/model.py +123 -60
  5. wan/text2video.py +19 -10
README.md CHANGED
@@ -19,10 +19,11 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
22
- * Mar 03, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
23
- * Mar 03, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
 
24
  If you upgrade you will need to do a 'pip install -r requirements.txt' again.
25
- * Mar 03, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
26
  * Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
27
  * Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
28
  - Support for all Wan including the Image to Video model
@@ -152,15 +153,29 @@ python gradio_server.py --attention sdpa
152
 
153
  ### Loras support
154
 
155
- -- Ready to be used but theoretical as no lora for Wan have been released as of today. ---
156
 
157
- Every lora stored in the subfoler 'loras' will be automatically loaded. You will be then able to activate / desactive any of them when running the application.
158
 
159
- For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0), alternatively you may specify a list of floats multipliers separated by a "," that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
 
160
 
161
- You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio interface. Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
162
 
163
- Then you can pre activate loras corresponding to a preset when launching the gradio server:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  ```bash
165
  python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
166
  ```
@@ -180,11 +195,11 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
180
  --open-browser : open automatically Browser when launching Gradio Server\
181
  --lock-config : prevent modifying the video engine configuration from the interface\
182
  --share : create a shareable URL on huggingface so that your server can be accessed remotely\
183
- --multiple-images : Images as a starting point for new videos\
184
  --compile : turn on pytorch compilation\
185
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
186
  --profile no : default (4) : no of profile between 1 and 5\
187
- --preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware
188
 
189
  ### Profiles (for power users only)
190
  You can choose between 5 profiles, but two are really relevant here :
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
23
+ * Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
24
+ * Mar 04, 2025: 👋 Wan2.1GP v1.3: Support for Image to Video with multiples images for different images / prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
25
  If you upgrade you will need to do a 'pip install -r requirements.txt' again.
26
+ * Mar 04, 2025: 👋 Wan2.1GP v1.2: Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
27
  * Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
28
  * Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
29
  - Support for all Wan including the Image to Video model
 
153
 
154
  ### Loras support
155
 
 
156
 
157
+ Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
158
 
159
+ For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora shoud be separated by a space character or a carriage return. For instance:\
160
+ *1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
161
 
162
+ Alternatively for each Lora's multiplier you may specify a list of float numbers multipliers separated by a "," (no space) that gives the evolution of this Lora's multiplier over the steps. For instance let's assume there are 30 denoising steps and the multiplier is *0.9,0.8,0.7* then for the steps ranges 0-9, 10-19 and 20-29 the Lora multiplier will be respectively 0.9, 0.8 and 0.7.
163
 
164
+ If multiple Loras are defined, remember that each multiplier associated to different Loras should be separated by a space or a carriage return, so we can specify the evolution of multipliers for multiple Loras. For instance for two Loras (press Shift Return to force a carriage return):
165
+
166
+ ```
167
+ 0.9,0.8,0.7
168
+ 1.2,1.1,1.0
169
+ ```
170
+ You can edit, save or delete Loras presets (combinations of loras with their corresponding multipliers) directly from the gradio Web interface. These presets will save the *comment* part of the prompt that should contain some instructions how to use the corresponding the loras (for instance by specifying a trigger word or providing an example).A comment in the prompt is a line that starts that a #. It will be ignored by the video generator. For instance:
171
+
172
+ ```
173
+ # use they keyword ohnvx to trigger the Lora*
174
+ A ohnvx is driving a car
175
+ ```
176
+ Each preset, is a file with ".lset" extension stored in the loras directory and can be shared with other users
177
+
178
+ Last but not least you can pre activate Loras corresponding and prefill a prompt (comments only or full prompt) by specifying a preset when launching the gradio server:
179
  ```bash
180
  python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.lset' is a preset stored in the 'loras' folder
181
  ```
 
195
  --open-browser : open automatically Browser when launching Gradio Server\
196
  --lock-config : prevent modifying the video engine configuration from the interface\
197
  --share : create a shareable URL on huggingface so that your server can be accessed remotely\
198
+ --multiple-images : allow the users to choose multiple images as different starting points for new videos\
199
  --compile : turn on pytorch compilation\
200
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
201
  --profile no : default (4) : no of profile between 1 and 5\
202
+ --preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
203
 
204
  ### Profiles (for power users only)
205
  You can choose between 5 profiles, but two are really relevant here :
gradio_server.py CHANGED
@@ -57,18 +57,18 @@ def _parse_args():
57
  )
58
 
59
 
60
- parser.add_argument(
61
- "--lora-dir-i2v",
62
- type=str,
63
- default="loras_i2v",
64
- help="Path to a directory that contains Loras for i2v"
65
- )
66
 
67
 
68
  parser.add_argument(
69
  "--lora-dir",
70
  type=str,
71
- default="loras",
72
  help="Path to a directory that contains Loras"
73
  )
74
 
@@ -80,12 +80,12 @@ def _parse_args():
80
  help="Lora preset to preload"
81
  )
82
 
83
- parser.add_argument(
84
- "--lora-preset-i2v",
85
- type=str,
86
- default="",
87
- help="Lora preset to preload for i2v"
88
- )
89
 
90
  parser.add_argument(
91
  "--profile",
@@ -198,6 +198,7 @@ if not Path(server_config_filename).is_file():
198
  "text_encoder_filename" : text_encoder_choices[1],
199
  "compile" : "",
200
  "default_ui": "t2v",
 
201
  "vae_config": 0,
202
  "profile" : profile_type.LowRAM_LowVRAM }
203
 
@@ -223,6 +224,7 @@ if len(args.attention)> 0:
223
 
224
  profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
225
  compile = server_config.get("compile", "")
 
226
  vae_config = server_config.get("vae_config", 0)
227
  if len(args.vae_config) > 0:
228
  vae_config = int(args.vae_config)
@@ -234,13 +236,14 @@ if args.t2v:
234
  if args.i2v:
235
  use_image2video = True
236
 
237
- if use_image2video:
238
- lora_dir =args.lora_dir_i2v
239
- lora_preselected_preset = args.lora_preset_i2v
240
- else:
241
- lora_dir =args.lora_dir
242
- lora_preselected_preset = args.lora_preset
243
-
 
244
  default_tea_cache = 0
245
  # if args.fast : #or args.fastest
246
  # transformer_filename_t2v = transformer_choices_t2v[2]
@@ -321,8 +324,16 @@ def extract_preset(lset_name, loras):
321
  raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
322
 
323
  loras_mult_choices = lset["loras_mult"]
324
- return loras_choices, loras_mult_choices
 
325
 
 
 
 
 
 
 
 
326
  def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
327
  loras =[]
328
  loras_names = []
@@ -337,7 +348,7 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
337
  raise Exception("--lora-dir should be a path to a directory that contains Loras")
338
 
339
  default_lora_preset = ""
340
-
341
  if lora_dir != None:
342
  import glob
343
  dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
@@ -350,15 +361,16 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
350
 
351
  if len(loras) > 0:
352
  loras_names = [ Path(lora).stem for lora in loras ]
353
- offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
354
 
355
  if len(lora_preselected_preset) > 0:
356
  if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
357
  raise Exception(f"Unknown preset '{lora_preselected_preset}'")
358
  default_lora_preset = lora_preselected_preset
359
- default_loras_choices, default_loras_multis_str= extract_preset(default_lora_preset, loras)
360
-
361
- return loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
 
362
 
363
 
364
  def load_t2v_model(model_filename, value):
@@ -439,13 +451,13 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
439
  kwargs["budgets"] = { "*" : "70%" }
440
 
441
 
442
- loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
443
  offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
444
 
445
 
446
- return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
447
 
448
- wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
449
  gen_in_progress = False
450
 
451
  def get_auto_attention():
@@ -487,13 +499,14 @@ def apply_changes( state,
487
  profile_choice,
488
  vae_config_choice,
489
  default_ui_choice ="t2v",
 
490
  ):
491
  if args.lock_config:
492
  return
493
  if gen_in_progress:
494
  yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
495
  return
496
- global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets
497
  server_config = {"attention_mode" : attention_choice,
498
  "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
499
  "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
@@ -502,6 +515,7 @@ def apply_changes( state,
502
  "profile" : profile_choice,
503
  "vae_config" : vae_config_choice,
504
  "default_ui" : default_ui_choice,
 
505
  }
506
 
507
  if Path(server_config_filename).is_file():
@@ -529,7 +543,7 @@ def apply_changes( state,
529
  state["config_new"] = server_config
530
  state["config_old"] = old_server_config
531
 
532
- global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config
533
  attention_mode = server_config["attention_mode"]
534
  profile = server_config["profile"]
535
  compile = server_config["compile"]
@@ -537,8 +551,8 @@ def apply_changes( state,
537
  transformer_filename_i2v = server_config["transformer_filename_i2v"]
538
  text_encoder_filename = server_config["text_encoder_filename"]
539
  vae_config = server_config["vae_config"]
540
-
541
- if all(change in ["attention_mode", "vae_config", "default_ui"] for change in changes ):
542
  if "attention_mode" in changes:
543
  pass
544
 
@@ -548,7 +562,7 @@ def apply_changes( state,
548
  offloadobj = None
549
  yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
550
 
551
- wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
552
 
553
 
554
  yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
@@ -727,7 +741,9 @@ def generate_video(
727
  if len(prompt) ==0:
728
  return
729
  prompts = prompt.replace("\r", "").split("\n")
730
-
 
 
731
  if use_image2video:
732
  if image_to_continue is not None:
733
  if isinstance(image_to_continue, list):
@@ -772,6 +788,9 @@ def generate_video(
772
  return False
773
  list_mult_choices_nums = []
774
  if len(loras_mult_choices) > 0:
 
 
 
775
  list_mult_choices_str = loras_mult_choices.split(" ")
776
  for i, mult in enumerate(list_mult_choices_str):
777
  mult = mult.strip()
@@ -805,18 +824,36 @@ def generate_video(
805
  # VAE Tiling
806
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
807
 
808
-
809
  # TeaCache
810
  trans = wan_model.model
811
  trans.enable_teacache = tea_cache > 0
812
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  import random
814
  if seed == None or seed <0:
815
  seed = random.randint(0, 999999999)
816
 
817
  file_list = []
818
  state["file_list"] = file_list
819
- from einops import rearrange
820
  save_path = os.path.join(os.getcwd(), "gradio_outputs")
821
  os.makedirs(save_path, exist_ok=True)
822
  video_no = 0
@@ -830,14 +867,12 @@ def generate_video(
830
 
831
  if trans.enable_teacache:
832
  trans.teacache_counter = 0
833
- trans.rel_l1_thresh = tea_cache
834
- trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
 
 
835
  trans.previous_residual_uncond = None
836
- trans.previous_modulated_input_uncond = None
837
  trans.previous_residual_cond = None
838
- trans.previous_modulated_input_cond= None
839
-
840
- trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
841
 
842
  video_no += 1
843
  status = f"Video {video_no}/{total_video}"
@@ -853,7 +888,7 @@ def generate_video(
853
  if use_image2video:
854
  samples = wan_model.generate(
855
  prompt,
856
- image_to_continue[ (video_no-1) % len(image_to_continue)],
857
  frame_num=(video_length // 4)* 4 + 1,
858
  max_area=MAX_AREA_CONFIGS[resolution],
859
  shift=flow_shift,
@@ -864,7 +899,8 @@ def generate_video(
864
  offload_model=False,
865
  callback=callback,
866
  enable_RIFLEx = enable_RIFLEx,
867
- VAE_tile_size = VAE_tile_size
 
868
  )
869
 
870
  else:
@@ -880,7 +916,8 @@ def generate_video(
880
  offload_model=False,
881
  callback=callback,
882
  enable_RIFLEx = enable_RIFLEx,
883
- VAE_tile_size = VAE_tile_size
 
884
  )
885
  except Exception as e:
886
  gen_in_progress = False
@@ -911,6 +948,7 @@ def generate_video(
911
  raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
912
 
913
  if trans.enable_teacache:
 
914
  trans.previous_residual_uncond = None
915
  trans.previous_residual_cond = None
916
 
@@ -957,7 +995,25 @@ def generate_video(
957
 
958
  new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
959
 
960
- def save_lset(lset_name, loras_choices, loras_mult_choices):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  global loras_presets
962
 
963
  if len(lset_name) == 0 or lset_name== new_preset_msg:
@@ -968,6 +1024,16 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
968
 
969
  loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
970
  lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
 
 
 
 
 
 
 
 
 
 
971
  lset_name_filename = lset_name + ".lset"
972
  full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
973
 
@@ -982,7 +1048,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices):
982
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
983
  lset_choices.append( (new_preset_msg, ""))
984
 
985
- return gr.Dropdown(choices=lset_choices, value= lset_name)
986
 
987
  def delete_lset(lset_name):
988
  global loras_presets
@@ -1000,23 +1066,31 @@ def delete_lset(lset_name):
1000
 
1001
  lset_choices = [ (preset, preset) for preset in loras_presets]
1002
  lset_choices.append((new_preset_msg, ""))
1003
- return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1])
 
1004
 
1005
- def apply_lset(lset_name, loras_choices, loras_mult_choices):
1006
 
1007
  if len(lset_name) == 0 or lset_name== new_preset_msg:
1008
  gr.Info("Please choose a preset in the list or create one")
1009
  else:
1010
- loras_choices, loras_mult_choices= extract_preset(lset_name, loras)
 
 
 
 
 
 
 
1011
  gr.Info(f"Lora Preset '{lset_name}' has been applied")
1012
 
1013
- return loras_choices, loras_mult_choices
1014
 
1015
  def create_demo():
1016
 
1017
  default_inference_steps = 30
 
1018
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
1019
-
1020
  with gr.Blocks() as demo:
1021
  state = gr.State({})
1022
 
@@ -1130,6 +1204,16 @@ def create_demo():
1130
  label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
1131
  )
1132
 
 
 
 
 
 
 
 
 
 
 
1133
  profile_choice = gr.Dropdown(
1134
  choices=[
1135
  ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
@@ -1161,16 +1245,12 @@ def create_demo():
1161
  video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
1162
  if args.multiple_images:
1163
  image_to_continue = gr.Gallery(
1164
- label="Images as a starting point for new videos", type ="numpy", #file_types= "image",
1165
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
1166
  else:
1167
  image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
1168
 
1169
- if use_image2video:
1170
- prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.", lines=3)
1171
- else:
1172
- prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos)", value="A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect.", lines=3)
1173
-
1174
 
1175
  with gr.Row():
1176
  if use_image2video:
@@ -1223,9 +1303,21 @@ def create_demo():
1223
  # with gr.Column():
1224
  with gr.Row(height=17):
1225
  apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
 
 
 
 
 
 
 
 
 
1226
  with gr.Row(height=17):
 
 
1227
  save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
1228
  delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
 
1229
 
1230
 
1231
  loras_choices = gr.Dropdown(
@@ -1237,7 +1329,7 @@ def create_demo():
1237
  visible= len(loras)>0,
1238
  label="Activated Loras"
1239
  )
1240
- loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns", value=default_loras_multis_str, visible= len(loras)>0 )
1241
 
1242
  show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
1243
  with gr.Row(visible=False) as advanced_row:
@@ -1250,18 +1342,23 @@ def create_demo():
1250
  guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
1251
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
1252
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1253
- tea_cache_setting = gr.Dropdown(
1254
- choices=[
1255
- ("Tea Cache Disabled", 0),
1256
- ("0.03 (around x1.6 speed up)", 0.03),
1257
- ("0.05 (around x2 speed up)", 0.05),
1258
- ("0.10 (around x3 speed up)", 0.1),
1259
- ],
1260
- value=default_tea_cache,
1261
- visible=True,
1262
- label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
1263
- )
1264
- tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
 
 
 
 
 
1265
 
1266
  RIFLEx_setting = gr.Dropdown(
1267
  choices=[
@@ -1283,9 +1380,13 @@ def create_demo():
1283
  generate_btn = gr.Button("Generate")
1284
  abort_btn = gr.Button("Abort")
1285
 
1286
- save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices], outputs=[lset_name])
1287
- delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name])
1288
- apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices], outputs=[loras_choices, loras_mult_choices])
 
 
 
 
1289
 
1290
  gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1291
 
@@ -1335,6 +1436,7 @@ def create_demo():
1335
  profile_choice,
1336
  vae_config_choice,
1337
  default_ui_choice,
 
1338
  ],
1339
  outputs= msg
1340
  ).then(
 
57
  )
58
 
59
 
60
+ # parser.add_argument(
61
+ # "--lora-dir-i2v",
62
+ # type=str,
63
+ # default="loras_i2v",
64
+ # help="Path to a directory that contains Loras for i2v"
65
+ # )
66
 
67
 
68
  parser.add_argument(
69
  "--lora-dir",
70
  type=str,
71
+ default="",
72
  help="Path to a directory that contains Loras"
73
  )
74
 
 
80
  help="Lora preset to preload"
81
  )
82
 
83
+ # parser.add_argument(
84
+ # "--lora-preset-i2v",
85
+ # type=str,
86
+ # default="",
87
+ # help="Lora preset to preload for i2v"
88
+ # )
89
 
90
  parser.add_argument(
91
  "--profile",
 
198
  "text_encoder_filename" : text_encoder_choices[1],
199
  "compile" : "",
200
  "default_ui": "t2v",
201
+ "boost" : 1,
202
  "vae_config": 0,
203
  "profile" : profile_type.LowRAM_LowVRAM }
204
 
 
224
 
225
  profile = force_profile_no if force_profile_no >=0 else server_config["profile"]
226
  compile = server_config.get("compile", "")
227
+ boost = server_config.get("boost", 1)
228
  vae_config = server_config.get("vae_config", 0)
229
  if len(args.vae_config) > 0:
230
  vae_config = int(args.vae_config)
 
236
  if args.i2v:
237
  use_image2video = True
238
 
239
+ # if use_image2video:
240
+ # lora_dir =args.lora_dir_i2v
241
+ # lora_preselected_preset = args.lora_preset_i2v
242
+ # else:
243
+ lora_dir =args.lora_dir
244
+ if len(lora_dir) ==0:
245
+ lora_dir = "loras_i2v" if use_image2video else "loras"
246
+ lora_preselected_preset = args.lora_preset
247
  default_tea_cache = 0
248
  # if args.fast : #or args.fastest
249
  # transformer_filename_t2v = transformer_choices_t2v[2]
 
324
  raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
325
 
326
  loras_mult_choices = lset["loras_mult"]
327
+ prompt = lset.get("prompt", "")
328
+ return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
329
 
330
+ def get_default_prompt(i2v):
331
+ if i2v:
332
+ return "Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field."
333
+ else:
334
+ return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
335
+
336
+
337
  def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
338
  loras =[]
339
  loras_names = []
 
348
  raise Exception("--lora-dir should be a path to a directory that contains Loras")
349
 
350
  default_lora_preset = ""
351
+ default_prompt = ""
352
  if lora_dir != None:
353
  import glob
354
  dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
 
361
 
362
  if len(loras) > 0:
363
  loras_names = [ Path(lora).stem for lora in loras ]
364
+ offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
365
 
366
  if len(lora_preselected_preset) > 0:
367
  if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
368
  raise Exception(f"Unknown preset '{lora_preselected_preset}'")
369
  default_lora_preset = lora_preselected_preset
370
+ default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
371
+ if len(default_prompt) == 0:
372
+ default_prompt = get_default_prompt(use_image2video)
373
+ return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
374
 
375
 
376
  def load_t2v_model(model_filename, value):
 
451
  kwargs["budgets"] = { "*" : "70%" }
452
 
453
 
454
+ loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
455
  offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
456
 
457
 
458
+ return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
459
 
460
+ wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
461
  gen_in_progress = False
462
 
463
  def get_auto_attention():
 
499
  profile_choice,
500
  vae_config_choice,
501
  default_ui_choice ="t2v",
502
+ boost_choice = 1
503
  ):
504
  if args.lock_config:
505
  return
506
  if gen_in_progress:
507
  yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
508
  return
509
+ global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
510
  server_config = {"attention_mode" : attention_choice,
511
  "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
512
  "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice], ##########
 
515
  "profile" : profile_choice,
516
  "vae_config" : vae_config_choice,
517
  "default_ui" : default_ui_choice,
518
+ "boost" : boost_choice,
519
  }
520
 
521
  if Path(server_config_filename).is_file():
 
543
  state["config_new"] = server_config
544
  state["config_old"] = old_server_config
545
 
546
+ global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
547
  attention_mode = server_config["attention_mode"]
548
  profile = server_config["profile"]
549
  compile = server_config["compile"]
 
551
  transformer_filename_i2v = server_config["transformer_filename_i2v"]
552
  text_encoder_filename = server_config["text_encoder_filename"]
553
  vae_config = server_config["vae_config"]
554
+ boost = server_config["boost"]
555
+ if all(change in ["attention_mode", "vae_config", "default_ui", "boost"] for change in changes ):
556
  if "attention_mode" in changes:
557
  pass
558
 
 
562
  offloadobj = None
563
  yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
564
 
565
+ wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
566
 
567
 
568
  yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
 
741
  if len(prompt) ==0:
742
  return
743
  prompts = prompt.replace("\r", "").split("\n")
744
+ prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
745
+ if len(prompts) ==0:
746
+ return
747
  if use_image2video:
748
  if image_to_continue is not None:
749
  if isinstance(image_to_continue, list):
 
788
  return False
789
  list_mult_choices_nums = []
790
  if len(loras_mult_choices) > 0:
791
+ loras_mult_choices_list = loras_mult_choices.replace("\r", "").split("\n")
792
+ loras_mult_choices_list = [multi for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
793
+ loras_mult_choices = " ".join(loras_mult_choices_list)
794
  list_mult_choices_str = loras_mult_choices.split(" ")
795
  for i, mult in enumerate(list_mult_choices_str):
796
  mult = mult.strip()
 
824
  # VAE Tiling
825
  device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
826
 
827
+ joint_pass = boost ==1
828
  # TeaCache
829
  trans = wan_model.model
830
  trans.enable_teacache = tea_cache > 0
831
+ if trans.enable_teacache:
832
+ if use_image2video:
833
+ if '480p' in transformer_filename_i2v:
834
+ # teacache_thresholds = [0.13, .19, 0.26]
835
+ trans.coefficients = [-3.02331670e+02, 2.23948934e+02, -5.25463970e+01, 5.87348440e+00, -2.01973289e-01]
836
+ elif '720p' in transformer_filename_i2v:
837
+ teacache_thresholds = [0.18, 0.2 , 0.3]
838
+ trans.coefficients = [-114.36346466, 65.26524496, -18.82220707, 4.91518089, -0.23412683]
839
+ else:
840
+ raise gr.Error("Teacache not supported for this model")
841
+ else:
842
+ if '1.3B' in transformer_filename_t2v:
843
+ # teacache_thresholds= [0.05, 0.07, 0.08]
844
+ trans.coefficients = [2.39676752e+03, -1.31110545e+03, 2.01331979e+02, -8.29855975e+00, 1.37887774e-01]
845
+ elif '14B' in transformer_filename_t2v:
846
+ # teacache_thresholds = [0.14, 0.15, 0.2]
847
+ trans.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
848
+ else:
849
+ raise gr.Error("Teacache not supported for this model")
850
+
851
  import random
852
  if seed == None or seed <0:
853
  seed = random.randint(0, 999999999)
854
 
855
  file_list = []
856
  state["file_list"] = file_list
 
857
  save_path = os.path.join(os.getcwd(), "gradio_outputs")
858
  os.makedirs(save_path, exist_ok=True)
859
  video_no = 0
 
867
 
868
  if trans.enable_teacache:
869
  trans.teacache_counter = 0
870
+ trans.teacache_multiplier = tea_cache
871
+ trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
872
+ trans.num_steps = num_inference_steps
873
+ trans.teacache_skipped_steps = 0
874
  trans.previous_residual_uncond = None
 
875
  trans.previous_residual_cond = None
 
 
 
876
 
877
  video_no += 1
878
  status = f"Video {video_no}/{total_video}"
 
888
  if use_image2video:
889
  samples = wan_model.generate(
890
  prompt,
891
+ image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
892
  frame_num=(video_length // 4)* 4 + 1,
893
  max_area=MAX_AREA_CONFIGS[resolution],
894
  shift=flow_shift,
 
899
  offload_model=False,
900
  callback=callback,
901
  enable_RIFLEx = enable_RIFLEx,
902
+ VAE_tile_size = VAE_tile_size,
903
+ joint_pass = joint_pass,
904
  )
905
 
906
  else:
 
916
  offload_model=False,
917
  callback=callback,
918
  enable_RIFLEx = enable_RIFLEx,
919
+ VAE_tile_size = VAE_tile_size,
920
+ joint_pass = joint_pass,
921
  )
922
  except Exception as e:
923
  gen_in_progress = False
 
948
  raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
949
 
950
  if trans.enable_teacache:
951
+ print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
952
  trans.previous_residual_uncond = None
953
  trans.previous_residual_cond = None
954
 
 
995
 
996
  new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
997
 
998
+
999
+ def validate_delete_lset(lset_name):
1000
+ if len(lset_name) == 0 or lset_name == new_preset_msg:
1001
+ gr.Info(f"Choose a Preset to delete")
1002
+ return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
1003
+ else:
1004
+ return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
1005
+
1006
+ def validate_save_lset(lset_name):
1007
+ if len(lset_name) == 0 or lset_name == new_preset_msg:
1008
+ gr.Info("Please enter a name for the preset")
1009
+ return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
1010
+ else:
1011
+ return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
1012
+
1013
+ def cancel_lset():
1014
+ return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1015
+
1016
+ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
1017
  global loras_presets
1018
 
1019
  if len(lset_name) == 0 or lset_name== new_preset_msg:
 
1024
 
1025
  loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ]
1026
  lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices}
1027
+ if save_lset_prompt_cbox!=1:
1028
+ prompts = prompt.replace("\r", "").split("\n")
1029
+ prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")]
1030
+ prompt = "\n".join(prompts)
1031
+
1032
+ if len(prompt) > 0:
1033
+ lset["prompt"] = prompt
1034
+ lset["full_prompt"] = save_lset_prompt_cbox ==1
1035
+
1036
+
1037
  lset_name_filename = lset_name + ".lset"
1038
  full_lset_name_filename = os.path.join(lora_dir, lset_name_filename)
1039
 
 
1048
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
1049
  lset_choices.append( (new_preset_msg, ""))
1050
 
1051
+ return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1052
 
1053
  def delete_lset(lset_name):
1054
  global loras_presets
 
1066
 
1067
  lset_choices = [ (preset, preset) for preset in loras_presets]
1068
  lset_choices.append((new_preset_msg, ""))
1069
+ return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
1070
+
1071
 
1072
+ def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
1073
 
1074
  if len(lset_name) == 0 or lset_name== new_preset_msg:
1075
  gr.Info("Please choose a preset in the list or create one")
1076
  else:
1077
+ loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
1078
+ if full_prompt:
1079
+ prompt = preset_prompt
1080
+ elif len(preset_prompt) > 0:
1081
+ prompts = prompt.replace("\r", "").split("\n")
1082
+ prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
1083
+ prompt = "\n".join(prompts)
1084
+ prompt = preset_prompt + '\n' + prompt
1085
  gr.Info(f"Lora Preset '{lset_name}' has been applied")
1086
 
1087
+ return loras_choices, loras_mult_choices, prompt
1088
 
1089
  def create_demo():
1090
 
1091
  default_inference_steps = 30
1092
+
1093
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
 
1094
  with gr.Blocks() as demo:
1095
  state = gr.State({})
1096
 
 
1204
  label="VAE Tiling - reduce the high VRAM requirements for VAE decoding and VAE encoding (if enabled it will be slower)"
1205
  )
1206
 
1207
+ boost_choice = gr.Dropdown(
1208
+ choices=[
1209
+ # ("Auto (ON if Video longer than 5s)", 0),
1210
+ ("ON", 1),
1211
+ ("OFF", 2),
1212
+ ],
1213
+ value=boost,
1214
+ label="Boost: Give a 10% speed speedup without losing quality at the cost of a litle VRAM (up to 1GB for max frames and resolution)"
1215
+ )
1216
+
1217
  profile_choice = gr.Dropdown(
1218
  choices=[
1219
  ("HighRAM_HighVRAM, profile 1: at least 48 GB of RAM and 24 GB of VRAM, the fastest for short videos a RTX 3090 / RTX 4090", 1),
 
1245
  video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
1246
  if args.multiple_images:
1247
  image_to_continue = gr.Gallery(
1248
+ label="Images as a starting point for new videos", type ="pil", #file_types= "image",
1249
  columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=use_image2video)
1250
  else:
1251
  image_to_continue = gr.Image(label= "Image as a starting point for a new video", visible=use_image2video)
1252
 
1253
+ prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
 
 
 
 
1254
 
1255
  with gr.Row():
1256
  if use_image2video:
 
1303
  # with gr.Column():
1304
  with gr.Row(height=17):
1305
  apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
1306
+ # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
1307
+ save_lset_prompt_drop= gr.Dropdown(
1308
+ choices=[
1309
+ ("Save Prompt Comments Only", 0),
1310
+ ("Save Full Prompt", 1)
1311
+ ], show_label= False, container=False, visible= False
1312
+ )
1313
+
1314
+
1315
  with gr.Row(height=17):
1316
+ confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
1317
+ confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
1318
  save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
1319
  delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
1320
+ cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
1321
 
1322
 
1323
  loras_choices = gr.Dropdown(
 
1329
  visible= len(loras)>0,
1330
  label="Activated Loras"
1331
  )
1332
+ loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
1333
 
1334
  show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
1335
  with gr.Row(visible=False) as advanced_row:
 
1342
  guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
1343
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
1344
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1345
+ with gr.Row():
1346
+ gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
1347
+ with gr.Row():
1348
+ tea_cache_setting = gr.Dropdown(
1349
+ choices=[
1350
+ ("Tea Cache Disabled", 0),
1351
+ ("around x1.5 speed up", 1.5),
1352
+ ("around x1.75 speed up", 1.75),
1353
+ ("around x2 speed up", 2.0),
1354
+ ("around x2.25 speed up", 2.25),
1355
+ ("around x2.5 speed up", 2.5),
1356
+ ],
1357
+ value=default_tea_cache,
1358
+ visible=True,
1359
+ label="Tea Cache Global Acceleration"
1360
+ )
1361
+ tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
1362
 
1363
  RIFLEx_setting = gr.Dropdown(
1364
  choices=[
 
1380
  generate_btn = gr.Button("Generate")
1381
  abort_btn = gr.Button("Abort")
1382
 
1383
+ save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1384
+ confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1385
+ delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1386
+ confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1387
+ cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
1388
+
1389
+ apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
1390
 
1391
  gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1392
 
 
1436
  profile_choice,
1437
  vae_config_choice,
1438
  default_ui_choice,
1439
+ boost_choice,
1440
  ],
1441
  outputs= msg
1442
  ).then(
wan/image2video.py CHANGED
@@ -146,7 +146,7 @@ class WanI2V:
146
  callback = None,
147
  enable_RIFLEx = False,
148
  VAE_tile_size= 0,
149
-
150
  ):
151
  r"""
152
  Generates video frames from input image and text prompt using diffusion process.
@@ -310,9 +310,22 @@ class WanI2V:
310
  'pipeline' : self
311
  }
312
 
 
 
 
 
 
 
 
 
 
 
313
  if offload_model:
314
  torch.cuda.empty_cache()
315
 
 
 
 
316
  # self.model.to(self.device)
317
  if callback != None:
318
  callback(-1, None)
@@ -323,17 +336,22 @@ class WanI2V:
323
  timestep = [t]
324
 
325
  timestep = torch.stack(timestep).to(self.device)
326
-
327
- noise_pred_cond = self.model(
328
- latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
329
- if self._interrupt:
330
- return None
331
- if offload_model:
332
- torch.cuda.empty_cache()
333
- noise_pred_uncond = self.model(
334
- latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
335
- if self._interrupt:
336
- return None
 
 
 
 
 
337
  del latent_model_input
338
  if offload_model:
339
  torch.cuda.empty_cache()
 
146
  callback = None,
147
  enable_RIFLEx = False,
148
  VAE_tile_size= 0,
149
+ joint_pass = False,
150
  ):
151
  r"""
152
  Generates video frames from input image and text prompt using diffusion process.
 
310
  'pipeline' : self
311
  }
312
 
313
+ arg_both= {
314
+ 'context': [context[0]],
315
+ 'context2': context_null,
316
+ 'clip_fea': clip_context,
317
+ 'seq_len': max_seq_len,
318
+ 'y': [y],
319
+ 'freqs' : freqs,
320
+ 'pipeline' : self
321
+ }
322
+
323
  if offload_model:
324
  torch.cuda.empty_cache()
325
 
326
+ if self.model.enable_teacache:
327
+ self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
328
+
329
  # self.model.to(self.device)
330
  if callback != None:
331
  callback(-1, None)
 
336
  timestep = [t]
337
 
338
  timestep = torch.stack(timestep).to(self.device)
339
+ if joint_pass:
340
+ noise_pred_cond, noise_pred_uncond = self.model(
341
+ latent_model_input, t=timestep, current_step=i, **arg_both)
342
+ if self._interrupt:
343
+ return None
344
+ else:
345
+ noise_pred_cond = self.model(
346
+ latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
347
+ if self._interrupt:
348
+ return None
349
+ if offload_model:
350
+ torch.cuda.empty_cache()
351
+ noise_pred_uncond = self.model(
352
+ latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
353
+ if self._interrupt:
354
+ return None
355
  del latent_model_input
356
  if offload_model:
357
  torch.cuda.empty_cache()
wan/modules/model.py CHANGED
@@ -667,7 +667,43 @@ class WanModel(ModelMixin, ConfigMixin):
667
 
668
  return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
669
 
670
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
671
  def forward(
672
  self,
673
  x,
@@ -679,6 +715,7 @@ class WanModel(ModelMixin, ConfigMixin):
679
  freqs = None,
680
  pipeline = None,
681
  current_step = 0,
 
682
  is_uncond=False
683
  ):
684
  r"""
@@ -722,10 +759,13 @@ class WanModel(ModelMixin, ConfigMixin):
722
  x = [u.flatten(2).transpose(1, 2) for u in x]
723
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
724
  assert seq_lens.max() <= seq_len
725
- x = torch.cat([
726
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
727
- dim=1) for u in x
728
- ])
 
 
 
729
 
730
  # time embeddings
731
  e = self.time_embedding(
@@ -740,82 +780,105 @@ class WanModel(ModelMixin, ConfigMixin):
740
  [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
741
  for u in context
742
  ]))
 
 
 
 
 
 
 
743
 
744
  if clip_fea is not None:
745
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
746
  context = torch.concat([context_clip, context], dim=1)
747
- # deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
 
 
 
 
 
 
 
 
 
 
 
748
  should_calc = True
749
- if self.enable_teacache and current_step >= self.teacache_start_step:
750
- if current_step == self.teacache_start_step:
751
- self.accumulated_rel_l1_distance_cond = 0
752
- self.accumulated_rel_l1_distance_uncond = 0
753
- self.teacache_skipped_cond_steps = 0
754
- self.teacache_skipped_uncond_steps = 0
755
  else:
756
- prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
757
- acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
758
-
759
- temb_relative_l1 = relative_l1_distance(prev_input, e0)
760
- setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
761
-
762
- if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
763
- should_calc = False
764
- self.teacache_counter += 1
765
- else:
766
  should_calc = True
767
- setattr(self, acc_distance_attr, 0)
768
-
769
- if is_uncond:
770
- self.previous_modulated_input_uncond = e0.clone()
771
- if should_calc:
772
- self.previous_residual_uncond = None
773
  else:
774
- x += self.previous_residual_uncond
775
- self.teacache_skipped_cond_steps += 1
776
- # print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
777
- else:
778
- self.previous_modulated_input_cond = e0.clone()
779
- if should_calc:
780
- self.previous_residual_cond = None
781
- else:
782
- x += self.previous_residual_cond
783
- self.teacache_skipped_uncond_steps += 1
784
- # print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
785
-
786
- if should_calc:
 
 
 
787
  if self.enable_teacache:
788
- ori_hidden_states = x.clone()
 
 
 
 
789
  # arguments
790
  kwargs = dict(
791
- e=e0,
792
  seq_lens=seq_lens,
793
  grid_sizes=grid_sizes,
794
  freqs=freqs,
795
- context=context,
796
  context_lens=context_lens)
797
 
798
  for block in self.blocks:
799
  if pipeline._interrupt:
800
- return [None]
801
-
802
- x = block(x, **kwargs)
 
 
 
 
803
 
804
  if self.enable_teacache:
805
- residual = ori_hidden_states # just to have a readable code
806
- torch.sub(x, ori_hidden_states, out=residual)
807
- if is_uncond:
808
- self.previous_residual_uncond = residual
809
  else:
810
- self.previous_residual_cond = residual
811
- del residual, ori_hidden_states
812
-
813
- # head
814
- x = self.head(x, e)
815
-
816
- # unpatchify
817
- x = self.unpatchify(x, grid_sizes)
818
- return [u.float() for u in x]
 
 
 
 
 
 
 
 
 
 
 
819
 
820
  def unpatchify(self, x, grid_sizes):
821
  r"""
 
667
 
668
  return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
669
 
670
+ def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
671
+ rescale_func = np.poly1d(self.coefficients)
672
+ e_list = []
673
+ for t in timesteps:
674
+ t = torch.stack([t])
675
+ e_list.append(self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t)))
676
+
677
+ best_threshold = 0.01
678
+ best_diff = 1000
679
+ target_nb_steps= int(len(timesteps) / speed_factor)
680
+ threshold = 0.01
681
+ while threshold <= 0.6:
682
+ accumulated_rel_l1_distance =0
683
+ nb_steps = 0
684
+ diff = 1000
685
+ for i, t in enumerate(timesteps):
686
+ skip = False
687
+ if not (i<=start_step or i== len(timesteps)):
688
+ accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
689
+ if accumulated_rel_l1_distance < threshold:
690
+ skip = True
691
+ else:
692
+ accumulated_rel_l1_distance = 0
693
+ previous_modulated_input = e_list[i]
694
+ if not skip:
695
+ nb_steps += 1
696
+ diff = abs(target_nb_steps - nb_steps)
697
+ if diff < best_diff:
698
+ best_threshold = threshold
699
+ best_diff = diff
700
+ elif diff > best_diff:
701
+ break
702
+ threshold += 0.01
703
+ self.rel_l1_thresh = best_threshold
704
+ print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
705
+ return best_threshold
706
+
707
  def forward(
708
  self,
709
  x,
 
715
  freqs = None,
716
  pipeline = None,
717
  current_step = 0,
718
+ context2 = None,
719
  is_uncond=False
720
  ):
721
  r"""
 
759
  x = [u.flatten(2).transpose(1, 2) for u in x]
760
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
761
  assert seq_lens.max() <= seq_len
762
+ if len(x)==1 and seq_len == x[0].size(1):
763
+ x = x[0]
764
+ else:
765
+ x = torch.cat([
766
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
767
+ dim=1) for u in x
768
+ ])
769
 
770
  # time embeddings
771
  e = self.time_embedding(
 
780
  [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
781
  for u in context
782
  ]))
783
+ if context2!=None:
784
+ context2 = self.text_embedding(
785
+ torch.stack([
786
+ torch.cat(
787
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
788
+ for u in context2
789
+ ]))
790
 
791
  if clip_fea is not None:
792
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
793
  context = torch.concat([context_clip, context], dim=1)
794
+ if context2 != None:
795
+ context2 = torch.concat([context_clip, context2], dim=1)
796
+
797
+ joint_pass = context2 != None
798
+ if joint_pass:
799
+ x_list = [x, x.clone()]
800
+ context_list = [context, context2]
801
+ is_uncond = False
802
+ else:
803
+ x_list = [x]
804
+ context_list = [context]
805
+ del x
806
  should_calc = True
807
+ if self.enable_teacache:
808
+ if is_uncond:
809
+ should_calc = self.should_calc
 
 
 
810
  else:
811
+ if current_step <= self.teacache_start_step or current_step == self.num_steps-1:
 
 
 
 
 
 
 
 
 
812
  should_calc = True
813
+ self.accumulated_rel_l1_distance = 0
 
 
 
 
 
814
  else:
815
+ rescale_func = np.poly1d(self.coefficients)
816
+ self.accumulated_rel_l1_distance += rescale_func(((e-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
817
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
818
+ should_calc = False
819
+ self.teacache_skipped_steps += 1
820
+ # print(f"Teacache Skipped Step:{self.teacache_skipped_steps}/{current_step}" )
821
+ else:
822
+ should_calc = True
823
+ self.accumulated_rel_l1_distance = 0
824
+ self.previous_modulated_input = e
825
+ self.should_calc = should_calc
826
+
827
+ if not should_calc:
828
+ for i, x in enumerate(x_list):
829
+ x += self.previous_residual_uncond if i==1 or is_uncond else self.previous_residual_cond
830
+ else:
831
  if self.enable_teacache:
832
+ if joint_pass or is_uncond:
833
+ self.previous_residual_uncond = None
834
+ if joint_pass or not is_uncond:
835
+ self.previous_residual_cond = None
836
+ ori_hidden_states = x_list[0].clone()
837
  # arguments
838
  kwargs = dict(
839
+ # e=e0,
840
  seq_lens=seq_lens,
841
  grid_sizes=grid_sizes,
842
  freqs=freqs,
843
+ # context=context,
844
  context_lens=context_lens)
845
 
846
  for block in self.blocks:
847
  if pipeline._interrupt:
848
+ if joint_pass:
849
+ return None, None
850
+ else:
851
+ return [None]
852
+ for i, (x, context) in enumerate(zip(x_list, context_list)):
853
+ x_list[i] = block(x, context = context, e= e0, **kwargs)
854
+ del x
855
 
856
  if self.enable_teacache:
857
+ if joint_pass:
858
+ self.previous_residual_cond = torch.sub(x_list[0], ori_hidden_states)
859
+ self.previous_residual_uncond = ori_hidden_states
860
+ torch.sub(x_list[1], ori_hidden_states, out=self.previous_residual_uncond)
861
  else:
862
+ residual = ori_hidden_states # just to have a readable code
863
+ torch.sub(x_list[0], ori_hidden_states, out=residual)
864
+ if i==1 or is_uncond:
865
+ self.previous_residual_uncond = residual
866
+ else:
867
+ self.previous_residual_cond = residual
868
+ residual, ori_hidden_states = None, None
869
+
870
+ for i, x in enumerate(x_list):
871
+ # head
872
+ x = self.head(x, e)
873
+
874
+ # unpatchify
875
+ x_list[i] = self.unpatchify(x, grid_sizes)
876
+ del x
877
+
878
+ if joint_pass:
879
+ return x_list[0][0], x_list[1][0]
880
+ else:
881
+ return [u.float() for u in x_list[0]]
882
 
883
  def unpatchify(self, x, grid_sizes):
884
  r"""
wan/text2video.py CHANGED
@@ -131,7 +131,8 @@ class WanT2V:
131
  offload_model=True,
132
  callback = None,
133
  enable_RIFLEx = None,
134
- VAE_tile_size = 0
 
135
  ):
136
  r"""
137
  Generates video frames from text prompt using diffusion process.
@@ -240,8 +241,10 @@ class WanT2V:
240
  freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
241
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
242
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
 
243
 
244
-
 
245
  if callback != None:
246
  callback(-1, None)
247
  for i, t in enumerate(tqdm(timesteps)):
@@ -251,14 +254,20 @@ class WanT2V:
251
  timestep = torch.stack(timestep)
252
 
253
  # self.model.to(self.device)
254
- noise_pred_cond = self.model(
255
- latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
256
- if self._interrupt:
257
- return None
258
- noise_pred_uncond = self.model(
259
- latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
260
- if self._interrupt:
261
- return None
 
 
 
 
 
 
262
 
263
  del latent_model_input
264
  noise_pred = noise_pred_uncond + guide_scale * (
 
131
  offload_model=True,
132
  callback = None,
133
  enable_RIFLEx = None,
134
+ VAE_tile_size = 0,
135
+ joint_pass = False,
136
  ):
137
  r"""
138
  Generates video frames from text prompt using diffusion process.
 
241
  freqs = get_rotary_pos_embed(frame_num, size[1], size[0], enable_RIFLEx= enable_RIFLEx)
242
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
243
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
244
+ arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
245
 
246
+ if self.model.enable_teacache:
247
+ self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
248
  if callback != None:
249
  callback(-1, None)
250
  for i, t in enumerate(tqdm(timesteps)):
 
254
  timestep = torch.stack(timestep)
255
 
256
  # self.model.to(self.device)
257
+ if joint_pass:
258
+ noise_pred_cond, noise_pred_uncond = self.model(
259
+ latent_model_input, t=timestep,current_step=i, **arg_both)
260
+ if self._interrupt:
261
+ return None
262
+ else:
263
+ noise_pred_cond = self.model(
264
+ latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
265
+ if self._interrupt:
266
+ return None
267
+ noise_pred_uncond = self.model(
268
+ latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
269
+ if self._interrupt:
270
+ return None
271
 
272
  del latent_model_input
273
  noise_pred = noise_pred_uncond + guide_scale * (