DeepBeepMeep commited on
Commit
0fd9e11
·
1 Parent(s): 7ecf9b0

Lora festival part 2, new macros, new user interface

Browse files
Files changed (4) hide show
  1. README.md +36 -3
  2. gradio_server.py +482 -124
  3. wan/modules/model.py +7 -2
  4. wan/utils/prompt_parser.py +291 -0
README.md CHANGED
@@ -19,6 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
 
 
 
 
 
 
 
22
  * Mar 14, 2025: 👋 Wan2.1GP v1.7:
23
  - Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
24
  - Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
@@ -121,7 +128,11 @@ To run the text to video generator (in Low VRAM mode):
121
  ```bash
122
  python gradio_server.py
123
  #or
124
- python gradio_server.py --t2v
 
 
 
 
125
 
126
  ```
127
 
@@ -191,10 +202,27 @@ python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.l
191
 
192
  You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  ### Command line parameters for Gradio Server
196
  --i2v : launch the image to video generator\
197
- --t2v : launch the text to video generator\
 
 
198
  --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
199
  --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
200
  --lora-preset preset : name of preset gile (without the extension) to preload
@@ -208,7 +236,12 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
208
  --compile : turn on pytorch compilation\
209
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
210
  --profile no : default (4) : no of profile between 1 and 5\
211
- --preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.
 
 
 
 
 
212
 
213
  ### Profiles (for power users only)
214
  You can choose between 5 profiles, but two are really relevant here :
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
23
+ - Clearer user interface
24
+ - Download 30 Loras in one click to try them all (expand the info section)
25
+ - Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
26
+ - Added basic macro prompt language to prefill prompts with differnent values. With one prompt template, you can generate multiple prompts.
27
+ - New Multiple images prompts: you can now combine any number of images with any number of text promtps (need to launch the app with --multiple-images)
28
+ - New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
29
  * Mar 14, 2025: 👋 Wan2.1GP v1.7:
30
  - Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
31
  - Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
 
128
  ```bash
129
  python gradio_server.py
130
  #or
131
+ python gradio_server.py --t2v #launch the default text 2 video model
132
+ #or
133
+ python gradio_server.py --t2v-14B #for the 14B model
134
+ #or
135
+ python gradio_server.py --t2v-1-3B #for the 1.3B model
136
 
137
  ```
138
 
 
202
 
203
  You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
204
 
205
+ ### Macros (basic)
206
+ In *Advanced Mode*, you can starts prompt lines with a "!" , for instance:\
207
+ ```
208
+ ! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its", "her", "his"
209
+ In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
210
+ ```
211
+
212
+ This will create automatically 3 prompts that will cause the generation of 3 videos:
213
+ ```
214
+ In the video, a cat is presented. The cat is in a forest and looks at its watch.
215
+ In the video, a man is presented. The man is in a lake and looks at his watch.
216
+ In the video, a woman is presented. The woman is in a city and looks at her watch.
217
+ ```
218
+
219
+ You can define multiple lines of macros. If there is only one macro line, the app will generate a simple user interface to enter the macro variables when getting back to *Normal Mode* (advanced mode turned off)
220
 
221
  ### Command line parameters for Gradio Server
222
  --i2v : launch the image to video generator\
223
+ --t2v : launch the text to video generator (default defined in the configuration)\
224
+ --t2v-14B : launch the 14B model text to video generator\
225
+ --t2v-1-3B : launch the 1.3B model text to video generator\
226
  --quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
227
  --lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
228
  --lora-preset preset : name of preset gile (without the extension) to preload
 
236
  --compile : turn on pytorch compilation\
237
  --attention mode: force attention mode among, sdpa, flash, sage, sage2\
238
  --profile no : default (4) : no of profile between 1 and 5\
239
+ --preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
240
+ --seed no : set default seed value\
241
+ --frames no : set the default number of frames to generate\
242
+ --steps no : set the default number of denoising steps\
243
+ --check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
244
+ --advanced : turn on the advanced mode while launching the app
245
 
246
  ### Profiles (for power users only)
247
  You can choose between 5 profiles, but two are really relevant here :
gradio_server.py CHANGED
@@ -20,6 +20,9 @@ import gc
20
  import traceback
21
  import math
22
  import asyncio
 
 
 
23
 
24
  def _parse_args():
25
  parser = argparse.ArgumentParser(
@@ -60,7 +63,7 @@ def _parse_args():
60
  parser.add_argument(
61
  "--lora-dir-i2v",
62
  type=str,
63
- default="loras_i2v",
64
  help="Path to a directory that contains Loras for i2v"
65
  )
66
 
@@ -72,6 +75,13 @@ def _parse_args():
72
  help="Path to a directory that contains Loras"
73
  )
74
 
 
 
 
 
 
 
 
75
 
