DeepBeepMeep commited on
Commit
acc356f
·
1 Parent(s): 92f2b6e

Refactored Loras

Browse files
README.md CHANGED
@@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
 
22
  * Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
23
  * Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
24
  * Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
@@ -157,6 +158,11 @@ python gradio_server.py --attention sdpa
157
 
158
  Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
159
 
 
 
 
 
 
160
  For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
161
  *1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
162
 
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
23
  * Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
24
  * Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
25
  * Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
 
158
 
159
  Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
160
 
161
+ If you want to manage in differenta areas Loras for the 1.3B model and the 14B as they are not comptatible, just create the following subfolders:
162
+ - loras/1.3B
163
+ - loras/14B
164
+
165
+
166
  For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
167
  *1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
168
 
gradio_server.py CHANGED
@@ -170,6 +170,18 @@ def _parse_args():
170
 
171
  return args
172
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  attention_modes_supported = get_attention_modes()
174
 
175
  args = _parse_args()
@@ -238,9 +250,10 @@ if args.i2v:
238
 
239
  lora_dir =args.lora_dir
240
  if use_image2video and len(lora_dir)==0:
241
- lora_dir =args.lora_dir_i2v
242
  if len(lora_dir) ==0:
243
- lora_dir = "loras_i2v" if use_image2video else "loras"
 
244
  lora_preselected_preset = args.lora_preset
245
  default_tea_cache = 0
246
  # if args.fast : #or args.fastest
@@ -295,35 +308,42 @@ def sanitize_file_name(file_name, rep =""):
295
  return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
296
 
297
  def extract_preset(lset_name, loras):
 
 
 
 
 
298
  lset_name = sanitize_file_name(lset_name)
299
  if not lset_name.endswith(".lset"):
300
  lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
301
  else:
302
  lset_name_filename = os.path.join(lora_dir, lset_name )
303
-
304
  if not os.path.isfile(lset_name_filename):
305
- raise gr.Error(f"Preset '{lset_name}' not found ")
 
 
306
 
307
- with open(lset_name_filename, "r", encoding="utf-8") as reader:
308
- text = reader.read()
309
- lset = json.loads(text)
310
 
311
- loras_choices_files = lset["loras"]
312
- loras_choices = []
313
- missing_loras = []
314
- for lora_file in loras_choices_files:
315
- loras_choice_no = loras.index(os.path.join(lora_dir, lora_file))
316
- if loras_choice_no < 0:
317
- missing_loras.append(lora_file)
318
- else:
319
- loras_choices.append(str(loras_choice_no))
320
 
321
- if len(missing_loras) > 0:
322
- raise gr.Error(f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing: {missing_loras}")
323
-
324
- loras_mult_choices = lset["loras_mult"]
325
- prompt = lset.get("prompt", "")
326
- return loras_choices, loras_mult_choices, prompt, lset.get("full_prompt", False)
 
327
 
328
  def get_default_prompt(i2v):
329
  if i2v:
@@ -332,7 +352,7 @@ def get_default_prompt(i2v):
332
  return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
333
 
334
 
335
- def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
336
  loras =[]
337
  loras_names = []
338
  default_loras_choices = []
@@ -341,32 +361,35 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
341
  default_lora_preset = ""
342
  default_prompt = ""
343
 
344
- if use_image2video or not "1.3B" in transformer_filename_t2v:
345
- from pathlib import Path
 
 
 
346
 
347
- if lora_dir != None :
348
- if not os.path.isdir(lora_dir):
349
- raise Exception("--lora-dir should be a path to a directory that contains Loras")
350
 
351
- if lora_dir != None:
352
- import glob
353
- dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
354
- dir_loras.sort()
355
- loras += [element for element in dir_loras if element not in loras ]
356
 
357
- dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
358
- dir_presets.sort()
359
- loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
360
 
361
- if len(loras) > 0:
362
- loras_names = [ Path(lora).stem for lora in loras ]
363
- offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
364
 
365
- if len(lora_preselected_preset) > 0:
366
- if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
367
- raise Exception(f"Unknown preset '{lora_preselected_preset}'")
368
- default_lora_preset = lora_preselected_preset
369
- default_loras_choices, default_loras_multis_str, default_prompt, _ = extract_preset(default_lora_preset, loras)
 
 
 
 
 
370
  if len(default_prompt) == 0:
371
  default_prompt = get_default_prompt(use_image2video)
372
  return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
@@ -450,8 +473,8 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
450
  kwargs["budgets"] = { "*" : "70%" }
451
 
452
 
453
- loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe, lora_dir, lora_preselected_preset, None)
454
- offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, **kwargs)
455
 
456
 
457
  return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