76
  parser.add_argument(
77
  "--lora-preset",
@@ -101,6 +111,34 @@ def _parse_args():
101
  help="Verbose level"
102
  )
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  parser.add_argument(
105
  "--server-port",
106
  type=str,
@@ -133,6 +171,18 @@ def _parse_args():
133
  help="image to video mode"
134
  )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  parser.add_argument(
137
  "--compile",
138
  action="store_true",
@@ -196,6 +246,13 @@ preload =int(args.preload)
196
  force_profile_no = int(args.profile)
197
  verbose_level = int(args.verbose)
198
  quantizeTransformer = args.quantize_transformer
 
 
 
 
 
 
 
199
 
200
  transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
201
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
@@ -247,12 +304,26 @@ if args.t2v:
247
  use_image2video = False
248
  if args.i2v:
249
  use_image2video = True
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  lora_dir =args.lora_dir
252
  if use_image2video and len(lora_dir)==0:
253
- root_lora_dir =args.lora_dir_i2v
254
  if len(lora_dir) ==0:
255
  root_lora_dir = "loras_i2v" if use_image2video else "loras"
 
 
256
  lora_dir = get_lora_dir(root_lora_dir)
257
  lora_preselected_preset = args.lora_preset
258
  default_tea_cache = 0
@@ -378,7 +449,8 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
378
  dir_presets.sort()
379
  loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
380
 
381
- loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
 
382
 
383
  if len(loras) > 0:
384
  loras_names = [ Path(lora).stem for lora in loras ]
@@ -492,7 +564,9 @@ def get_default_flow(model_filename):
492
  return 3.0 if "480p" in model_filename else 5.0
493
 
494
  def generate_header(model_filename, compile, attention_mode):
495
- header = "<H2 ALIGN=CENTER><SPAN> ----------------- "
 
 
496
 
497
  if "image" in model_filename:
498
  model_name = "Wan2.1 image2video"
@@ -508,7 +582,8 @@ def generate_header(model_filename, compile, attention_mode):
508
 
509
  if compile:
510
  header += ", pytorch compilation ON"
511
- header += ") -----------------</SPAN></H2>"
 
512
 
513
  return header
514
 
@@ -591,19 +666,19 @@ def apply_changes( state,
591
 
592
  # return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
593
 
594
- def update_defaults(state, num_inference_steps,flow_shift):
595
  if "config_changes" not in state:
596
  return get_default_flow("")
597
  changes = state["config_changes"]
598
  server_config = state["config_new"]
599
  old_server_config = state["config_old"]
600
-
601
  if not use_image2video:
602
  old_is_14B = "14B" in server_config["transformer_filename"]
603
  new_is_14B = "14B" in old_server_config["transformer_filename"]
604
 
605
  trans_file = server_config["transformer_filename"]
606
- # if old_is_14B != new_is_14B:
607
  # num_inference_steps, flow_shift = get_default_flow(trans_file)
608
  else:
609
  old_is_720P = "720P" in server_config["transformer_filename_i2v"]
@@ -615,9 +690,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
615
  header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
616
  new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
617
  lset_choices = [ (preset, preset) for preset in loras_presets]
618
- lset_choices.append( (new_preset_msg, ""))
619
-
620
- return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
 
 
621
 
622
 
623
  from moviepy.editor import ImageSequenceClip
@@ -661,9 +738,20 @@ def abort_generation(state):
661
  else:
662
  return gr.Button(interactive= True)
663
 
664
- def refresh_gallery(state):
665
  file_list = state.get("file_list", None)
666
- return file_list
 
 
 
 
 
 
 
 
 
 
 
667
 
668
  def finalize_gallery(state):
669
  choice = 0
@@ -675,7 +763,7 @@ def finalize_gallery(state):
675
  time.sleep(0.2)
676
  global gen_in_progress
677
  gen_in_progress = False
678
- return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
679
 
680
  def select_video(state , event_data: gr.EventData):
681
  data= event_data._data
@@ -697,7 +785,9 @@ def one_more_video(state):
697
  extra_orders = state.get("extra_orders", 0)
698
  extra_orders += 1
699
  state["extra_orders"] = extra_orders
700
- prompts_max = state["prompts_max"]
 
 
701
  prompt_no = state["prompt_no"]
702
  video_no = state["video_no"]
703
  total_video = state["total_video"]
@@ -730,6 +820,7 @@ def generate_video(
730
  flow_shift,
731
  embedded_guidance_scale,
732
  repeat_generation,
 
733
  tea_cache,
734
  tea_cache_start_step_perc,
735
  loras_choices,
@@ -759,8 +850,11 @@ def generate_video(
759
  elif attention_mode in attention_modes_supported:
760
  attn = attention_mode
761
  else:
762
- raise gr.Error(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
 
763
 
 
 
764
  width, height = resolution.split("x")
765
  width, height = int(width), int(height)
766
 
@@ -768,17 +862,18 @@ def generate_video(
768
  slg_layers = None
769
  if use_image2video:
770
  if "480p" in transformer_filename_i2v and width * height > 848*480:
771
- raise gr.Error("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
 
772
 
773
  resolution = str(width) + "*" + str(height)
774
  if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
775
- raise gr.Error(f"Resolution {resolution} not supported by image 2 video")
776
-
777
 
778
  else:
779
  if "1.3B" in transformer_filename_t2v and width * height > 848*480:
780
- raise gr.Error("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
781
-
782
 
783
  offload.shared_state["_attention"] = attn
784
 
@@ -808,6 +903,9 @@ def generate_video(
808
  temp_filename = None
809
  if len(prompt) ==0:
810
  return
 
 
 
811
  prompts = prompt.replace("\r", "").split("\n")
812
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
813
  if len(prompts) ==0:
@@ -818,22 +916,31 @@ def generate_video(
818
  image_to_continue = [ tup[0] for tup in image_to_continue ]
819
  else:
820
  image_to_continue = [image_to_continue]
821
- if len(prompts) >= len(image_to_continue):
822
- if len(prompts) % len(image_to_continue) !=0:
823
- raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
824
- rep = len(prompts) // len(image_to_continue)
825
- new_image_to_continue = []
826
- for i, _ in enumerate(prompts):
827
- new_image_to_continue.append(image_to_continue[i//rep] )
828
- image_to_continue = new_image_to_continue
829
- else:
830
- if len(image_to_continue) % len(prompts) !=0:
831
- raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
832
- rep = len(image_to_continue) // len(prompts)
833
  new_prompts = []
834
- for i, _ in enumerate(image_to_continue):
835
- new_prompts.append( prompts[ i//rep] )
 
 
836
  prompts = new_prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
 
838
  elif video_to_continue != None and len(video_to_continue) >0 :
839
  input_image_or_video_path = video_to_continue
@@ -900,6 +1007,10 @@ def generate_video(
900
  # TeaCache
901
  trans.enable_teacache = tea_cache > 0
902
  if trans.enable_teacache:
 
 
 
 
903
  if use_image2video:
904
  if '480p' in transformer_filename_i2v:
905
  # teacache_thresholds = [0.13, .19, 0.26]
@@ -935,9 +1046,11 @@ def generate_video(
935
  start_time = time.time()
936
  state["prompts_max"] = len(prompts)
937
  for no, prompt in enumerate(prompts):
 
938
  repeat_no = 0
939
  state["prompt_no"] = no
940
  extra_generation = 0
 
941
  while True:
942
  extra_orders = state.get("extra_orders",0)
943
  state["extra_orders"] = 0
@@ -950,8 +1063,6 @@ def generate_video(
950
 
951
  if trans.enable_teacache:
952
  trans.teacache_counter = 0
953
- trans.teacache_multiplier = tea_cache
954
- trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
955
  trans.num_steps = num_inference_steps
956
  trans.teacache_skipped_steps = 0
957
  trans.previous_residual_uncond = None
@@ -1035,6 +1146,7 @@ def generate_video(
1035
  if any( keyword in frame.name for keyword in keyword_list):
1036
  VRAM_crash = True
1037
  break
 
1038
  if VRAM_crash:
1039
  raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
1040
  else:
@@ -1054,6 +1166,7 @@ def generate_video(
1054
  if samples == None:
1055
  end_time = time.time()
1056
  abort = True
 
1057
  yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
1058
  else:
1059
  sample = samples.cpu()
@@ -1076,9 +1189,10 @@ def generate_video(
1076
  print(f"New video saved to Path: "+video_path)
1077
  file_list.append(video_path)
1078
  if video_no < total_video:
1079
- yield status
1080
  else:
1081
  end_time = time.time()
 
1082
  yield f"Total Generation Time: {end_time-start_time:.1f}s"
1083
  seed += 1
1084
  repeat_no += 1
@@ -1089,18 +1203,22 @@ def generate_video(
1089
  offload.unload_loras_from_model(trans)
1090
 
1091
 
1092
- new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
 
 
 
 
1093
 
1094
 
1095
  def validate_delete_lset(lset_name):
1096
- if len(lset_name) == 0 or lset_name == new_preset_msg:
1097
  gr.Info(f"Choose a Preset to delete")
1098
  return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
1099
  else:
1100
  return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
1101
 
1102
  def validate_save_lset(lset_name):
1103
- if len(lset_name) == 0 or lset_name == new_preset_msg:
1104
  gr.Info("Please enter a name for the preset")
1105
  return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
1106
  else:
@@ -1109,10 +1227,12 @@ def validate_save_lset(lset_name):
1109
  def cancel_lset():
1110
  return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1111
 
1112
- def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
1113
  global loras_presets
1114
 
1115
- if len(lset_name) == 0 or lset_name== new_preset_msg:
 
 
1116
  gr.Info("Please enter a name for the preset")
1117
  lset_choices =[("Please enter a name for a Lora Preset","")]
1118
  else:
@@ -1142,14 +1262,14 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
1142
  gr.Info(f"Lora Preset '{lset_name}' has been created")
1143
  loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
1144
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
1145
- lset_choices.append( (new_preset_msg, ""))
1146
 
1147
  return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1148
 
1149
  def delete_lset(lset_name):
1150
  global loras_presets
1151
  lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
1152
- if len(lset_name) > 0 and lset_name != new_preset_msg:
1153
  if not os.path.isfile(lset_name_filename):
1154
  raise gr.Error(f"Preset '{lset_name}' not found ")
1155
  os.remove(lset_name_filename)
@@ -1161,7 +1281,7 @@ def delete_lset(lset_name):
1161
  gr.Info(f"Choose a Preset to delete")
1162
 
1163
  lset_choices = [ (preset, preset) for preset in loras_presets]
1164
- lset_choices.append((new_preset_msg, ""))
1165
  return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
1166
 
1167
  def refresh_lora_list(lset_name, loras_choices):
@@ -1179,15 +1299,15 @@ def refresh_lora_list(lset_name, loras_choices):
1179
  lora_names_selected.append(lora_id)
1180
 
1181
  lset_choices = [ (preset, preset) for preset in loras_presets]
1182
- lset_choices.append((new_preset_msg, ""))
1183
  if lset_name in loras_presets:
1184
  pos = loras_presets.index(lset_name)
1185
  else:
1186
  pos = len(loras_presets)
1187
  lset_name =""
1188
 
1189
- errors = wan_model.model._loras_errors
1190
- if len(errors) > 0:
1191
  error_files = [path for path, _ in errors]
1192
  gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
1193
  else:
@@ -1196,9 +1316,11 @@ def refresh_lora_list(lset_name, loras_choices):
1196
 
1197
  return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
1198
 
1199
- def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
1200
 
1201
- if len(lset_name) == 0 or lset_name== new_preset_msg:
 
 
1202
  gr.Info("Please choose a preset in the list or create one")
1203
  else:
1204
  loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
@@ -1213,34 +1335,221 @@ def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
1213
  prompt = "\n".join(prompts)
1214
  prompt = preset_prompt + '\n' + prompt
1215
  gr.Info(f"Lora Preset '{lset_name}' has been applied")
 
 
1216
 
1217
  return loras_choices, loras_mult_choices, prompt
1218
 
1219
- def create_demo():
1220
-
1221
- default_inference_steps = 30
1222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1223
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
1224
- with gr.Blocks() as demo:
1225
- state = gr.State({})
1226
 
1227
  if use_image2video:
1228
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1229
  else:
1230
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1231
 
1232
- gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
1233
-
1234
- if use_image2video and False:
1235
- pass
1236
- else:
1237
- gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
1238
- gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
1239
- gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
1240
- gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
1241
- gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
1242
- gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
1243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1244
 
1245
  # css = """<STYLE>
1246
  # h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
@@ -1370,8 +1679,36 @@ def create_demo():
1370
  apply_btn = gr.Button("Apply Changes")
1371
 
1372
 
 
1373
  with gr.Row():
1374
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1375
  video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
1376
  if args.multiple_images:
1377
  image_to_continue = gr.Gallery(
@@ -1380,8 +1717,30 @@ def create_demo():
1380
  else:
1381
  image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
1382
 
1383
- prompt = gr.Textbox(label="Prompts (multiple prompts separated by carriage returns will generate multiple videos, lines that starts with # are ignored)", value=default_prompt, lines=3)
1384
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1385
  with gr.Row():
1386
  if use_image2video:
1387
  resolution = gr.Dropdown(
@@ -1417,64 +1776,51 @@ def create_demo():
1417
 
1418
  with gr.Row():
1419
  with gr.Column():
1420
- video_length = gr.Slider(5, 193, value=81, step=4, label="Number of frames (16 = 1s)")
1421
  with gr.Column():
1422
- num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps")
1423
 
1424
  with gr.Row():
1425
  max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
1426
 
1427
 
1428
- with gr.Row(visible= len(loras)>0):
1429
- lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
1430
- with gr.Column(scale=5):
1431
- lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
1432
- with gr.Column(scale=1):
1433
- # with gr.Column():
1434
- with gr.Row(height=17):
1435
- apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
1436
- refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
1437
- # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
1438
- save_lset_prompt_drop= gr.Dropdown(
1439
- choices=[
1440
- ("Save Prompt Comments Only", 0),
1441
- ("Save Full Prompt", 1)
1442
- ], show_label= False, container=False, visible= False
1443
- )
1444
-
1445
-
1446
- with gr.Row(height=17):
1447
- confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
1448
- confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
1449
- save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
1450
- delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
1451
- cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
1452
-
1453
-
1454
- loras_choices = gr.Dropdown(
1455
- choices=[
1456
- (lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
1457
- ],
1458
- value= default_loras_choices,
1459
- multiselect= True,
1460
- visible= len(loras)>0,
1461
- label="Activated Loras"
1462
- )
1463
- loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
1464
 
1465
- show_advanced = gr.Checkbox(label="Show Advanced Options", value=False)
1466
- with gr.Row(visible=False) as advanced_row:
1467
  with gr.Column():
1468
- seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
1469
- repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
1470
  with gr.Row():
1471
- negative_prompt = gr.Textbox(label="Negative Prompt", value="")
 
 
 
 
 
 
 
1472
  with gr.Row():
1473
  guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
1474
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
1475
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1476
  with gr.Row():
1477
- gr.Markdown("Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1478
  with gr.Row():
1479
  tea_cache_setting = gr.Dropdown(
1480
  choices=[
@@ -1491,6 +1837,7 @@ def create_demo():
1491
  )
1492
  tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
1493
 
 
1494
  RIFLEx_setting = gr.Dropdown(
1495
  choices=[
1496
  ("Auto (ON if Video longer than 5s)", 0),
@@ -1503,7 +1850,7 @@ def create_demo():
1503
 
1504
 
1505
  with gr.Row():
1506
- gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
1507
  with gr.Row():
1508
  slg_switch = gr.Dropdown(
1509
  choices=[
@@ -1529,33 +1876,43 @@ def create_demo():
1529
  slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
1530
 
1531
 
1532
- show_advanced.change(fn=lambda x: gr.Row(visible=x), inputs=[show_advanced], outputs=[advanced_row])
 
1533
 
1534
  with gr.Column():
1535
  gen_status = gr.Text(label="Status", interactive= False)
1536
  output = gr.Gallery(
1537
  label="Generated videos", show_label=False, elem_id="gallery"
1538
- , columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
1539
  generate_btn = gr.Button("Generate")
1540
  onemore_btn = gr.Button("One More Please !", visible= False)
1541
  abort_btn = gr.Button("Abort")
 
1542
 
1543
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1544
- confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
 
1545
  delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1546
  confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1547
  cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
1548
 
1549
- apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
 
 
1550
 
1551
  refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
 
 
1552
 
1553
- gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1554
 
1555
  abort_btn.click(abort_generation,state,abort_btn )
1556
  output.select(select_video, state, None )
1557
  onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
1558
- generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
 
 
 
1559
  fn=generate_video,
1560
  inputs=[
1561
  prompt,
@@ -1568,6 +1925,7 @@ def create_demo():
1568
  flow_shift,
1569
  embedded_guidance_scale,
1570
  repeat_generation,
 
1571
  tea_cache_setting,
1572
  tea_cache_start_step_perc,
1573
  loras_choices,
@@ -1587,7 +1945,7 @@ def create_demo():
1587
  ).then(
1588
  finalize_gallery,
1589
  [state],
1590
- [output , abort_btn, generate_btn, onemore_btn]
1591
  )
1592
 
1593
  apply_btn.click(
@@ -1607,7 +1965,7 @@ def create_demo():
1607
  outputs= msg
1608
  ).then(
1609
  update_defaults,
1610
- [state, num_inference_steps, flow_shift],
1611
  [num_inference_steps, flow_shift, header, lset_name , loras_choices ]
1612
  )
1613
 
 
20
  import traceback
21
  import math
22
  import asyncio
23
+ from wan.utils import prompt_parser
24
+ PROMPT_VARS_MAX = 10
25
+
26
 
27
  def _parse_args():
28
  parser = argparse.ArgumentParser(
 
63
  parser.add_argument(
64
  "--lora-dir-i2v",
65
  type=str,
66
+ default="",
67
  help="Path to a directory that contains Loras for i2v"
68
  )
69
 
 
75
  help="Path to a directory that contains Loras"
76
  )
77
 
78
+ parser.add_argument(
79
+ "--check-loras",
80
+ type=str,
81
+ default=0,
82
+ help="Filter Loras that are not valid"
83
+ )
84
+
85
 
86
  parser.add_argument(
87
  "--lora-preset",
 
111
  help="Verbose level"
112
  )
113
 
114
+ parser.add_argument(
115
+ "--steps",
116
+ type=int,
117
+ default=0,
118
+ help="default denoising steps"
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--frames",
123
+ type=int,
124
+ default=0,
125
+ help="default number of frames"
126
+ )
127
+
128
+ parser.add_argument(
129
+ "--seed",
130
+ type=int,
131
+ default=-1,
132
+ help="default generation seed"
133
+ )
134
+
135
+ parser.add_argument(
136
+ "--advanced",
137
+ action="store_true",
138
+ help="Access advanced options by default"
139
+ )
140
+
141
+
142
  parser.add_argument(
143
  "--server-port",
144
  type=str,
 
171
  help="image to video mode"
172
  )
173
 
174
+ parser.add_argument(
175
+ "--t2v-14B",
176
+ action="store_true",
177
+ help="text to video mode 14B model"
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--t2v-1-3B",
182
+ action="store_true",
183
+ help="text to video mode 1.3B model"
184
+ )
185
+
186
  parser.add_argument(
187
  "--compile",
188
  action="store_true",
 
246
  force_profile_no = int(args.profile)
247
  verbose_level = int(args.verbose)
248
  quantizeTransformer = args.quantize_transformer
249
+ default_seed = args.seed
250
+ default_number_frames = int(args.frames)
251
+ if default_number_frames > 0:
252
+ default_number_frames = ((default_number_frames - 1) // 4) * 4 + 1
253
+ default_inference_steps = args.steps
254
+ check_loras = args.check_loras ==1
255
+ advanced = args.advanced
256
 
257
  transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
258
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
 
304
  use_image2video = False
305
  if args.i2v:
306
  use_image2video = True
307
+ if args.t2v_14B:
308
+ use_image2video = False
309
+ if not "14B" in transformer_filename_t2v:
310
+ transformer_filename_t2v = transformer_choices_t2v[2]
311
+ lock_ui_transformer = False
312
+
313
+ if args.t2v_1_3B:
314
+ transformer_filename_t2v = transformer_choices_t2v[0]
315
+ use_image2video = False
316
+ lock_ui_transformer = False
317
+
318
+ only_allow_edit_in_advanced = False
319
 
320
  lora_dir =args.lora_dir
321
  if use_image2video and len(lora_dir)==0:
322
+ lora_dir =args.lora_dir_i2v
323
  if len(lora_dir) ==0:
324
  root_lora_dir = "loras_i2v" if use_image2video else "loras"
325
+ else:
326
+ root_lora_dir = lora_dir
327
  lora_dir = get_lora_dir(root_lora_dir)
328
  lora_preselected_preset = args.lora_preset
329
  default_tea_cache = 0
 
449
  dir_presets.sort()
450
  loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
451
 
452
+ if check_loras:
453
+ loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
454
 
455
  if len(loras) > 0:
456
  loras_names = [ Path(lora).stem for lora in loras ]
 
564
  return 3.0 if "480p" in model_filename else 5.0
565
 
566
  def generate_header(model_filename, compile, attention_mode):
567
+
568
+
569
+ header = "<div class='title-with-lines'><div class=line></div><h2>"
570
 
571
  if "image" in model_filename:
572
  model_name = "Wan2.1 image2video"
 
582
 
583
  if compile:
584
  header += ", pytorch compilation ON"
585
+ header += ") </h2><div class=line></div> "
586
+
587
 
588
  return header
589
 
 
666
 
667
  # return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
668
 
669
+ def update_defaults(state, num_inference_steps,flow_shift, lset_name , loras_choices):
670
  if "config_changes" not in state:
671
  return get_default_flow("")
672
  changes = state["config_changes"]
673
  server_config = state["config_new"]
674
  old_server_config = state["config_old"]
675
+ t2v_changed = False
676
  if not use_image2video:
677
  old_is_14B = "14B" in server_config["transformer_filename"]
678
  new_is_14B = "14B" in old_server_config["transformer_filename"]
679
 
680
  trans_file = server_config["transformer_filename"]
681
+ t2v_changed = old_is_14B != new_is_14B
682
  # num_inference_steps, flow_shift = get_default_flow(trans_file)
683
  else:
684
  old_is_720P = "720P" in server_config["transformer_filename_i2v"]
 
690
  header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
691
  new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
692
  lset_choices = [ (preset, preset) for preset in loras_presets]
693
+ lset_choices.append( (get_new_preset_msg(advanced), ""))
694
+ if t2v_changed:
695
+ return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
696
+ else:
697
+ return num_inference_steps, flow_shift, header, lset_name , loras_choices
698
 
699
 
700
  from moviepy.editor import ImageSequenceClip
 
738
  else:
739
  return gr.Button(interactive= True)
740
 
741
+ def refresh_gallery(state, txt):
742
  file_list = state.get("file_list", None)
743
+ prompt = state.get("prompt", "")
744
+ if len(prompt) == 0:
745
+ return file_list, gr.Text(visible= False, value="")
746
+ else:
747
+ prompts_max = state.get("prompts_max",0)
748
+ prompt_no = state.get("prompt_no",0)
749
+ if prompts_max >1 :
750
+ label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
751
+ else:
752
+ label = f"Current Prompt"
753
+ return file_list, gr.Text(visible= True, value=prompt, label=label)
754
+
755
 
756
  def finalize_gallery(state):
757
  choice = 0
 
763
  time.sleep(0.2)
764
  global gen_in_progress
765
  gen_in_progress = False
766
+ return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
767
 
768
  def select_video(state , event_data: gr.EventData):
769
  data= event_data._data
 
785
  extra_orders = state.get("extra_orders", 0)
786
  extra_orders += 1
787
  state["extra_orders"] = extra_orders
788
+ prompts_max = state.get("prompts_max",0)
789
+ if prompts_max == 0:
790
+ return state
791
  prompt_no = state["prompt_no"]
792
  video_no = state["video_no"]
793
  total_video = state["total_video"]
 
820
  flow_shift,
821
  embedded_guidance_scale,
822
  repeat_generation,
823
+ multi_images_gen_type,
824
  tea_cache,
825
  tea_cache_start_step_perc,
826
  loras_choices,
 
850
  elif attention_mode in attention_modes_supported:
851
  attn = attention_mode
852
  else:
853
+ gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
854
+ return
855
 
856
+ if state.get("validate_success",0) != 1:
857
+ return
858
  width, height = resolution.split("x")
859
  width, height = int(width), int(height)
860
 
 
862
  slg_layers = None
863
  if use_image2video:
864
  if "480p" in transformer_filename_i2v and width * height > 848*480:
865
+ gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
866
+ return
867
 
868
  resolution = str(width) + "*" + str(height)
869
  if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
870
+ gr.Info(f"Resolution {resolution} not supported by image 2 video")
871
+ return
872
 
873
  else:
874
  if "1.3B" in transformer_filename_t2v and width * height > 848*480:
875
+ gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
876
+ return
877
 
878
  offload.shared_state["_attention"] = attn
879
 
 
903
  temp_filename = None
904
  if len(prompt) ==0:
905
  return
906
+ prompt, errors = prompt_parser.process_template(prompt)
907
+ if len(errors) > 0:
908
+ gr.Info(f"Error processing prompt template: " + errors)
909
  prompts = prompt.replace("\r", "").split("\n")
910
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
911
  if len(prompts) ==0:
 
916
  image_to_continue = [ tup[0] for tup in image_to_continue ]
917
  else:
918
  image_to_continue = [image_to_continue]
919
+ if multi_images_gen_type == 0:
 
 
 
 
 
 
 
 
 
 
 
920
  new_prompts = []
921
+ new_image_to_continue = []
922
+ for i in range(len(prompts) * len(image_to_continue) ):
923
+ new_prompts.append( prompts[ i % len(prompts)] )
924
+ new_image_to_continue.append(image_to_continue[i // len(prompts)] )
925
  prompts = new_prompts
926
+ image_to_continue = new_image_to_continue
927
+ else:
928
+ if len(prompts) >= len(image_to_continue):
929
+ if len(prompts) % len(image_to_continue) !=0:
930
+ raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
931
+ rep = len(prompts) // len(image_to_continue)
932
+ new_image_to_continue = []
933
+ for i, _ in enumerate(prompts):
934
+ new_image_to_continue.append(image_to_continue[i//rep] )
935
+ image_to_continue = new_image_to_continue
936
+ else:
937
+ if len(image_to_continue) % len(prompts) !=0:
938
+ raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
939
+ rep = len(image_to_continue) // len(prompts)
940
+ new_prompts = []
941
+ for i, _ in enumerate(image_to_continue):
942
+ new_prompts.append( prompts[ i//rep] )
943
+ prompts = new_prompts
944
 
945
  elif video_to_continue != None and len(video_to_continue) >0 :
946
  input_image_or_video_path = video_to_continue
 
1007
  # TeaCache
1008
  trans.enable_teacache = tea_cache > 0
1009
  if trans.enable_teacache:
1010
+ trans.teacache_multiplier = tea_cache
1011
+ trans.rel_l1_thresh = 0
1012
+ trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
1013
+
1014
  if use_image2video:
1015
  if '480p' in transformer_filename_i2v:
1016
  # teacache_thresholds = [0.13, .19, 0.26]
 
1046
  start_time = time.time()
1047
  state["prompts_max"] = len(prompts)
1048
  for no, prompt in enumerate(prompts):
1049
+ state["prompt"] = prompt
1050
  repeat_no = 0
1051
  state["prompt_no"] = no
1052
  extra_generation = 0
1053
+ yield f"Prompt No{no}"
1054
  while True:
1055
  extra_orders = state.get("extra_orders",0)
1056
  state["extra_orders"] = 0
 
1063
 
1064
  if trans.enable_teacache:
1065
  trans.teacache_counter = 0
 
 
1066
  trans.num_steps = num_inference_steps
1067
  trans.teacache_skipped_steps = 0
1068
  trans.previous_residual_uncond = None
 
1146
  if any( keyword in frame.name for keyword in keyword_list):
1147
  VRAM_crash = True
1148
  break
1149
+ state["prompt"] = ""
1150
  if VRAM_crash:
1151
  raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
1152
  else:
 
1166
  if samples == None:
1167
  end_time = time.time()
1168
  abort = True
1169
+ state["prompt"] = ""
1170
  yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
1171
  else:
1172
  sample = samples.cpu()
 
1189
  print(f"New video saved to Path: "+video_path)
1190
  file_list.append(video_path)
1191
  if video_no < total_video:
1192
+ yield status
1193
  else:
1194
  end_time = time.time()
1195
+ state["prompt"] = ""
1196
  yield f"Total Generation Time: {end_time-start_time:.1f}s"
1197
  seed += 1
1198
  repeat_no += 1
 
1203
  offload.unload_loras_from_model(trans)
1204
 
1205
 
1206
+ def get_new_preset_msg(advanced = True):
1207
+ if advanced:
1208
+ return "Enter here a Name for a Lora Preset or Choose one in the List"
1209
+ else:
1210
+ return "Choose a Lora Preset in this List to Apply a Special Effect"
1211
 
1212
 
1213
  def validate_delete_lset(lset_name):
1214
+ if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
1215
  gr.Info(f"Choose a Preset to delete")
1216
  return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
1217
  else:
1218
  return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
1219
 
1220
  def validate_save_lset(lset_name):
1221
+ if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
1222
  gr.Info("Please enter a name for the preset")
1223
  return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
1224
  else:
 
1227
  def cancel_lset():
1228
  return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1229
 
1230
+ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
1231
  global loras_presets
1232
 
1233
+ if state.get("validate_success",0) == 0:
1234
+ pass
1235
+ if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
1236
  gr.Info("Please enter a name for the preset")
1237
  lset_choices =[("Please enter a name for a Lora Preset","")]
1238
  else:
 
1262
  gr.Info(f"Lora Preset '{lset_name}' has been created")
1263
  loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
1264
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
1265
+ lset_choices.append( (get_new_preset_msg(), ""))
1266
 
1267
  return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1268
 
1269
  def delete_lset(lset_name):
1270
  global loras_presets
1271
  lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
1272
+ if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
1273
  if not os.path.isfile(lset_name_filename):
1274
  raise gr.Error(f"Preset '{lset_name}' not found ")
1275
  os.remove(lset_name_filename)
 
1281
  gr.Info(f"Choose a Preset to delete")
1282
 
1283
  lset_choices = [ (preset, preset) for preset in loras_presets]
1284
+ lset_choices.append((get_new_preset_msg(), ""))
1285
  return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
1286
 
1287
  def refresh_lora_list(lset_name, loras_choices):
 
1299
  lora_names_selected.append(lora_id)
1300
 
1301
  lset_choices = [ (preset, preset) for preset in loras_presets]
1302
+ lset_choices.append((get_new_preset_msg(advanced), ""))
1303
  if lset_name in loras_presets:
1304
  pos = loras_presets.index(lset_name)
1305
  else:
1306
  pos = len(loras_presets)
1307
  lset_name =""
1308
 
1309
+ errors = getattr(wan_model.model, "_loras_errors", "")
1310
+ if errors !=None and len(errors) > 0:
1311
  error_files = [path for path, _ in errors]
1312
  gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
1313
  else:
 
1316
 
1317
  return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
1318
 
1319
+ def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt):
1320
 
1321
+ state["apply_success"] = 0
1322
+
1323
+ if len(lset_name) == 0 or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False):
1324
  gr.Info("Please choose a preset in the list or create one")
1325
  else:
1326
  loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
 
1335
  prompt = "\n".join(prompts)
1336
  prompt = preset_prompt + '\n' + prompt
1337
  gr.Info(f"Lora Preset '{lset_name}' has been applied")
1338
+ state["apply_success"] = 1
1339
+ state["wizard_prompt"] = 0
1340
 
1341
  return loras_choices, loras_mult_choices, prompt
1342
 
 
 
 
1343
 
1344
+ def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, *args):
1345
+
1346
+ prompts = wizard_prompt.replace("\r" ,"").split("\n")
1347
+
1348
+ new_prompts = []
1349
+ macro_already_written = False
1350
+ for prompt in prompts:
1351
+ if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt:
1352
+ variables = state["variables"]
1353
+ values = args[:len(variables)]
1354
+ macro = "! "
1355
+ for i, (variable, value) in enumerate(zip(variables, values)):
1356
+ if len(value) == 0 and not allow_null_values:
1357
+ return prompt, "You need to provide a value for '" + variable + "'"
1358
+ sub_values= [ "\"" + sub_value + "\"" for sub_value in value.split("\n") ]
1359
+ value = ",".join(sub_values)
1360
+ if i>0:
1361
+ macro += " : "
1362
+ macro += "{" + variable + "}"+ f"={value}"
1363
+ if len(variables) > 0:
1364
+ macro_already_written = True
1365
+ new_prompts.append(macro)
1366
+ new_prompts.append(prompt)
1367
+ else:
1368
+ new_prompts.append(prompt)
1369
+
1370
+ prompt = "\n".join(new_prompts)
1371
+ return prompt, ""
1372
+
1373
+ def validate_wizard_prompt(state, prompt, wizard_prompt, *args):
1374
+ state["validate_success"] = 0
1375
+
1376
+ if state.get("wizard_prompt",0) != 1:
1377
+ state["validate_success"] = 1
1378
+ return prompt
1379
+
1380
+ prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, False, *args)
1381
+ if len(errors) > 0:
1382
+ gr.Info(errors)
1383
+ return prompt
1384
+
1385
+ state["validate_success"] = 1
1386
+
1387
+ return prompt
1388
+
1389
+ def fill_prompt_from_wizard(state, prompt, wizard_prompt, *args):
1390
+
1391
+ if state.get("wizard_prompt",0) == 1:
1392
+ prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, True, *args)
1393
+ if len(errors) > 0:
1394
+ gr.Info(errors)
1395
+
1396
+ state["wizard_prompt"] = 0
1397
+
1398
+ return gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX
1399
+
1400
+ def extract_wizard_prompt(prompt):
1401
+ variables = []
1402
+ values = {}
1403
+ prompts = prompt.replace("\r" ,"").split("\n")
1404
+ if sum(prompt.startswith("!") for prompt in prompts) > 1:
1405
+ return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
1406
+
1407
+ new_prompts = []
1408
+ errors = ""
1409
+ for prompt in prompts:
1410
+ if prompt.startswith("!"):
1411
+ variables, errors = prompt_parser.extract_variable_names(prompt)
1412
+ if len(errors) > 0:
1413
+ return "", variables, values, "Error parsing Prompt templace: " + errors
1414
+ if len(variables) > PROMPT_VARS_MAX:
1415
+ return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
1416
+ values, errors = prompt_parser.extract_variable_values(prompt)
1417
+ if len(errors) > 0:
1418
+ return "", variables, values, "Error parsing Prompt templace: " + errors
1419
+ else:
1420
+ variables_extra, errors = prompt_parser.extract_variable_names(prompt)
1421
+ if len(errors) > 0:
1422
+ return "", variables, values, "Error parsing Prompt templace: " + errors
1423
+ variables += variables_extra
1424
+ variables = [var for pos, var in enumerate(variables) if var not in variables[:pos]]
1425
+ if len(variables) > PROMPT_VARS_MAX:
1426
+ return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
1427
+
1428
+ new_prompts.append(prompt)
1429
+ wizard_prompt = "\n".join(new_prompts)
1430
+ return wizard_prompt, variables, values, errors
1431
+
1432
+ def fill_wizard_prompt(state, prompt, wizard_prompt):
1433
+ def get_hidden_textboxes(num = PROMPT_VARS_MAX ):
1434
+ return [gr.Textbox(value="", visible=False)] * num
1435
+
1436
+ hidden_column = gr.Column(visible = False)
1437
+ visible_column = gr.Column(visible = True)
1438
+
1439
+ if advanced or state.get("apply_success") != 1:
1440
+ return prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes()
1441
+ prompt_parts= []
1442
+ state["wizard_prompt"] = 0
1443
+
1444
+ wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt)
1445
+ if len(errors) > 0:
1446
+ gr.Info( errors )
1447
+ return gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes()
1448
+
1449
+ for variable in variables:
1450
+ value = values.get(variable, "")
1451
+ prompt_parts.append(gr.Textbox( placeholder=variable, info= variable, visible= True, value= "\n".join(value) ))
1452
+ any_macro = len(variables) > 0
1453
+
1454
+ prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts))
1455
+
1456
+ state["variables"] = variables
1457
+ state["wizard_prompt"] = 1
1458
+
1459
+ return gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts
1460
+
1461
+ def switch_prompt_type(state, prompt, wizard_prompt, *prompt_vars):
1462
+ if advanced:
1463
+ return fill_prompt_from_wizard(state, prompt, wizard_prompt, *prompt_vars)
1464
+ else:
1465
+ state["apply_success"] = 1
1466
+ return fill_wizard_prompt(state, prompt, wizard_prompt)
1467
+
1468
+
1469
+ visible= False
1470
+ def switch_advanced(new_advanced, lset_name):
1471
+ global advanced
1472
+ advanced= new_advanced
1473
+ lset_choices = [ (preset, preset) for preset in loras_presets]
1474
+ lset_choices.append((get_new_preset_msg(advanced), ""))
1475
+ if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="":
1476
+ lset_name = get_new_preset_msg(advanced)
1477
+
1478
+ if only_allow_edit_in_advanced:
1479
+ return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name)
1480
+ else:
1481
+ return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
1482
+
1483
+ def download_loras():
1484
+ from huggingface_hub import snapshot_download
1485
+
1486
+
1487
+ yield "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
1488
+ log_path = os.path.join(lora_dir, "log.txt")
1489
+ if not os.path.isfile(log_path):
1490
+ import shutil
1491
+ tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
1492
+
1493
+ import shutil, glob
1494
+ snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path)
1495
+ [shutil.move(f, lora_dir) for f in glob.glob(os.path.join(tmp_path, "loras_i2v", "*.*")) if not "README.txt" in f ]
1496
+
1497
+
1498
+ yield "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
1499
+
1500
+ from datetime import datetime
1501
+ dt = datetime.today().strftime('%Y-%m-%d')
1502
+ with open( log_path, "w", encoding="utf-8") as writer:
1503
+ writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
1504
+
1505
+ return
1506
+ def create_demo():
1507
+ css= """
1508
+ .title-with-lines {
1509
+ display: flex;
1510
+ align-items: center;
1511
+ margin: 30px 0;
1512
+ }
1513
+ .line {
1514
+ flex-grow: 1;
1515
+ height: 1px;
1516
+ background-color: #333;
1517
+ }
1518
+ h2 {
1519
+ margin: 0 20px;
1520
+ white-space: nowrap;
1521
+ }
1522
+ """
1523
  default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
1524
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="emerald", neutral_hue="slate", text_size= "md")) as demo:
1525
+ state_dict = {}
1526
 
1527
  if use_image2video:
1528
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
1529
  else:
1530
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
1531
 
1532
+ gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
 
 
 
 
 
 
 
 
 
 
1533
 
1534
+ with gr.Accordion("Click here for some Info on how to use Wan2GP and to download 20+ Loras", open = False):
1535
+ if use_image2video and False:
1536
+ pass
1537
+ else:
1538
+ gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
1539
+ gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
1540
+ gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
1541
+ gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
1542
+ gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
1543
+ gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
1544
+
1545
+ if use_image2video:
1546
+ with gr.Row():
1547
+ with gr.Row(scale =3):
1548
+ gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras). Dont't forget first to make a backup of your Loras just in case.")
1549
+ with gr.Row(scale =1):
1550
+ download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
1551
+ with gr.Row():
1552
+ download_status = gr.Markdown()
1553
 
1554
  # css = """<STYLE>
1555
  # h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
 
1679
  apply_btn = gr.Button("Apply Changes")
1680
 
1681
 
1682
+
1683
  with gr.Row():
1684
  with gr.Column():
1685
+ with gr.Row(visible= len(loras)>0) as presets_column:
1686
+ lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
1687
+ with gr.Column(scale=6):
1688
+ lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
1689
+ with gr.Column(scale=1):
1690
+ # with gr.Column():
1691
+ with gr.Row(height=17):
1692
+ apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
1693
+ refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced)
1694
+ # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
1695
+ save_lset_prompt_drop= gr.Dropdown(
1696
+ choices=[
1697
+ ("Save Prompt Comments Only", 0),
1698
+ ("Save Full Prompt", 1)
1699
+ ], show_label= False, container=False, value =1, visible= False
1700
+ )
1701
+
1702
+ with gr.Row(height=17, visible=False) as refresh2_row:
1703
+ refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1)
1704
+
1705
+ with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows:
1706
+ confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
1707
+ confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
1708
+ save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
1709
+ delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
1710
+ cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
1711
+
1712
  video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
1713
  if args.multiple_images:
1714
  image_to_continue = gr.Gallery(
 
1717
  else:
1718
  image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
1719
 
1720
+ advanced_prompt = advanced
1721
+ prompt_vars=[]
1722
+ if not advanced_prompt:
1723
+ default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
1724
+ advanced_prompt = len(errors) > 0
1725
+
1726
+ with gr.Column(visible= advanced_prompt) as prompt_column_advanced: #visible= False
1727
+ prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=default_prompt, lines=3)
1728
+
1729
+ with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: #visible= False
1730
+ gr.Markdown("<B>Please fill the following input fields to adapt automatically the Prompt:</B>")
1731
+ with gr.Row(): #visible= not advanced_prompt and len(variables) > 0
1732
+ if not advanced_prompt:
1733
+ for variable in variables:
1734
+ value = values.get(variable, "")
1735
+ prompt_vars.append(gr.Textbox( placeholder=variable, min_width=80, show_label= False, info= variable, visible= True, value= "\n".join(value) ))
1736
+ state_dict["wizard_prompt"] = 1
1737
+ state_dict["variables"] = variables
1738
+ for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
1739
+ prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
1740
+ with gr.Column(not advanced_prompt) as prompt_column_wizard:
1741
+ wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
1742
+ state = gr.State(state_dict)
1743
+
1744
  with gr.Row():
1745
  if use_image2video:
1746
  resolution = gr.Dropdown(
 
1776
 
1777
  with gr.Row():
1778
  with gr.Column():
1779
+ video_length = gr.Slider(5, 193, value=default_number_frames if default_number_frames > 0 else 81, step=4, label="Number of frames (16 = 1s)")
1780
  with gr.Column():
1781
+ num_inference_steps = gr.Slider(1, 100, value= default_inference_steps if default_inference_steps > 0 else 30, step=1, label="Number of Inference Steps")
1782
 
1783
  with gr.Row():
1784
  max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
1785
 
1786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1787
 
1788
+ show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
1789
+ with gr.Row(visible=advanced) as advanced_row:
1790
  with gr.Column():
1791
+ seed = gr.Slider(-1, 999999999, value=default_seed, step=1, label="Seed (-1 for random)")
 
1792
  with gr.Row():
1793
+ repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
1794
+ multi_images_gen_type = gr.Dropdown(
1795
+ choices=[
1796
+ ("Generate every combination of images and texts prompts", 0),
1797
+ ("Match images and text prompts", 1),
1798
+ ], visible= args.multiple_images, label= "Multiple Images as Prompts"
1799
+ )
1800
+
1801
  with gr.Row():
1802
  guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
1803
  embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
1804
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1805
  with gr.Row():
1806
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="")
1807
+ with gr.Row():
1808
+ gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
1809
+ with gr.Column() as loras_column:
1810
+ loras_choices = gr.Dropdown(
1811
+ choices=[
1812
+ (lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
1813
+ ],
1814
+ value= default_loras_choices,
1815
+ multiselect= True,
1816
+ visible= len(loras)>0,
1817
+ label="Activated Loras"
1818
+ )
1819
+ loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
1820
+
1821
+
1822
+ with gr.Row():
1823
+ gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
1824
  with gr.Row():
1825
  tea_cache_setting = gr.Dropdown(
1826
  choices=[
 
1837
  )
1838
  tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
1839
 
1840
+ gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
1841
  RIFLEx_setting = gr.Dropdown(
1842
  choices=[
1843
  ("Auto (ON if Video longer than 5s)", 0),
 
1850
 
1851
 
1852
  with gr.Row():
1853
+ gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
1854
  with gr.Row():
1855
  slg_switch = gr.Dropdown(
1856
  choices=[
 
1876
  slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
1877
 
1878
 
1879
+ show_advanced.change(fn=switch_advanced, inputs=[show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
1880
+ fn=switch_prompt_type, inputs = [state, prompt, wizard_prompt, *prompt_vars], outputs = [prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
1881
 
1882
  with gr.Column():
1883
  gen_status = gr.Text(label="Status", interactive= False)
1884
  output = gr.Gallery(
1885
  label="Generated videos", show_label=False, elem_id="gallery"
1886
+ , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
1887
  generate_btn = gr.Button("Generate")
1888
  onemore_btn = gr.Button("One More Please !", visible= False)
1889
  abort_btn = gr.Button("Abort")
1890
+ gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) #gr.Markdown("Current prompt") #, ,
1891
 
1892
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1893
+ confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
1894
+ save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1895
  delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1896
  confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1897
  cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
1898
 
1899
+ apply_lset_btn.click(apply_lset, inputs=[state, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]).then(
1900
+ fn = fill_wizard_prompt, inputs = [state, prompt, wizard_prompt], outputs = [ prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
1901
+ )
1902
 
1903
  refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
1904
+ refresh_lora_btn2.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
1905
+ download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
1906
 
1907
+ gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
1908
 
1909
  abort_btn.click(abort_generation,state,abort_btn )
1910
  output.select(select_video, state, None )
1911
  onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
1912
+ generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
1913
+ ).then(
1914
+ fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
1915
+ ).then(
1916
  fn=generate_video,
1917
  inputs=[
1918
  prompt,
 
1925
  flow_shift,
1926
  embedded_guidance_scale,
1927
  repeat_generation,
1928
+ multi_images_gen_type,
1929
  tea_cache_setting,
1930
  tea_cache_start_step_perc,
1931
  loras_choices,
 
1945
  ).then(
1946
  finalize_gallery,
1947
  [state],
1948
+ [output , abort_btn, generate_btn, onemore_btn, gen_info]
1949
  )
1950
 
1951
  apply_btn.click(
 
1965
  outputs= msg
1966
  ).then(
1967
  update_defaults,
1968
+ [state, num_inference_steps, flow_shift,lset_name , loras_choices],
1969
  [num_inference_steps, flow_shift, header, lset_name , loras_choices ]
1970
  )
1971
 
wan/modules/model.py CHANGED
@@ -676,6 +676,7 @@ class WanModel(ModelMixin, ConfigMixin):
676
 
677
  best_threshold = 0.01
678
  best_diff = 1000
 
679
  target_nb_steps= int(len(timesteps) / speed_factor)
680
  threshold = 0.01
681
  while threshold <= 0.6:
@@ -686,6 +687,8 @@ class WanModel(ModelMixin, ConfigMixin):
686
  skip = False
687
  if not (i<=start_step or i== len(timesteps)):
688
  accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
 
 
689
  if accumulated_rel_l1_distance < threshold:
690
  skip = True
691
  else:
@@ -693,15 +696,17 @@ class WanModel(ModelMixin, ConfigMixin):
693
  previous_modulated_input = e_list[i]
694
  if not skip:
695
  nb_steps += 1
696
- diff = abs(target_nb_steps - nb_steps)
 
697
  if diff < best_diff:
698
  best_threshold = threshold
699
  best_diff = diff
 
700
  elif diff > best_diff:
701
  break
702
  threshold += 0.01
703
  self.rel_l1_thresh = best_threshold
704
- print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(len(timesteps) - best_diff):0.1f} for a target of x{speed_factor}")
705
  return best_threshold
706
 
707
  def forward(
 
676
 
677
  best_threshold = 0.01
678
  best_diff = 1000
679
+ best_signed_diff = 1000
680
  target_nb_steps= int(len(timesteps) / speed_factor)
681
  threshold = 0.01
682
  while threshold <= 0.6:
 
687
  skip = False
688
  if not (i<=start_step or i== len(timesteps)):
689
  accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
690
+ # self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
691
+
692
  if accumulated_rel_l1_distance < threshold:
693
  skip = True
694
  else:
 
696
  previous_modulated_input = e_list[i]
697
  if not skip:
698
  nb_steps += 1
699
+ signed_diff = target_nb_steps - nb_steps
700
+ diff = abs(signed_diff)
701
  if diff < best_diff:
702
  best_threshold = threshold
703
  best_diff = diff
704
+ best_signed_diff = signed_diff
705
  elif diff > best_diff:
706
  break
707
  threshold += 0.01
708
  self.rel_l1_thresh = best_threshold
709
+ print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
710
  return best_threshold
711
 
712
  def forward(
wan/utils/prompt_parser.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def process_template(input_text):
4
+ """
5
+ Process a text template with macro instructions and variable substitution.
6
+ Supports multiple values for variables to generate multiple output versions.
7
+ Each section between macro lines is treated as a separate template.
8
+
9
+ Args:
10
+ input_text (str): The input template text
11
+
12
+ Returns:
13
+ tuple: (output_text, error_message)
14
+ - output_text: Processed output with variables substituted, or empty string if error
15
+ - error_message: Error description and problematic line, or empty string if no error
16
+ """
17
+ lines = input_text.strip().split('\n')
18
+ current_variables = {}
19
+ current_template_lines = []
20
+ all_output_lines = []
21
+ error_message = ""
22
+
23
+ # Process the input line by line
24
+ line_number = 0
25
+ while line_number < len(lines):
26
+ orig_line = lines[line_number]
27
+ line = orig_line.strip()
28
+ line_number += 1
29
+
30
+ # Skip empty lines or comments
31
+ if not line or line.startswith('#'):
32
+ continue
33
+
34
+ # Handle macro instructions
35
+ if line.startswith('!'):
36
+ # Process any accumulated template lines before starting a new macro
37
+ if current_template_lines:
38
+ # Process the current template with current variables
39
+ template_output, err = process_current_template(current_template_lines, current_variables)
40
+ if err:
41
+ return "", err
42
+ all_output_lines.extend(template_output)
43
+ current_template_lines = [] # Reset template lines
44
+
45
+ # Reset variables for the new macro
46
+ current_variables = {}
47
+
48
+ # Parse the macro line
49
+ macro_line = line[1:].strip()
50
+
51
+ # Check for unmatched braces in the whole line
52
+ open_braces = macro_line.count('{')
53
+ close_braces = macro_line.count('}')
54
+ if open_braces != close_braces:
55
+ error_message = f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces\nLine: '{orig_line}'"
56
+ return "", error_message
57
+
58
+ # Check for unclosed quotes
59
+ if macro_line.count('"') % 2 != 0:
60
+ error_message = f"Unclosed double quotes\nLine: '{orig_line}'"
61
+ return "", error_message
62
+
63
+ # Split by optional colon separator
64
+ var_sections = re.split(r'\s*:\s*', macro_line)
65
+
66
+ for section in var_sections:
67
+ section = section.strip()
68
+ if not section:
69
+ continue
70
+
71
+ # Extract variable name
72
+ var_match = re.search(r'\{([^}]+)\}', section)
73
+ if not var_match:
74
+ if '{' in section or '}' in section:
75
+ error_message = f"Malformed variable declaration\nLine: '{orig_line}'"
76
+ return "", error_message
77
+ continue
78
+
79
+ var_name = var_match.group(1).strip()
80
+ if not var_name:
81
+ error_message = f"Empty variable name\nLine: '{orig_line}'"
82
+ return "", error_message
83
+
84
+ # Check variable value format
85
+ value_part = section[section.find('}')+1:].strip()
86
+ if not value_part.startswith('='):
87
+ error_message = f"Missing '=' after variable '{{{var_name}}}'\nLine: '{orig_line}'"
88
+ return "", error_message
89
+
90
+ # Extract all quoted values
91
+ var_values = re.findall(r'"([^"]*)"', value_part)
92
+
93
+ # Check if there are values specified
94
+ if not var_values:
95
+ error_message = f"No quoted values found for variable '{{{var_name}}}'\nLine: '{orig_line}'"
96
+ return "", error_message
97
+
98
+ # Check for missing commas between values
99
+ # Look for patterns like "value""value" (missing comma)
100
+ if re.search(r'"[^,]*"[^,]*"', value_part):
101
+ error_message = f"Missing comma between values for variable '{{{var_name}}}'\nLine: '{orig_line}'"
102
+ return "", error_message
103
+
104
+ # Store the variable values
105
+ current_variables[var_name] = var_values
106
+
107
+ # Handle template lines
108
+ else:
109
+ # Check for unknown variables in template line
110
+ var_references = re.findall(r'\{([^}]+)\}', line)
111
+ for var_ref in var_references:
112
+ if var_ref not in current_variables:
113
+ error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'"
114
+ return "", error_message
115
+
116
+ # Add to current template lines
117
+ current_template_lines.append(line)
118
+
119
+ # Process any remaining template lines
120
+ if current_template_lines:
121
+ template_output, err = process_current_template(current_template_lines, current_variables)
122
+ if err:
123
+ return "", err
124
+ all_output_lines.extend(template_output)
125
+
126
+ return '\n'.join(all_output_lines), ""
127
+
128
+ def process_current_template(template_lines, variables):
129
+ """
130
+ Process a set of template lines with the current variables.
131
+
132
+ Args:
133
+ template_lines (list): List of template lines to process
134
+ variables (dict): Dictionary of variable names to lists of values
135
+
136
+ Returns:
137
+ tuple: (output_lines, error_message)
138
+ """
139
+ if not variables or not template_lines:
140
+ return template_lines, ""
141
+
142
+ output_lines = []
143
+
144
+ # Find the maximum number of values for any variable
145
+ max_values = max(len(values) for values in variables.values())
146
+
147
+ # Generate each combination
148
+ for i in range(max_values):
149
+ for template in template_lines:
150
+ output_line = template
151
+ for var_name, var_values in variables.items():
152
+ # Use modulo to cycle through values if needed
153
+ value_index = i % len(var_values)
154
+ var_value = var_values[value_index]
155
+ output_line = output_line.replace(f"{{{var_name}}}", var_value)
156
+ output_lines.append(output_line)
157
+
158
+ return output_lines, ""
159
+
160
+
161
+ def extract_variable_names(macro_line):
162
+ """
163
+ Extract all variable names from a macro line.
164
+
165
+ Args:
166
+ macro_line (str): A macro line (with or without the leading '!')
167
+
168
+ Returns:
169
+ tuple: (variable_names, error_message)
170
+ - variable_names: List of variable names found in the macro
171
+ - error_message: Error description if any, empty string if no error
172
+ """
173
+ # Remove leading '!' if present
174
+ if macro_line.startswith('!'):
175
+ macro_line = macro_line[1:].strip()
176
+
177
+ variable_names = []
178
+
179
+ # Check for unmatched braces
180
+ open_braces = macro_line.count('{')
181
+ close_braces = macro_line.count('}')
182
+ if open_braces != close_braces:
183
+ return [], f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
184
+
185
+ # Split by optional colon separator
186
+ var_sections = re.split(r'\s*:\s*', macro_line)
187
+
188
+ for section in var_sections:
189
+ section = section.strip()
190
+ if not section:
191
+ continue
192
+
193
+ # Extract variable name
194
+ var_matches = re.findall(r'\{([^}]+)\}', section)
195
+ for var_name in var_matches:
196
+ new_var = var_name.strip()
197
+ if not new_var in variable_names:
198
+ variable_names.append(new_var)
199
+
200
+ return variable_names, ""
201
+
202
+ def extract_variable_values(macro_line):
203
+ """
204
+ Extract all variable names and their values from a macro line.
205
+
206
+ Args:
207
+ macro_line (str): A macro line (with or without the leading '!')
208
+
209
+ Returns:
210
+ tuple: (variables_dict, error_message)
211
+ - variables_dict: Dictionary mapping variable names to their values
212
+ - error_message: Error description if any, empty string if no error
213
+ """
214
+ # Remove leading '!' if present
215
+ if macro_line.startswith('!'):
216
+ macro_line = macro_line[1:].strip()
217
+
218
+ variables = {}
219
+
220
+ # Check for unmatched braces
221
+ open_braces = macro_line.count('{')
222
+ close_braces = macro_line.count('}')
223
+ if open_braces != close_braces:
224
+ return {}, f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
225
+
226
+ # Check for unclosed quotes
227
+ if macro_line.count('"') % 2 != 0:
228
+ return {}, "Unclosed double quotes"
229
+
230
+ # Split by optional colon separator
231
+ var_sections = re.split(r'\s*:\s*', macro_line)
232
+
233
+ for section in var_sections:
234
+ section = section.strip()
235
+ if not section:
236
+ continue
237
+
238
+ # Extract variable name
239
+ var_match = re.search(r'\{([^}]+)\}', section)
240
+ if not var_match:
241
+ if '{' in section or '}' in section:
242
+ return {}, "Malformed variable declaration"
243
+ continue
244
+
245
+ var_name = var_match.group(1).strip()
246
+ if not var_name:
247
+ return {}, "Empty variable name"
248
+
249
+ # Check variable value format
250
+ value_part = section[section.find('}')+1:].strip()
251
+ if not value_part.startswith('='):
252
+ return {}, f"Missing '=' after variable '{{{var_name}}}'"
253
+
254
+ # Extract all quoted values
255
+ var_values = re.findall(r'"([^"]*)"', value_part)
256
+
257
+ # Check if there are values specified
258
+ if not var_values:
259
+ return {}, f"No quoted values found for variable '{{{var_name}}}'"
260
+
261
+ # Check for missing commas between values
262
+ if re.search(r'"[^,]*"[^,]*"', value_part):
263
+ return {}, f"Missing comma between values for variable '{{{var_name}}}'"
264
+
265
+ variables[var_name] = var_values
266
+
267
+ return variables, ""
268
+
269
+ def generate_macro_line(variables_dict):
270
+ """
271
+ Generate a macro line from a dictionary of variable names and their values.
272
+
273
+ Args:
274
+ variables_dict (dict): Dictionary mapping variable names to lists of values
275
+
276
+ Returns:
277
+ str: A formatted macro line (including the leading '!')
278
+ """
279
+ sections = []
280
+
281
+ for var_name, values in variables_dict.items():
282
+ # Format each value with quotes
283
+ quoted_values = [f'"{value}"' for value in values]
284
+ # Join values with commas
285
+ values_str = ','.join(quoted_values)
286
+ # Create the variable assignment
287
+ section = f"{{{var_name}}}={values_str}"
288
+ sections.append(section)
289
+
290
+ # Join sections with a colon and space for readability
291
+ return "! " + " : ".join(sections)