@@ -542,7 +565,7 @@ def apply_changes( state,
542
  state["config_new"] = server_config
543
  state["config_old"] = old_server_config
544
 
545
- global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
546
  attention_mode = server_config["attention_mode"]
547
  profile = server_config["profile"]
548
  compile = server_config["compile"]
@@ -560,7 +583,7 @@ def apply_changes( state,
560
  offloadobj.release()
561
  offloadobj = None
562
  yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
563
-
564
  wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
565
 
566
 
@@ -590,7 +613,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
590
  num_inference_steps, flow_shift = get_default_flow(trans_file)
591
 
592
  header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
593
- return num_inference_steps, flow_shift, header
 
 
 
 
594
 
595
 
596
  from moviepy.editor import ImageSequenceClip
@@ -603,23 +630,32 @@ def save_video(final_frames, output_path, fps=24):
603
  ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
604
 
605
  def build_callback(state, pipe, progress, status, num_inference_steps):
606
- def callback(step_idx, latents):
607
- step_idx += 1
608
- if state.get("abort", False):
609
- # pipe._interrupt = True
610
- status_msg = status + " - Aborting"
611
- elif step_idx == num_inference_steps:
612
- status_msg = status + " - VAE Decoding"
613
  else:
614
- status_msg = status + " - Denoising"
615
-
616
- progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
 
 
 
 
 
 
 
 
 
 
 
617
 
618
  return callback
619
 
620
  def abort_generation(state):
621
  if "in_progress" in state:
622
  state["abort"] = True
 
623
  wan_model._interrupt= True
624
  return gr.Button(interactive= False)
625
  else:
@@ -634,11 +670,12 @@ def finalize_gallery(state):
634
  if "in_progress" in state:
635
  del state["in_progress"]
636
  choice = state.get("selected",0)
637
-
 
638
  time.sleep(0.2)
639
  global gen_in_progress
640
  gen_in_progress = False
641
- return gr.Gallery(selected_index=choice), gr.Button(interactive= True)
642
 
643
  def select_video(state , event_data: gr.EventData):
644
  data= event_data._data
@@ -656,6 +693,32 @@ def expand_slist(slist, num_inference_steps ):
656
  return new_slist
657
 
658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  def generate_video(
660
  prompt,
661
  negative_prompt,
@@ -733,6 +796,7 @@ def generate_video(
733
  else:
734
  VAE_tile_size = 128
735
 
 
736
 
737
  global gen_in_progress
738
  gen_in_progress = True
@@ -740,7 +804,7 @@ def generate_video(
740
  if len(prompt) ==0:
741
  return
742
  prompts = prompt.replace("\r", "").split("\n")
743
- prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
744
  if len(prompts) ==0:
745
  return
746
  if use_image2video:
@@ -808,9 +872,13 @@ def generate_video(
808
  list_mult_choices_nums.append(float(mult))
809
  if len(list_mult_choices_nums ) < len(loras_choices):
810
  list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
811
-
812
- offload.activate_loras(wan_model.model, loras_choices, list_mult_choices_nums)
813
-
 
 
 
 
814
  seed = None if seed == -1 else seed
815
  # negative_prompt = "" # not applicable in the inference
816
 
@@ -825,7 +893,6 @@ def generate_video(
825
 
826
  joint_pass = boost ==1
827
  # TeaCache
828
- trans = wan_model.model
829
  trans.enable_teacache = tea_cache > 0
830
  if trans.enable_teacache:
831
  if use_image2video:
@@ -857,11 +924,23 @@ def generate_video(
857
  os.makedirs(save_path, exist_ok=True)
858
  video_no = 0
859
  total_video = repeat_generation * len(prompts)
 
 
860
  abort = False
861
  start_time = time.time()
862
- for prompt in prompts:
863
- for _ in range(repeat_generation):
864
- if abort:
 
 
 
 
 
 
 
 
 
 
865
  break
866
 
867
  if trans.enable_teacache:
@@ -875,9 +954,12 @@ def generate_video(
875
 
876
  video_no += 1
877
  status = f"Video {video_no}/{total_video}"
 
 
 
878
  progress(0, desc=status + " - Encoding Prompt" )
879
-
880
  callback = build_callback(state, trans, progress, status, num_inference_steps)
 
881
 
882
 
883
  gc.collect()
@@ -887,7 +969,7 @@ def generate_video(
887
  if use_image2video:
888
  samples = wan_model.generate(
889
  prompt,
890
- image_to_continue[ (video_no-1) % len(image_to_continue)].convert('RGB'),
891
  frame_num=(video_length // 4)* 4 + 1,
892
  max_area=MAX_AREA_CONFIGS[resolution],
893
  shift=flow_shift,
@@ -923,6 +1005,7 @@ def generate_video(
923
  if temp_filename!= None and os.path.isfile(temp_filename):
924
  os.remove(temp_filename)
925
  offload.last_offload_obj.unload_all()
 
926
  # if compile:
927
  # cache_size = torch._dynamo.config.cache_size_limit
928
  # torch.compiler.reset()
@@ -967,9 +1050,9 @@ def generate_video(
967
 
968
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
969
  if os.name == 'nt':
970
- file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50])}.mp4"
971
  else:
972
- file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100])}.mp4"
973
  video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
974
  cache_video(
975
  tensor=sample[None],
@@ -987,10 +1070,13 @@ def generate_video(
987
  end_time = time.time()
988
  yield f"Total Generation Time: {end_time-start_time:.1f}s"
989
  seed += 1
 
990
 
991
  if temp_filename!= None and os.path.isfile(temp_filename):
992
  os.remove(temp_filename)
993
  gen_in_progress = False
 
 
994
 
995
  new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
996
 
@@ -998,19 +1084,19 @@ new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
998
  def validate_delete_lset(lset_name):
999
  if len(lset_name) == 0 or lset_name == new_preset_msg:
1000
  gr.Info(f"Choose a Preset to delete")
1001
- return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
1002
  else:
1003
- return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
1004
 
1005
  def validate_save_lset(lset_name):
1006
  if len(lset_name) == 0 or lset_name == new_preset_msg:
1007
  gr.Info("Please enter a name for the preset")
1008
- return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
1009
  else:
1010
- return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
1011
 
1012
  def cancel_lset():
1013
- return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1014
 
1015
  def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
1016
  global loras_presets
@@ -1047,7 +1133,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
1047
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
1048
  lset_choices.append( (new_preset_msg, ""))
1049
 
1050
- return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1051
 
1052
  def delete_lset(lset_name):
1053
  global loras_presets
@@ -1065,23 +1151,57 @@ def delete_lset(lset_name):
1065
 
1066
  lset_choices = [ (preset, preset) for preset in loras_presets]
1067
  lset_choices.append((new_preset_msg, ""))
1068
- return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1069
 
1070
 
 
 
1071
  def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
1072
 
1073
  if len(lset_name) == 0 or lset_name== new_preset_msg:
1074
  gr.Info("Please choose a preset in the list or create one")
1075
  else:
1076
- loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
1077
- if full_prompt:
1078
- prompt = preset_prompt
1079
- elif len(preset_prompt) > 0:
1080
- prompts = prompt.replace("\r", "").split("\n")
1081
- prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
1082
- prompt = "\n".join(prompts)
1083
- prompt = preset_prompt + '\n' + prompt
1084
- gr.Info(f"Lora Preset '{lset_name}' has been applied")
 
 
 
1085
 
1086
  return loras_choices, loras_mult_choices, prompt
1087
 
@@ -1094,21 +1214,21 @@ def create_demo():
1094
  state = gr.State({})
1095
 
1096
  if use_image2video:
1097
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1098
  else:
1099
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1100
 
1101
- gr.Markdown("<FONT SIZE=3>With this first release of Wan 2.1GP by <B>DeepBeepMeep</B>, the VRAM requirements have been divided by more than 2 with no quality loss</FONT>")
1102
 
1103
  if use_image2video and False:
1104
  pass
1105
  else:
1106
- gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance : 24 GB of VRAM (RTX 3090 / RTX 4090), the limits are as follows:")
1107
  gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
1108
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
1109
  gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
1110
  gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
1111
- gr.Markdown("Please note that if your turn on compilation, the first generation step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
1112
 
1113
 
1114
  # css = """<STYLE>
@@ -1302,6 +1422,7 @@ def create_demo():
1302
  # with gr.Column():
1303
  with gr.Row(height=17):
1304
  apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
 
1305
  # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
1306
  save_lset_prompt_drop= gr.Dropdown(
1307
  choices=[
@@ -1334,7 +1455,7 @@ def create_demo():
1334
  with gr.Row(visible=False) as advanced_row:
1335
  with gr.Column():
1336
  seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
1337
- repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated Video per prompt")
1338
  with gr.Row():
1339
  negative_prompt = gr.Textbox(label="Negative Prompt", value="")
1340
  with gr.Row():
@@ -1377,22 +1498,25 @@ def create_demo():
1377
  label="Generated videos", show_label=False, elem_id="gallery"
1378
  , columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
1379
  generate_btn = gr.Button("Generate")
 
1380
  abort_btn = gr.Button("Abort")
1381
 
1382
- save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1383
- confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1384
- delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1385
- confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1386
- cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
1387
 
1388
  apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
1389
 
 
 
1390
  gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1391
 
1392
  abort_btn.click(abort_generation,state,abort_btn )
1393
  output.select(select_video, state, None )
1394
-
1395
- generate_btn.click(
1396
  fn=generate_video,
1397
  inputs=[
1398
  prompt,
@@ -1420,7 +1544,7 @@ def create_demo():
1420
  ).then(
1421
  finalize_gallery,
1422
  [state],
1423
- [output , abort_btn]
1424
  )
1425
 
1426
  apply_btn.click(
@@ -1441,7 +1565,7 @@ def create_demo():
1441
  ).then(
1442
  update_defaults,
1443
  [state, num_inference_steps, flow_shift],
1444
- [num_inference_steps, flow_shift, header]
1445
  )
1446
 
1447
  return demo
 
170
 
171
  return args
172
 
173
+ def get_lora_dir(root_lora_dir):
174
+ if not use_image2video:
175
+ if "1.3B" in transformer_filename_t2v:
176
+ lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
177
+ if os.path.isdir(lora_dir_1_3B ):
178
+ return lora_dir_1_3B
179
+ else:
180
+ lora_dir_14B = os.path.join(root_lora_dir, "14B")
181
+ if os.path.isdir(lora_dir_14B ):
182
+ return lora_dir_14B
183
+ return root_lora_dir
184
+
185
  attention_modes_supported = get_attention_modes()
186
 
187
  args = _parse_args()
 
250
 
251
  lora_dir =args.lora_dir
252
  if use_image2video and len(lora_dir)==0:
253
+ root_lora_dir =args.lora_dir_i2v
254
  if len(lora_dir) ==0:
255
+ root_lora_dir = "loras_i2v" if use_image2video else "loras"
256
+ lora_dir = get_lora_dir(root_lora_dir)
257
  lora_preselected_preset = args.lora_preset
258
  default_tea_cache = 0
259
  # if args.fast : #or args.fastest
 
308
  return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
309
 
310
  def extract_preset(lset_name, loras):
311
+ loras_choices = []
312
+ loras_choices_files = []
313
+ loras_mult_choices = ""
314
+ prompt =""
315
+ full_prompt =""
316
  lset_name = sanitize_file_name(lset_name)
317
  if not lset_name.endswith(".lset"):
318
  lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
319
  else:
320
  lset_name_filename = os.path.join(lora_dir, lset_name )
321
+ error = ""
322
  if not os.path.isfile(lset_name_filename):
323
+ error = f"Preset '{lset_name}' not found "
324
+ else:
325
+ missing_loras = []
326
 
327
+ with open(lset_name_filename, "r", encoding="utf-8") as reader:
328
+ text = reader.read()
329
+ lset = json.loads(text)
330
 
331
+ loras_choices_files = lset["loras"]
332
+ for lora_file in loras_choices_files:
333
+ choice = os.path.join(lora_dir, lora_file)
334
+ if choice not in loras:
335
+ missing_loras.append(lora_file)
336
+ else:
337
+ loras_choice_no = loras.index(choice)
338
+ loras_choices.append(str(loras_choice_no))
 
339
 
340
+ if len(missing_loras) > 0:
341
+ error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}"
342
+
343
+ loras_mult_choices = lset["loras_mult"]
344
+ prompt = lset.get("prompt", "")
345
+ full_prompt = lset.get("full_prompt", False)
346
+ return loras_choices, loras_mult_choices, prompt, full_prompt, error
347
 
348
  def get_default_prompt(i2v):
349
  if i2v:
 
352
  return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
353
 
354
 
355
+ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
356
  loras =[]
357
  loras_names = []
358
  default_loras_choices = []
 
361
  default_lora_preset = ""
362
  default_prompt = ""
363
 
364
+ from pathlib import Path
365
+
366
+ if lora_dir != None :
367
+ if not os.path.isdir(lora_dir):
368
+ raise Exception("--lora-dir should be a path to a directory that contains Loras")
369
 
 
 
 
370
 
371
+ if lora_dir != None:
372
+ import glob
373
+ dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
374
+ dir_loras.sort()
375
+ loras += [element for element in dir_loras if element not in loras ]
376
 
377
+ dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
378
+ dir_presets.sort()
379
+ loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
380
 
381
+ loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
 
 
382
 
383
+ if len(loras) > 0:
384
+ loras_names = [ Path(lora).stem for lora in loras ]
385
+
386
+ if len(lora_preselected_preset) > 0:
387
+ if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
388
+ raise Exception(f"Unknown preset '{lora_preselected_preset}'")
389
+ default_lora_preset = lora_preselected_preset
390
+ default_loras_choices, default_loras_multis_str, default_prompt, _ , error = extract_preset(default_lora_preset, loras)
391
+ if len(error) > 0:
392
+ print(error[:200])
393
  if len(default_prompt) == 0:
394
  default_prompt = get_default_prompt(use_image2video)
395
  return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
 
473
  kwargs["budgets"] = { "*" : "70%" }
474
 
475
 
476
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
477
+ loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe["transformer"], lora_dir, lora_preselected_preset, None)
478
 
479
 
480
  return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
 
565
  state["config_new"] = server_config
566
  state["config_old"] = old_server_config
567
 
568
+ global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
569
  attention_mode = server_config["attention_mode"]
570
  profile = server_config["profile"]
571
  compile = server_config["compile"]
 
583
  offloadobj.release()
584
  offloadobj = None
585
  yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
586
+ lora_dir = get_lora_dir(root_lora_dir)
587
  wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
588
 
589
 
 
613
  num_inference_steps, flow_shift = get_default_flow(trans_file)
614
 
615
  header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
616
+ new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
617
+ lset_choices = [ (preset, preset) for preset in loras_presets]
618
+ lset_choices.append( (new_preset_msg, ""))
619
+
620
+ return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
621
 
622
 
623
  from moviepy.editor import ImageSequenceClip
 
630
  ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
631
 
632
  def build_callback(state, pipe, progress, status, num_inference_steps):
633
+ def callback(step_idx, latents, read_state = False):
634
+ status = state["progress_status"]
635
+ if read_state:
636
+ phase, step_idx = state["progress_phase"]
 
 
 
637
  else:
638
+ step_idx += 1
639
+ if state.get("abort", False):
640
+ # pipe._interrupt = True
641
+ phase = " - Aborting"
642
+ elif step_idx == num_inference_steps:
643
+ phase = " - VAE Decoding"
644
+ else:
645
+ phase = " - Denoising"
646
+ state["progress_phase"] = (phase, step_idx)
647
+ status_msg = status + phase
648
+ if step_idx >= 0:
649
+ progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
650
+ else:
651
+ progress(0, status_msg)
652
 
653
  return callback
654
 
655
  def abort_generation(state):
656
  if "in_progress" in state:
657
  state["abort"] = True
658
+ state["extra_orders"] = 0
659
  wan_model._interrupt= True
660
  return gr.Button(interactive= False)
661
  else:
 
670
  if "in_progress" in state:
671
  del state["in_progress"]
672
  choice = state.get("selected",0)
673
+
674
+ state["extra_orders"] = 0
675
  time.sleep(0.2)
676
  global gen_in_progress
677
  gen_in_progress = False
678
+ return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
679
 
680
  def select_video(state , event_data: gr.EventData):
681
  data= event_data._data
 
693
  return new_slist
694
 
695
 
696
+ def one_more_video(state):
697
+ extra_orders = state.get("extra_orders", 0)
698
+ extra_orders += 1
699
+ state["extra_orders"] = extra_orders
700
+ prompts_max = state["prompts_max"]
701
+ prompt_no = state["prompt_no"]
702
+ video_no = state["video_no"]
703
+ total_video = state["total_video"]
704
+ # total_video += (prompts_max- prompt_no)
705
+ total_video += 1
706
+ total_generation = state["total_generation"] + extra_orders
707
+ state["total_video"] = total_video
708
+
709
+ state["progress_status"] = f"Video {video_no}/{total_video}"
710
+ offload.shared_state["refresh"] = 1
711
+ # if (prompts_max - prompt_no) > 1:
712
+ # gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
713
+ # else:
714
+ gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
715
+
716
+ return state
717
+
718
+ def prepare_generate_video():
719
+
720
+ return gr.Button(visible= False), gr.Checkbox(visible= True)
721
+
722
  def generate_video(
723
  prompt,
724
  negative_prompt,
 
796
  else:
797
  VAE_tile_size = 128
798
 
799
+ trans = wan_model.model
800
 
801
  global gen_in_progress
802
  gen_in_progress = True
 
804
  if len(prompt) ==0:
805
  return
806
  prompts = prompt.replace("\r", "").split("\n")
807
+ prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
808
  if len(prompts) ==0:
809
  return
810
  if use_image2video:
 
872
  list_mult_choices_nums.append(float(mult))
873
  if len(list_mult_choices_nums ) < len(loras_choices):
874
  list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
875
+ loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
876
+ pinnedLora = False # profile !=5
877
+ offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
878
+ errors = trans._loras_errors
879
+ if len(errors) > 0:
880
+ error_files = [msg for _ , msg in errors]
881
+ raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
882
  seed = None if seed == -1 else seed
883
  # negative_prompt = "" # not applicable in the inference
884
 
 
893
 
894
  joint_pass = boost ==1
895
  # TeaCache
 
896
  trans.enable_teacache = tea_cache > 0
897
  if trans.enable_teacache:
898
  if use_image2video:
 
924
  os.makedirs(save_path, exist_ok=True)
925
  video_no = 0
926
  total_video = repeat_generation * len(prompts)
927
+ state["total_video"] = total_video
928
+ extra_generation = 0
929
  abort = False
930
  start_time = time.time()
931
+ state["prompts_max"] = len(prompts)
932
+ for no, prompt in enumerate(prompts):
933
+ repeat_no = 0
934
+ state["prompt_no"] = no
935
+ extra_generation = 0
936
+ while True:
937
+ extra_orders = state.get("extra_orders",0)
938
+ state["extra_orders"] = 0
939
+ extra_generation += extra_orders
940
+ state["total_generation"] = repeat_generation + extra_generation
941
+ # total_video += (len(prompts)- no) * extra_orders
942
+ total_video += extra_orders
943
+ if abort or repeat_no >= (repeat_generation + extra_generation):
944
  break
945
 
946
  if trans.enable_teacache:
 
954
 
955
  video_no += 1
956
  status = f"Video {video_no}/{total_video}"
957
+ state["video_no"] = video_no
958
+ state["progress_status"] = status
959
+ state["progress_phase"] = (" - Encoding Prompt", -1 )
960
  progress(0, desc=status + " - Encoding Prompt" )
 
961
  callback = build_callback(state, trans, progress, status, num_inference_steps)
962
+ offload.shared_state["callback"] = callback
963
 
964
 
965
  gc.collect()
 
969
  if use_image2video:
970
  samples = wan_model.generate(
971
  prompt,
972
+ image_to_continue[no].convert('RGB'),
973
  frame_num=(video_length // 4)* 4 + 1,
974
  max_area=MAX_AREA_CONFIGS[resolution],
975
  shift=flow_shift,
 
1005
  if temp_filename!= None and os.path.isfile(temp_filename):
1006
  os.remove(temp_filename)
1007
  offload.last_offload_obj.unload_all()
1008
+ offload.unload_loras_from_model(trans)
1009
  # if compile:
1010
  # cache_size = torch._dynamo.config.cache_size_limit
1011
  # torch.compiler.reset()
 
1050
 
1051
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
1052
  if os.name == 'nt':
1053
+ file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
1054
  else:
1055
+ file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
1056
  video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
1057
  cache_video(
1058
  tensor=sample[None],
 
1070
  end_time = time.time()
1071
  yield f"Total Generation Time: {end_time-start_time:.1f}s"
1072
  seed += 1
1073
+ repeat_no += 1
1074
 
1075
  if temp_filename!= None and os.path.isfile(temp_filename):
1076
  os.remove(temp_filename)
1077
  gen_in_progress = False
1078
+ offload.unload_loras_from_model(trans)
1079
+
1080
 
1081
  new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
1082
 
 
1084
  def validate_delete_lset(lset_name):
1085
  if len(lset_name) == 0 or lset_name == new_preset_msg:
1086
  gr.Info(f"Choose a Preset to delete")
1087
+ return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
1088
  else:
1089
+ return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
1090
 
1091
  def validate_save_lset(lset_name):
1092
  if len(lset_name) == 0 or lset_name == new_preset_msg:
1093
  gr.Info("Please enter a name for the preset")
1094
+ return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
1095
  else:
1096
+ return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
1097
 
1098
  def cancel_lset():
1099
+ return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1100
 
1101
  def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
1102
  global loras_presets
 
1133
  lset_choices = [ ( preset, preset) for preset in loras_presets ]
1134
  lset_choices.append( (new_preset_msg, ""))
1135
 
1136
+ return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
1137
 
1138
  def delete_lset(lset_name):
1139
  global loras_presets
 
1151
 
1152
  lset_choices = [ (preset, preset) for preset in loras_presets]
1153
  lset_choices.append((new_preset_msg, ""))
1154
+ return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
1155
+
1156
+ def refresh_lora_list(lset_name, loras_choices):
1157
+ global loras,loras_names, loras_presets
1158
+ prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
1159
+
1160
+ loras, loras_names, _, _, _, _, loras_presets = setup_loras(wan_model.model, lora_dir, lora_preselected_preset, None)
1161
+ gc.collect()
1162
+ new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
1163
+ new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) }
1164
+ lora_names_selected = []
1165
+ for lora in prev_lora_names_selected:
1166
+ lora_id = new_loras_dict.get(lora, None)
1167
+ if lora_id!= None:
1168
+ lora_names_selected.append(lora_id)
1169
+
1170
+ lset_choices = [ (preset, preset) for preset in loras_presets]
1171
+ lset_choices.append((new_preset_msg, ""))
1172
+ if lset_name in loras_presets:
1173
+ pos = loras_presets.index(lset_name)
1174
+ else:
1175
+ pos = len(loras_presets)
1176
+ lset_name =""
1177
+
1178
+ errors = wan_model.model._loras_errors
1179
+ if len(errors) > 0:
1180
+ error_files = [path for path, _ in errors]
1181
+ gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
1182
+ else:
1183
+ gr.Info("Lora List has been refreshed")
1184
 
1185
 
1186
+ return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
1187
+
1188
  def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
1189
 
1190
  if len(lset_name) == 0 or lset_name== new_preset_msg:
1191
  gr.Info("Please choose a preset in the list or create one")
1192
  else:
1193
+ loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
1194
+ if len(error) > 0:
1195
+ gr.Info(error)
1196
+ else:
1197
+ if full_prompt:
1198
+ prompt = preset_prompt
1199
+ elif len(preset_prompt) > 0:
1200
+ prompts = prompt.replace("\r", "").split("\n")
1201
+ prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
1202
+ prompt = "\n".join(prompts)
1203
+ prompt = preset_prompt + '\n' + prompt
1204
+ gr.Info(f"Lora Preset '{lset_name}' has been applied")
1205
 
1206
  return loras_choices, loras_mult_choices, prompt
1207
 
 
1214
  state = gr.State({})
1215
 
1216
  if use_image2video:
1217
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1218
  else:
1219
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
1220
 
1221
+ gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
1222
 
1223
  if use_image2video and False:
1224
  pass
1225
  else:
1226
+ gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
1227
  gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
1228
  gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
1229
  gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
1230
  gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
1231
+ gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
1232
 
1233
 
1234
  # css = """<STYLE>
 
1422
  # with gr.Column():
1423
  with gr.Row(height=17):
1424
  apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
1425
+ refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
1426
  # save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
1427
  save_lset_prompt_drop= gr.Dropdown(
1428
  choices=[
 
1455
  with gr.Row(visible=False) as advanced_row:
1456
  with gr.Column():
1457
  seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
1458
+ repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
1459
  with gr.Row():
1460
  negative_prompt = gr.Textbox(label="Negative Prompt", value="")
1461
  with gr.Row():
 
1498
  label="Generated videos", show_label=False, elem_id="gallery"
1499
  , columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
1500
  generate_btn = gr.Button("Generate")
1501
+ onemore_btn = gr.Button("One More Please !", visible= False)
1502
  abort_btn = gr.Button("Abort")
1503
 
1504
+ save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1505
+ confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
1506
+ delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1507
+ confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
1508
+ cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
1509
 
1510
  apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
1511
 
1512
+ refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
1513
+
1514
  gen_status.change(refresh_gallery, inputs = [state], outputs = output )
1515
 
1516
  abort_btn.click(abort_generation,state,abort_btn )
1517
  output.select(select_video, state, None )
1518
+ onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
1519
+ generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
1520
  fn=generate_video,
1521
  inputs=[
1522
  prompt,
 
1544
  ).then(
1545
  finalize_gallery,
1546
  [state],
1547
+ [output , abort_btn, generate_btn, onemore_btn]
1548
  )
1549
 
1550
  apply_btn.click(
 
1565
  ).then(
1566
  update_defaults,
1567
  [state, num_inference_steps, flow_shift],
1568
+ [num_inference_steps, flow_shift, header, lset_name , loras_choices ]
1569
  )
1570
 
1571
  return demo
requirements.txt CHANGED
@@ -16,5 +16,5 @@ gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
- mmgp==3.2.6
20
  peft==0.14.0
 
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
+ mmgp==3.2.7
20
  peft==0.14.0
wan/image2video.py CHANGED
@@ -331,7 +331,7 @@ class WanI2V:
331
  callback(-1, None)
332
 
333
  for i, t in enumerate(tqdm(timesteps)):
334
- offload.set_step_no_for_lora(i)
335
  latent_model_input = [latent.to(self.device)]
336
  timestep = [t]
337
 
 
331
  callback(-1, None)
332
 
333
  for i, t in enumerate(tqdm(timesteps)):
334
+ offload.set_step_no_for_lora(self.model, i)
335
  latent_model_input = [latent.to(self.device)]
336
  timestep = [t]
337
 
wan/modules/attention.py CHANGED
@@ -60,6 +60,30 @@ try:
60
  except ImportError:
61
  sageattn = None
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @torch.compiler.disable()
65
  def sdpa_wrapper(
@@ -119,7 +143,8 @@ def pay_attention(
119
  deterministic=False,
120
  dtype=torch.bfloat16,
121
  version=None,
122
- force_attention= None
 
123
  ):
124
  """
125
  q: [B, Lq, Nq, C1].
@@ -194,9 +219,67 @@ def pay_attention(
194
  max_seqlen_kv=lk,
195
  ).unflatten(0, (b, lq))
196
  elif attn=="sage2":
197
- qkv_list = [q,k,v]
198
- del q,k,v
199
- x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  elif attn=="sdpa":
201
  qkv_list = [q, k, v]
202
  del q, k , v
 
60
  except ImportError:
61
  sageattn = None
62
 
63
+ # # try:
64
+ # if True:
65
+ # from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
66
+ # @torch.compiler.disable()
67
+ # def sageattn_window_wrapper(
68
+ # qkv_list,
69
+ # attention_length,
70
+ # window
71
+ # ):
72
+ # q,k, v = qkv_list
73
+ # padding_length = q.shape[0] -attention_length
74
+ # q = q[:attention_length, :, : ].unsqueeze(0)
75
+ # k = k[:attention_length, :, : ].unsqueeze(0)
76
+ # v = v[:attention_length, :, : ].unsqueeze(0)
77
+ # o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
78
+ # del q, k ,v
79
+ # qkv_list.clear()
80
+
81
+ # if padding_length > 0:
82
+ # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
83
+
84
+ # return o
85
+ # # except ImportError:
86
+ # # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
87
 
88
  @torch.compiler.disable()
89
  def sdpa_wrapper(
 
143
  deterministic=False,
144
  dtype=torch.bfloat16,
145
  version=None,
146
+ force_attention= None,
147
+ cross_attn= False
148
  ):
149
  """
150
  q: [B, Lq, Nq, C1].
 
219
  max_seqlen_kv=lk,
220
  ).unflatten(0, (b, lq))
221
  elif attn=="sage2":
222
+ import math
223
+ if cross_attn or True:
224
+ qkv_list = [q,k,v]
225
+ del q,k,v
226
+
227
+ x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
228
+ # else:
229
+ # layer = offload.shared_state["layer"]
230
+ # embed_sizes = offload.shared_state["embed_sizes"]
231
+ # current_step = offload.shared_state["step_no"]
232
+ # max_steps = offload.shared_state["max_steps"]
233
+
234
+
235
+ # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
236
+
237
+ # window = 0
238
+ # start_window_step = int(max_steps * 0.4)
239
+ # start_layer = 10
240
+ # if (layer < start_layer ) or current_step <start_window_step:
241
+ # window = 0
242
+ # else:
243
+ # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
244
+ # print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
245
+ # window = math.ceil(coef* nb_latents)
246
+
247
+ # invert_spaces = (layer + current_step) % 2 == 0 and window > 0
248
+
249
+ # def flip(q):
250
+ # q = q.reshape(*embed_sizes, *q.shape[-2:])
251
+ # q = q.transpose(0,2)
252
+ # q = q.contiguous()
253
+ # q = q.transpose(0,2)
254
+ # q = q.reshape( -1, *q.shape[-2:])
255
+ # return q
256
+
257
+ # def flop(q):
258
+ # q = q.reshape(embed_sizes[2], embed_sizes[1], embed_sizes[0] , *q.shape[-2:])
259
+ # q = q.transpose(0,2)
260
+ # q = q.contiguous()
261
+ # q = q.transpose(0,2)
262
+ # q = q.reshape( -1, *q.shape[-2:])
263
+ # return q
264
+
265
+
266
+ # if invert_spaces:
267
+
268
+ # q = flip(q)
269
+ # k = flip(k)
270
+ # v = flip(v)
271
+ # qkv_list = [q,k,v]
272
+ # del q,k,v
273
+
274
+
275
+
276
+ # x = sageattn_window_wrapper(qkv_list, lq, window= window) #.unsqueeze(0)
277
+
278
+ # if invert_spaces:
279
+ # x = flop(x)
280
+ # x = x.unsqueeze(0)
281
+
282
+
283
  elif attn=="sdpa":
284
  qkv_list = [q, k, v]
285
  del q, k , v
wan/modules/model.py CHANGED
@@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
  from diffusers.models.modeling_utils import ModelMixin
9
  import numpy as np
10
  from typing import Union,Optional
11
-
12
  from .attention import pay_attention
13
 
14
  __all__ = ['WanModel']
@@ -302,7 +302,7 @@ class WanT2VCrossAttention(WanSelfAttention):
302
  # compute attention
303
  qvl_list=[q, k, v]
304
  del q, k, v
305
- x = pay_attention(qvl_list, k_lens=context_lens)
306
 
307
  # output
308
  x = x.flatten(2)
@@ -716,7 +716,9 @@ class WanModel(ModelMixin, ConfigMixin):
716
  pipeline = None,
717
  current_step = 0,
718
  context2 = None,
719
- is_uncond=False
 
 
720
  ):
721
  r"""
722
  Forward pass through the diffusion model
@@ -755,6 +757,12 @@ class WanModel(ModelMixin, ConfigMixin):
755
  # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
756
 
757
  grid_sizes = [ list(u.shape[2:]) for u in x]
 
 
 
 
 
 
758
 
759
  x = [u.flatten(2).transpose(1, 2) for u in x]
760
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
@@ -843,7 +851,11 @@ class WanModel(ModelMixin, ConfigMixin):
843
  # context=context,
844
  context_lens=context_lens)
845
 
846
- for block in self.blocks:
 
 
 
 
847
  if pipeline._interrupt:
848
  if joint_pass:
849
  return None, None
 
8
  from diffusers.models.modeling_utils import ModelMixin
9
  import numpy as np
10
  from typing import Union,Optional
11
+ from mmgp import offload
12
  from .attention import pay_attention
13
 
14
  __all__ = ['WanModel']
 
302
  # compute attention
303
  qvl_list=[q, k, v]
304
  del q, k, v
305
+ x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
306
 
307
  # output
308
  x = x.flatten(2)
 
716
  pipeline = None,
717
  current_step = 0,
718
  context2 = None,
719
+ is_uncond=False,
720
+ max_steps = 0
721
+
722
  ):
723
  r"""
724
  Forward pass through the diffusion model
 
757
  # [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
758
 
759
  grid_sizes = [ list(u.shape[2:]) for u in x]
760
+ embed_sizes = grid_sizes[0]
761
+
762
+ offload.shared_state["embed_sizes"] = embed_sizes
763
+ offload.shared_state["step_no"] = current_step
764
+ offload.shared_state["max_steps"] = max_steps
765
+
766
 
767
  x = [u.flatten(2).transpose(1, 2) for u in x]
768
  seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
 
851
  # context=context,
852
  context_lens=context_lens)
853
 
854
+ for l, block in enumerate(self.blocks):
855
+ offload.shared_state["layer"] = l
856
+ if "refresh" in offload.shared_state:
857
+ del offload.shared_state["refresh"]
858
+ offload.shared_state["callback"](-1, -1, True)
859
  if pipeline._interrupt:
860
  if joint_pass:
861
  return None, None
wan/text2video.py CHANGED
@@ -243,6 +243,10 @@ class WanT2V:
243
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
244
  arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
245
 
 
 
 
 
246
  if self.model.enable_teacache:
247
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
248
  if callback != None:
@@ -250,7 +254,7 @@ class WanT2V:
250
  for i, t in enumerate(tqdm(timesteps)):
251
  latent_model_input = latents
252
  timestep = [t]
253
- offload.set_step_no_for_lora(i)
254
  timestep = torch.stack(timestep)
255
 
256
  # self.model.to(self.device)
 
243
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
244
  arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
245
 
246
+ # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
247
+ # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
248
+ # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
249
+
250
  if self.model.enable_teacache:
251
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
252
  if callback != None:
 
254
  for i, t in enumerate(tqdm(timesteps)):
255
  latent_model_input = latents
256
  timestep = [t]
257
+ offload.set_step_no_for_lora(self.model, i)
258
  timestep = torch.stack(timestep)
259
 
260
  # self.model.to(self.device)