Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
acc356f
1
Parent(s):
92f2b6e
Refactored Loras
Browse files- README.md +6 -0
- gradio_server.py +228 -104
- requirements.txt +1 -1
- wan/image2video.py +1 -1
- wan/modules/attention.py +87 -4
- wan/modules/model.py +16 -4
- wan/text2video.py +5 -1
README.md
CHANGED
|
@@ -19,6 +19,7 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
|
|
|
| 22 |
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
|
| 23 |
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
| 24 |
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
|
@@ -157,6 +158,11 @@ python gradio_server.py --attention sdpa
|
|
| 157 |
|
| 158 |
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
|
| 161 |
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
| 162 |
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
+
* Mar 14, 2025: 👋 Wan2.1GP v1.7: Lora Fest special edition: very fast loading / unload of loras so for those Loras collectors around. You will need to refresh the requirements *pip install -r requirements.txt*
|
| 23 |
* Mar 13, 2025: 👋 Wan2.1GP v1.6: Better Loras support, accelerated loading Loras. You will need to refresh the requirements *pip install -r requirements.txt*
|
| 24 |
* Mar 10, 2025: 👋 Wan2.1GP v1.5: Official Teachache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
|
| 25 |
* Mar 07, 2025: 👋 Wan2.1GP v1.4: Fix Pytorch compilation, now it is really 20% faster when activated
|
|
|
|
| 158 |
|
| 159 |
Every lora stored in the subfoler 'loras' for t2v and 'loras_i2v' will be automatically loaded. You will be then able to activate / desactive any of them when running the application by selecting them in the area below "Activated Loras" .
|
| 160 |
|
| 161 |
+
If you want to manage in differenta areas Loras for the 1.3B model and the 14B as they are not comptatible, just create the following subfolders:
|
| 162 |
+
- loras/1.3B
|
| 163 |
+
- loras/14B
|
| 164 |
+
|
| 165 |
+
|
| 166 |
For each activated Lora, you may specify a *multiplier* that is one float number that corresponds to its weight (default is 1.0) .The multipliers for each Lora should be separated by a space character or a carriage return. For instance:\
|
| 167 |
*1.2 0.8* means that the first lora will have a 1.2 multiplier and the second one will have 0.8.
|
| 168 |
|
gradio_server.py
CHANGED
|
@@ -170,6 +170,18 @@ def _parse_args():
|
|
| 170 |
|
| 171 |
return args
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
attention_modes_supported = get_attention_modes()
|
| 174 |
|
| 175 |
args = _parse_args()
|
|
@@ -238,9 +250,10 @@ if args.i2v:
|
|
| 238 |
|
| 239 |
lora_dir =args.lora_dir
|
| 240 |
if use_image2video and len(lora_dir)==0:
|
| 241 |
-
|
| 242 |
if len(lora_dir) ==0:
|
| 243 |
-
|
|
|
|
| 244 |
lora_preselected_preset = args.lora_preset
|
| 245 |
default_tea_cache = 0
|
| 246 |
# if args.fast : #or args.fastest
|
|
@@ -295,35 +308,42 @@ def sanitize_file_name(file_name, rep =""):
|
|
| 295 |
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
| 296 |
|
| 297 |
def extract_preset(lset_name, loras):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
lset_name = sanitize_file_name(lset_name)
|
| 299 |
if not lset_name.endswith(".lset"):
|
| 300 |
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
| 301 |
else:
|
| 302 |
lset_name_filename = os.path.join(lora_dir, lset_name )
|
| 303 |
-
|
| 304 |
if not os.path.isfile(lset_name_filename):
|
| 305 |
-
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
loras_choices.append(str(loras_choice_no))
|
| 320 |
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
|
|
|
| 327 |
|
| 328 |
def get_default_prompt(i2v):
|
| 329 |
if i2v:
|
|
@@ -332,7 +352,7 @@ def get_default_prompt(i2v):
|
|
| 332 |
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
| 333 |
|
| 334 |
|
| 335 |
-
def setup_loras(
|
| 336 |
loras =[]
|
| 337 |
loras_names = []
|
| 338 |
default_loras_choices = []
|
|
@@ -341,32 +361,35 @@ def setup_loras(pipe, lora_dir, lora_preselected_preset, split_linear_modules_m
|
|
| 341 |
default_lora_preset = ""
|
| 342 |
default_prompt = ""
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
-
if lora_dir != None :
|
| 348 |
-
if not os.path.isdir(lora_dir):
|
| 349 |
-
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
|
| 361 |
-
|
| 362 |
-
loras_names = [ Path(lora).stem for lora in loras ]
|
| 363 |
-
offload.load_loras_into_model(pipe["transformer"], loras, activate_all_loras=False, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
if len(default_prompt) == 0:
|
| 371 |
default_prompt = get_default_prompt(use_image2video)
|
| 372 |
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
|
@@ -450,8 +473,8 @@ def load_models(i2v, lora_dir, lora_preselected_preset ):
|
|
| 450 |
kwargs["budgets"] = { "*" : "70%" }
|
| 451 |
|
| 452 |
|
| 453 |
-
|
| 454 |
-
|
| 455 |
|
| 456 |
|
| 457 |
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
|
@@ -542,7 +565,7 @@ def apply_changes( state,
|
|
| 542 |
state["config_new"] = server_config
|
| 543 |
state["config_old"] = old_server_config
|
| 544 |
|
| 545 |
-
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost
|
| 546 |
attention_mode = server_config["attention_mode"]
|
| 547 |
profile = server_config["profile"]
|
| 548 |
compile = server_config["compile"]
|
|
@@ -560,7 +583,7 @@ def apply_changes( state,
|
|
| 560 |
offloadobj.release()
|
| 561 |
offloadobj = None
|
| 562 |
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
| 563 |
-
|
| 564 |
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 565 |
|
| 566 |
|
|
@@ -590,7 +613,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
|
|
| 590 |
num_inference_steps, flow_shift = get_default_flow(trans_file)
|
| 591 |
|
| 592 |
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
|
| 595 |
|
| 596 |
from moviepy.editor import ImageSequenceClip
|
|
@@ -603,23 +630,32 @@ def save_video(final_frames, output_path, fps=24):
|
|
| 603 |
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
| 604 |
|
| 605 |
def build_callback(state, pipe, progress, status, num_inference_steps):
|
| 606 |
-
def callback(step_idx, latents):
|
| 607 |
-
|
| 608 |
-
if
|
| 609 |
-
|
| 610 |
-
status_msg = status + " - Aborting"
|
| 611 |
-
elif step_idx == num_inference_steps:
|
| 612 |
-
status_msg = status + " - VAE Decoding"
|
| 613 |
else:
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
return callback
|
| 619 |
|
| 620 |
def abort_generation(state):
|
| 621 |
if "in_progress" in state:
|
| 622 |
state["abort"] = True
|
|
|
|
| 623 |
wan_model._interrupt= True
|
| 624 |
return gr.Button(interactive= False)
|
| 625 |
else:
|
|
@@ -634,11 +670,12 @@ def finalize_gallery(state):
|
|
| 634 |
if "in_progress" in state:
|
| 635 |
del state["in_progress"]
|
| 636 |
choice = state.get("selected",0)
|
| 637 |
-
|
|
|
|
| 638 |
time.sleep(0.2)
|
| 639 |
global gen_in_progress
|
| 640 |
gen_in_progress = False
|
| 641 |
-
return gr.Gallery(selected_index=choice), gr.Button(interactive= True)
|
| 642 |
|
| 643 |
def select_video(state , event_data: gr.EventData):
|
| 644 |
data= event_data._data
|
|
@@ -656,6 +693,32 @@ def expand_slist(slist, num_inference_steps ):
|
|
| 656 |
return new_slist
|
| 657 |
|
| 658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
def generate_video(
|
| 660 |
prompt,
|
| 661 |
negative_prompt,
|
|
@@ -733,6 +796,7 @@ def generate_video(
|
|
| 733 |
else:
|
| 734 |
VAE_tile_size = 128
|
| 735 |
|
|
|
|
| 736 |
|
| 737 |
global gen_in_progress
|
| 738 |
gen_in_progress = True
|
|
@@ -740,7 +804,7 @@ def generate_video(
|
|
| 740 |
if len(prompt) ==0:
|
| 741 |
return
|
| 742 |
prompts = prompt.replace("\r", "").split("\n")
|
| 743 |
-
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
| 744 |
if len(prompts) ==0:
|
| 745 |
return
|
| 746 |
if use_image2video:
|
|
@@ -808,9 +872,13 @@ def generate_video(
|
|
| 808 |
list_mult_choices_nums.append(float(mult))
|
| 809 |
if len(list_mult_choices_nums ) < len(loras_choices):
|
| 810 |
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
seed = None if seed == -1 else seed
|
| 815 |
# negative_prompt = "" # not applicable in the inference
|
| 816 |
|
|
@@ -825,7 +893,6 @@ def generate_video(
|
|
| 825 |
|
| 826 |
joint_pass = boost ==1
|
| 827 |
# TeaCache
|
| 828 |
-
trans = wan_model.model
|
| 829 |
trans.enable_teacache = tea_cache > 0
|
| 830 |
if trans.enable_teacache:
|
| 831 |
if use_image2video:
|
|
@@ -857,11 +924,23 @@ def generate_video(
|
|
| 857 |
os.makedirs(save_path, exist_ok=True)
|
| 858 |
video_no = 0
|
| 859 |
total_video = repeat_generation * len(prompts)
|
|
|
|
|
|
|
| 860 |
abort = False
|
| 861 |
start_time = time.time()
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
break
|
| 866 |
|
| 867 |
if trans.enable_teacache:
|
|
@@ -875,9 +954,12 @@ def generate_video(
|
|
| 875 |
|
| 876 |
video_no += 1
|
| 877 |
status = f"Video {video_no}/{total_video}"
|
|
|
|
|
|
|
|
|
|
| 878 |
progress(0, desc=status + " - Encoding Prompt" )
|
| 879 |
-
|
| 880 |
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
|
|
|
| 881 |
|
| 882 |
|
| 883 |
gc.collect()
|
|
@@ -887,7 +969,7 @@ def generate_video(
|
|
| 887 |
if use_image2video:
|
| 888 |
samples = wan_model.generate(
|
| 889 |
prompt,
|
| 890 |
-
image_to_continue[
|
| 891 |
frame_num=(video_length // 4)* 4 + 1,
|
| 892 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 893 |
shift=flow_shift,
|
|
@@ -923,6 +1005,7 @@ def generate_video(
|
|
| 923 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 924 |
os.remove(temp_filename)
|
| 925 |
offload.last_offload_obj.unload_all()
|
|
|
|
| 926 |
# if compile:
|
| 927 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 928 |
# torch.compiler.reset()
|
|
@@ -967,9 +1050,9 @@ def generate_video(
|
|
| 967 |
|
| 968 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
| 969 |
if os.name == 'nt':
|
| 970 |
-
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50])}.mp4"
|
| 971 |
else:
|
| 972 |
-
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100])}.mp4"
|
| 973 |
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
| 974 |
cache_video(
|
| 975 |
tensor=sample[None],
|
|
@@ -987,10 +1070,13 @@ def generate_video(
|
|
| 987 |
end_time = time.time()
|
| 988 |
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
| 989 |
seed += 1
|
|
|
|
| 990 |
|
| 991 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 992 |
os.remove(temp_filename)
|
| 993 |
gen_in_progress = False
|
|
|
|
|
|
|
| 994 |
|
| 995 |
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
| 996 |
|
|
@@ -998,19 +1084,19 @@ new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
|
| 998 |
def validate_delete_lset(lset_name):
|
| 999 |
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1000 |
gr.Info(f"Choose a Preset to delete")
|
| 1001 |
-
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
| 1002 |
else:
|
| 1003 |
-
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
| 1004 |
|
| 1005 |
def validate_save_lset(lset_name):
|
| 1006 |
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1007 |
gr.Info("Please enter a name for the preset")
|
| 1008 |
-
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
| 1009 |
else:
|
| 1010 |
-
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
| 1011 |
|
| 1012 |
def cancel_lset():
|
| 1013 |
-
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1014 |
|
| 1015 |
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
| 1016 |
global loras_presets
|
|
@@ -1047,7 +1133,7 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
|
|
| 1047 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
| 1048 |
lset_choices.append( (new_preset_msg, ""))
|
| 1049 |
|
| 1050 |
-
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1051 |
|
| 1052 |
def delete_lset(lset_name):
|
| 1053 |
global loras_presets
|
|
@@ -1065,23 +1151,57 @@ def delete_lset(lset_name):
|
|
| 1065 |
|
| 1066 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 1067 |
lset_choices.append((new_preset_msg, ""))
|
| 1068 |
-
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
|
|
|
|
|
|
|
| 1071 |
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
| 1072 |
|
| 1073 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
| 1074 |
gr.Info("Please choose a preset in the list or create one")
|
| 1075 |
else:
|
| 1076 |
-
loras_choices, loras_mult_choices, preset_prompt, full_prompt = extract_preset(lset_name, loras)
|
| 1077 |
-
if
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
|
|
|
|
|
|
|
|
|
| 1085 |
|
| 1086 |
return loras_choices, loras_mult_choices, prompt
|
| 1087 |
|
|
@@ -1094,21 +1214,21 @@ def create_demo():
|
|
| 1094 |
state = gr.State({})
|
| 1095 |
|
| 1096 |
if use_image2video:
|
| 1097 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
| 1098 |
else:
|
| 1099 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
| 1100 |
|
| 1101 |
-
gr.Markdown("<FONT SIZE=3>
|
| 1102 |
|
| 1103 |
if use_image2video and False:
|
| 1104 |
pass
|
| 1105 |
else:
|
| 1106 |
-
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :
|
| 1107 |
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
| 1108 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
| 1109 |
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
| 1110 |
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
| 1111 |
-
gr.Markdown("Please note that if your turn on compilation, the first
|
| 1112 |
|
| 1113 |
|
| 1114 |
# css = """<STYLE>
|
|
@@ -1302,6 +1422,7 @@ def create_demo():
|
|
| 1302 |
# with gr.Column():
|
| 1303 |
with gr.Row(height=17):
|
| 1304 |
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
|
|
|
| 1305 |
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
| 1306 |
save_lset_prompt_drop= gr.Dropdown(
|
| 1307 |
choices=[
|
|
@@ -1334,7 +1455,7 @@ def create_demo():
|
|
| 1334 |
with gr.Row(visible=False) as advanced_row:
|
| 1335 |
with gr.Column():
|
| 1336 |
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
|
| 1337 |
-
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Number of Generated
|
| 1338 |
with gr.Row():
|
| 1339 |
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
| 1340 |
with gr.Row():
|
|
@@ -1377,22 +1498,25 @@ def create_demo():
|
|
| 1377 |
label="Generated videos", show_label=False, elem_id="gallery"
|
| 1378 |
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
|
| 1379 |
generate_btn = gr.Button("Generate")
|
|
|
|
| 1380 |
abort_btn = gr.Button("Abort")
|
| 1381 |
|
| 1382 |
-
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1383 |
-
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1384 |
-
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1385 |
-
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1386 |
-
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
| 1387 |
|
| 1388 |
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
| 1389 |
|
|
|
|
|
|
|
| 1390 |
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
| 1391 |
|
| 1392 |
abort_btn.click(abort_generation,state,abort_btn )
|
| 1393 |
output.select(select_video, state, None )
|
| 1394 |
-
|
| 1395 |
-
generate_btn.click(
|
| 1396 |
fn=generate_video,
|
| 1397 |
inputs=[
|
| 1398 |
prompt,
|
|
@@ -1420,7 +1544,7 @@ def create_demo():
|
|
| 1420 |
).then(
|
| 1421 |
finalize_gallery,
|
| 1422 |
[state],
|
| 1423 |
-
[output , abort_btn]
|
| 1424 |
)
|
| 1425 |
|
| 1426 |
apply_btn.click(
|
|
@@ -1441,7 +1565,7 @@ def create_demo():
|
|
| 1441 |
).then(
|
| 1442 |
update_defaults,
|
| 1443 |
[state, num_inference_steps, flow_shift],
|
| 1444 |
-
[num_inference_steps, flow_shift, header]
|
| 1445 |
)
|
| 1446 |
|
| 1447 |
return demo
|
|
|
|
| 170 |
|
| 171 |
return args
|
| 172 |
|
| 173 |
+
def get_lora_dir(root_lora_dir):
|
| 174 |
+
if not use_image2video:
|
| 175 |
+
if "1.3B" in transformer_filename_t2v:
|
| 176 |
+
lora_dir_1_3B = os.path.join(root_lora_dir, "1.3B")
|
| 177 |
+
if os.path.isdir(lora_dir_1_3B ):
|
| 178 |
+
return lora_dir_1_3B
|
| 179 |
+
else:
|
| 180 |
+
lora_dir_14B = os.path.join(root_lora_dir, "14B")
|
| 181 |
+
if os.path.isdir(lora_dir_14B ):
|
| 182 |
+
return lora_dir_14B
|
| 183 |
+
return root_lora_dir
|
| 184 |
+
|
| 185 |
attention_modes_supported = get_attention_modes()
|
| 186 |
|
| 187 |
args = _parse_args()
|
|
|
|
| 250 |
|
| 251 |
lora_dir =args.lora_dir
|
| 252 |
if use_image2video and len(lora_dir)==0:
|
| 253 |
+
root_lora_dir =args.lora_dir_i2v
|
| 254 |
if len(lora_dir) ==0:
|
| 255 |
+
root_lora_dir = "loras_i2v" if use_image2video else "loras"
|
| 256 |
+
lora_dir = get_lora_dir(root_lora_dir)
|
| 257 |
lora_preselected_preset = args.lora_preset
|
| 258 |
default_tea_cache = 0
|
| 259 |
# if args.fast : #or args.fastest
|
|
|
|
| 308 |
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
| 309 |
|
| 310 |
def extract_preset(lset_name, loras):
|
| 311 |
+
loras_choices = []
|
| 312 |
+
loras_choices_files = []
|
| 313 |
+
loras_mult_choices = ""
|
| 314 |
+
prompt =""
|
| 315 |
+
full_prompt =""
|
| 316 |
lset_name = sanitize_file_name(lset_name)
|
| 317 |
if not lset_name.endswith(".lset"):
|
| 318 |
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
| 319 |
else:
|
| 320 |
lset_name_filename = os.path.join(lora_dir, lset_name )
|
| 321 |
+
error = ""
|
| 322 |
if not os.path.isfile(lset_name_filename):
|
| 323 |
+
error = f"Preset '{lset_name}' not found "
|
| 324 |
+
else:
|
| 325 |
+
missing_loras = []
|
| 326 |
|
| 327 |
+
with open(lset_name_filename, "r", encoding="utf-8") as reader:
|
| 328 |
+
text = reader.read()
|
| 329 |
+
lset = json.loads(text)
|
| 330 |
|
| 331 |
+
loras_choices_files = lset["loras"]
|
| 332 |
+
for lora_file in loras_choices_files:
|
| 333 |
+
choice = os.path.join(lora_dir, lora_file)
|
| 334 |
+
if choice not in loras:
|
| 335 |
+
missing_loras.append(lora_file)
|
| 336 |
+
else:
|
| 337 |
+
loras_choice_no = loras.index(choice)
|
| 338 |
+
loras_choices.append(str(loras_choice_no))
|
|
|
|
| 339 |
|
| 340 |
+
if len(missing_loras) > 0:
|
| 341 |
+
error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}"
|
| 342 |
+
|
| 343 |
+
loras_mult_choices = lset["loras_mult"]
|
| 344 |
+
prompt = lset.get("prompt", "")
|
| 345 |
+
full_prompt = lset.get("full_prompt", False)
|
| 346 |
+
return loras_choices, loras_mult_choices, prompt, full_prompt, error
|
| 347 |
|
| 348 |
def get_default_prompt(i2v):
|
| 349 |
if i2v:
|
|
|
|
| 352 |
return "A large orange octopus is seen resting on the bottom of the ocean floor, blending in with the sandy and rocky terrain. Its tentacles are spread out around its body, and its eyes are closed. The octopus is unaware of a king crab that is crawling towards it from behind a rock, its claws raised and ready to attack. The crab is brown and spiny, with long legs and antennae. The scene is captured from a wide angle, showing the vastness and depth of the ocean. The water is clear and blue, with rays of sunlight filtering through. The shot is sharp and crisp, with a high dynamic range. The octopus and the crab are in focus, while the background is slightly blurred, creating a depth of field effect."
|
| 353 |
|
| 354 |
|
| 355 |
+
def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
|
| 356 |
loras =[]
|
| 357 |
loras_names = []
|
| 358 |
default_loras_choices = []
|
|
|
|
| 361 |
default_lora_preset = ""
|
| 362 |
default_prompt = ""
|
| 363 |
|
| 364 |
+
from pathlib import Path
|
| 365 |
+
|
| 366 |
+
if lora_dir != None :
|
| 367 |
+
if not os.path.isdir(lora_dir):
|
| 368 |
+
raise Exception("--lora-dir should be a path to a directory that contains Loras")
|
| 369 |
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
+
if lora_dir != None:
|
| 372 |
+
import glob
|
| 373 |
+
dir_loras = glob.glob( os.path.join(lora_dir , "*.sft") ) + glob.glob( os.path.join(lora_dir , "*.safetensors") )
|
| 374 |
+
dir_loras.sort()
|
| 375 |
+
loras += [element for element in dir_loras if element not in loras ]
|
| 376 |
|
| 377 |
+
dir_presets = glob.glob( os.path.join(lora_dir , "*.lset") )
|
| 378 |
+
dir_presets.sort()
|
| 379 |
+
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
| 380 |
|
| 381 |
+
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
|
|
|
|
|
|
| 382 |
|
| 383 |
+
if len(loras) > 0:
|
| 384 |
+
loras_names = [ Path(lora).stem for lora in loras ]
|
| 385 |
+
|
| 386 |
+
if len(lora_preselected_preset) > 0:
|
| 387 |
+
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
| 388 |
+
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
| 389 |
+
default_lora_preset = lora_preselected_preset
|
| 390 |
+
default_loras_choices, default_loras_multis_str, default_prompt, _ , error = extract_preset(default_lora_preset, loras)
|
| 391 |
+
if len(error) > 0:
|
| 392 |
+
print(error[:200])
|
| 393 |
if len(default_prompt) == 0:
|
| 394 |
default_prompt = get_default_prompt(use_image2video)
|
| 395 |
return loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
|
|
|
| 473 |
kwargs["budgets"] = { "*" : "70%" }
|
| 474 |
|
| 475 |
|
| 476 |
+
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
|
| 477 |
+
loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = setup_loras(pipe["transformer"], lora_dir, lora_preselected_preset, None)
|
| 478 |
|
| 479 |
|
| 480 |
return wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets
|
|
|
|
| 565 |
state["config_new"] = server_config
|
| 566 |
state["config_old"] = old_server_config
|
| 567 |
|
| 568 |
+
global attention_mode, profile, compile, transformer_filename_t2v, transformer_filename_i2v, text_encoder_filename, vae_config, boost, lora_dir
|
| 569 |
attention_mode = server_config["attention_mode"]
|
| 570 |
profile = server_config["profile"]
|
| 571 |
compile = server_config["compile"]
|
|
|
|
| 583 |
offloadobj.release()
|
| 584 |
offloadobj = None
|
| 585 |
yield "<DIV ALIGN=CENTER>Please wait while the new configuration is being applied</DIV>"
|
| 586 |
+
lora_dir = get_lora_dir(root_lora_dir)
|
| 587 |
wan_model, offloadobj, loras, loras_names, default_loras_choices, default_loras_multis_str, default_prompt, default_lora_preset, loras_presets = load_models(use_image2video, lora_dir, lora_preselected_preset )
|
| 588 |
|
| 589 |
|
|
|
|
| 613 |
num_inference_steps, flow_shift = get_default_flow(trans_file)
|
| 614 |
|
| 615 |
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
| 616 |
+
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
| 617 |
+
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 618 |
+
lset_choices.append( (new_preset_msg, ""))
|
| 619 |
+
|
| 620 |
+
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
|
| 621 |
|
| 622 |
|
| 623 |
from moviepy.editor import ImageSequenceClip
|
|
|
|
| 630 |
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
| 631 |
|
| 632 |
def build_callback(state, pipe, progress, status, num_inference_steps):
|
| 633 |
+
def callback(step_idx, latents, read_state = False):
|
| 634 |
+
status = state["progress_status"]
|
| 635 |
+
if read_state:
|
| 636 |
+
phase, step_idx = state["progress_phase"]
|
|
|
|
|
|
|
|
|
|
| 637 |
else:
|
| 638 |
+
step_idx += 1
|
| 639 |
+
if state.get("abort", False):
|
| 640 |
+
# pipe._interrupt = True
|
| 641 |
+
phase = " - Aborting"
|
| 642 |
+
elif step_idx == num_inference_steps:
|
| 643 |
+
phase = " - VAE Decoding"
|
| 644 |
+
else:
|
| 645 |
+
phase = " - Denoising"
|
| 646 |
+
state["progress_phase"] = (phase, step_idx)
|
| 647 |
+
status_msg = status + phase
|
| 648 |
+
if step_idx >= 0:
|
| 649 |
+
progress( (step_idx , num_inference_steps) , status_msg , num_inference_steps)
|
| 650 |
+
else:
|
| 651 |
+
progress(0, status_msg)
|
| 652 |
|
| 653 |
return callback
|
| 654 |
|
| 655 |
def abort_generation(state):
|
| 656 |
if "in_progress" in state:
|
| 657 |
state["abort"] = True
|
| 658 |
+
state["extra_orders"] = 0
|
| 659 |
wan_model._interrupt= True
|
| 660 |
return gr.Button(interactive= False)
|
| 661 |
else:
|
|
|
|
| 670 |
if "in_progress" in state:
|
| 671 |
del state["in_progress"]
|
| 672 |
choice = state.get("selected",0)
|
| 673 |
+
|
| 674 |
+
state["extra_orders"] = 0
|
| 675 |
time.sleep(0.2)
|
| 676 |
global gen_in_progress
|
| 677 |
gen_in_progress = False
|
| 678 |
+
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
|
| 679 |
|
| 680 |
def select_video(state , event_data: gr.EventData):
|
| 681 |
data= event_data._data
|
|
|
|
| 693 |
return new_slist
|
| 694 |
|
| 695 |
|
| 696 |
+
def one_more_video(state):
|
| 697 |
+
extra_orders = state.get("extra_orders", 0)
|
| 698 |
+
extra_orders += 1
|
| 699 |
+
state["extra_orders"] = extra_orders
|
| 700 |
+
prompts_max = state["prompts_max"]
|
| 701 |
+
prompt_no = state["prompt_no"]
|
| 702 |
+
video_no = state["video_no"]
|
| 703 |
+
total_video = state["total_video"]
|
| 704 |
+
# total_video += (prompts_max- prompt_no)
|
| 705 |
+
total_video += 1
|
| 706 |
+
total_generation = state["total_generation"] + extra_orders
|
| 707 |
+
state["total_video"] = total_video
|
| 708 |
+
|
| 709 |
+
state["progress_status"] = f"Video {video_no}/{total_video}"
|
| 710 |
+
offload.shared_state["refresh"] = 1
|
| 711 |
+
# if (prompts_max - prompt_no) > 1:
|
| 712 |
+
# gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for the next {prompts_max - prompt_no} prompts")
|
| 713 |
+
# else:
|
| 714 |
+
gr.Info(f"An extra video generation is planned for a total of {total_generation} videos for this prompt")
|
| 715 |
+
|
| 716 |
+
return state
|
| 717 |
+
|
| 718 |
+
def prepare_generate_video():
|
| 719 |
+
|
| 720 |
+
return gr.Button(visible= False), gr.Checkbox(visible= True)
|
| 721 |
+
|
| 722 |
def generate_video(
|
| 723 |
prompt,
|
| 724 |
negative_prompt,
|
|
|
|
| 796 |
else:
|
| 797 |
VAE_tile_size = 128
|
| 798 |
|
| 799 |
+
trans = wan_model.model
|
| 800 |
|
| 801 |
global gen_in_progress
|
| 802 |
gen_in_progress = True
|
|
|
|
| 804 |
if len(prompt) ==0:
|
| 805 |
return
|
| 806 |
prompts = prompt.replace("\r", "").split("\n")
|
| 807 |
+
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
| 808 |
if len(prompts) ==0:
|
| 809 |
return
|
| 810 |
if use_image2video:
|
|
|
|
| 872 |
list_mult_choices_nums.append(float(mult))
|
| 873 |
if len(list_mult_choices_nums ) < len(loras_choices):
|
| 874 |
list_mult_choices_nums += [1.0] * ( len(loras_choices) - len(list_mult_choices_nums ) )
|
| 875 |
+
loras_selected = [ lora for i, lora in enumerate(loras) if str(i) in loras_choices]
|
| 876 |
+
pinnedLora = False # profile !=5
|
| 877 |
+
offload.load_loras_into_model(trans, loras_selected, list_mult_choices_nums, activate_all_loras=True, pinnedLora=pinnedLora, split_linear_modules_map = None)
|
| 878 |
+
errors = trans._loras_errors
|
| 879 |
+
if len(errors) > 0:
|
| 880 |
+
error_files = [msg for _ , msg in errors]
|
| 881 |
+
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
| 882 |
seed = None if seed == -1 else seed
|
| 883 |
# negative_prompt = "" # not applicable in the inference
|
| 884 |
|
|
|
|
| 893 |
|
| 894 |
joint_pass = boost ==1
|
| 895 |
# TeaCache
|
|
|
|
| 896 |
trans.enable_teacache = tea_cache > 0
|
| 897 |
if trans.enable_teacache:
|
| 898 |
if use_image2video:
|
|
|
|
| 924 |
os.makedirs(save_path, exist_ok=True)
|
| 925 |
video_no = 0
|
| 926 |
total_video = repeat_generation * len(prompts)
|
| 927 |
+
state["total_video"] = total_video
|
| 928 |
+
extra_generation = 0
|
| 929 |
abort = False
|
| 930 |
start_time = time.time()
|
| 931 |
+
state["prompts_max"] = len(prompts)
|
| 932 |
+
for no, prompt in enumerate(prompts):
|
| 933 |
+
repeat_no = 0
|
| 934 |
+
state["prompt_no"] = no
|
| 935 |
+
extra_generation = 0
|
| 936 |
+
while True:
|
| 937 |
+
extra_orders = state.get("extra_orders",0)
|
| 938 |
+
state["extra_orders"] = 0
|
| 939 |
+
extra_generation += extra_orders
|
| 940 |
+
state["total_generation"] = repeat_generation + extra_generation
|
| 941 |
+
# total_video += (len(prompts)- no) * extra_orders
|
| 942 |
+
total_video += extra_orders
|
| 943 |
+
if abort or repeat_no >= (repeat_generation + extra_generation):
|
| 944 |
break
|
| 945 |
|
| 946 |
if trans.enable_teacache:
|
|
|
|
| 954 |
|
| 955 |
video_no += 1
|
| 956 |
status = f"Video {video_no}/{total_video}"
|
| 957 |
+
state["video_no"] = video_no
|
| 958 |
+
state["progress_status"] = status
|
| 959 |
+
state["progress_phase"] = (" - Encoding Prompt", -1 )
|
| 960 |
progress(0, desc=status + " - Encoding Prompt" )
|
|
|
|
| 961 |
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
| 962 |
+
offload.shared_state["callback"] = callback
|
| 963 |
|
| 964 |
|
| 965 |
gc.collect()
|
|
|
|
| 969 |
if use_image2video:
|
| 970 |
samples = wan_model.generate(
|
| 971 |
prompt,
|
| 972 |
+
image_to_continue[no].convert('RGB'),
|
| 973 |
frame_num=(video_length // 4)* 4 + 1,
|
| 974 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 975 |
shift=flow_shift,
|
|
|
|
| 1005 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1006 |
os.remove(temp_filename)
|
| 1007 |
offload.last_offload_obj.unload_all()
|
| 1008 |
+
offload.unload_loras_from_model(trans)
|
| 1009 |
# if compile:
|
| 1010 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 1011 |
# torch.compiler.reset()
|
|
|
|
| 1050 |
|
| 1051 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
| 1052 |
if os.name == 'nt':
|
| 1053 |
+
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:50]).strip()}.mp4"
|
| 1054 |
else:
|
| 1055 |
+
file_name = f"{time_flag}_seed{seed}_{sanitize_file_name(prompt[:100]).strip()}.mp4"
|
| 1056 |
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
| 1057 |
cache_video(
|
| 1058 |
tensor=sample[None],
|
|
|
|
| 1070 |
end_time = time.time()
|
| 1071 |
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
| 1072 |
seed += 1
|
| 1073 |
+
repeat_no += 1
|
| 1074 |
|
| 1075 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1076 |
os.remove(temp_filename)
|
| 1077 |
gen_in_progress = False
|
| 1078 |
+
offload.unload_loras_from_model(trans)
|
| 1079 |
+
|
| 1080 |
|
| 1081 |
new_preset_msg = "Enter a Name for a Lora Preset or Choose One Above"
|
| 1082 |
|
|
|
|
| 1084 |
def validate_delete_lset(lset_name):
|
| 1085 |
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1086 |
gr.Info(f"Choose a Preset to delete")
|
| 1087 |
+
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
| 1088 |
else:
|
| 1089 |
+
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
| 1090 |
|
| 1091 |
def validate_save_lset(lset_name):
|
| 1092 |
if len(lset_name) == 0 or lset_name == new_preset_msg:
|
| 1093 |
gr.Info("Please enter a name for the preset")
|
| 1094 |
+
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
| 1095 |
else:
|
| 1096 |
+
return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True)
|
| 1097 |
|
| 1098 |
def cancel_lset():
|
| 1099 |
+
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1100 |
|
| 1101 |
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
| 1102 |
global loras_presets
|
|
|
|
| 1133 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
| 1134 |
lset_choices.append( (new_preset_msg, ""))
|
| 1135 |
|
| 1136 |
+
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1137 |
|
| 1138 |
def delete_lset(lset_name):
|
| 1139 |
global loras_presets
|
|
|
|
| 1151 |
|
| 1152 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 1153 |
lset_choices.append((new_preset_msg, ""))
|
| 1154 |
+
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
| 1155 |
+
|
| 1156 |
+
def refresh_lora_list(lset_name, loras_choices):
|
| 1157 |
+
global loras,loras_names, loras_presets
|
| 1158 |
+
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
|
| 1159 |
+
|
| 1160 |
+
loras, loras_names, _, _, _, _, loras_presets = setup_loras(wan_model.model, lora_dir, lora_preselected_preset, None)
|
| 1161 |
+
gc.collect()
|
| 1162 |
+
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
| 1163 |
+
new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) }
|
| 1164 |
+
lora_names_selected = []
|
| 1165 |
+
for lora in prev_lora_names_selected:
|
| 1166 |
+
lora_id = new_loras_dict.get(lora, None)
|
| 1167 |
+
if lora_id!= None:
|
| 1168 |
+
lora_names_selected.append(lora_id)
|
| 1169 |
+
|
| 1170 |
+
lset_choices = [ (preset, preset) for preset in loras_presets]
|
| 1171 |
+
lset_choices.append((new_preset_msg, ""))
|
| 1172 |
+
if lset_name in loras_presets:
|
| 1173 |
+
pos = loras_presets.index(lset_name)
|
| 1174 |
+
else:
|
| 1175 |
+
pos = len(loras_presets)
|
| 1176 |
+
lset_name =""
|
| 1177 |
+
|
| 1178 |
+
errors = wan_model.model._loras_errors
|
| 1179 |
+
if len(errors) > 0:
|
| 1180 |
+
error_files = [path for path, _ in errors]
|
| 1181 |
+
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
| 1182 |
+
else:
|
| 1183 |
+
gr.Info("Lora List has been refreshed")
|
| 1184 |
|
| 1185 |
|
| 1186 |
+
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
| 1187 |
+
|
| 1188 |
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
| 1189 |
|
| 1190 |
if len(lset_name) == 0 or lset_name== new_preset_msg:
|
| 1191 |
gr.Info("Please choose a preset in the list or create one")
|
| 1192 |
else:
|
| 1193 |
+
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
|
| 1194 |
+
if len(error) > 0:
|
| 1195 |
+
gr.Info(error)
|
| 1196 |
+
else:
|
| 1197 |
+
if full_prompt:
|
| 1198 |
+
prompt = preset_prompt
|
| 1199 |
+
elif len(preset_prompt) > 0:
|
| 1200 |
+
prompts = prompt.replace("\r", "").split("\n")
|
| 1201 |
+
prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")]
|
| 1202 |
+
prompt = "\n".join(prompts)
|
| 1203 |
+
prompt = preset_prompt + '\n' + prompt
|
| 1204 |
+
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
| 1205 |
|
| 1206 |
return loras_choices, loras_mult_choices, prompt
|
| 1207 |
|
|
|
|
| 1214 |
state = gr.State({})
|
| 1215 |
|
| 1216 |
if use_image2video:
|
| 1217 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Image To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
| 1218 |
else:
|
| 1219 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v1.7 - AI Text To Video Generator (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</H1></div>")
|
| 1220 |
|
| 1221 |
+
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP by <B>DeepBeepMeep</B>, a super fast and low VRAM Video Generator !</FONT>")
|
| 1222 |
|
| 1223 |
if use_image2video and False:
|
| 1224 |
pass
|
| 1225 |
else:
|
| 1226 |
+
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
|
| 1227 |
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
| 1228 |
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
| 1229 |
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
| 1230 |
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
| 1231 |
+
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
| 1232 |
|
| 1233 |
|
| 1234 |
# css = """<STYLE>
|
|
|
|
| 1422 |
# with gr.Column():
|
| 1423 |
with gr.Row(height=17):
|
| 1424 |
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
| 1425 |
+
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
|
| 1426 |
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
| 1427 |
save_lset_prompt_drop= gr.Dropdown(
|
| 1428 |
choices=[
|
|
|
|
| 1455 |
with gr.Row(visible=False) as advanced_row:
|
| 1456 |
with gr.Column():
|
| 1457 |
seed = gr.Slider(-1, 999999999, value=-1, step=1, label="Seed (-1 for random)")
|
| 1458 |
+
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
|
| 1459 |
with gr.Row():
|
| 1460 |
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
| 1461 |
with gr.Row():
|
|
|
|
| 1498 |
label="Generated videos", show_label=False, elem_id="gallery"
|
| 1499 |
, columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= False)
|
| 1500 |
generate_btn = gr.Button("Generate")
|
| 1501 |
+
onemore_btn = gr.Button("One More Please !", visible= False)
|
| 1502 |
abort_btn = gr.Button("Abort")
|
| 1503 |
|
| 1504 |
+
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1505 |
+
confirm_save_lset_btn.click(save_lset, inputs=[lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 1506 |
+
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1507 |
+
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
| 1508 |
+
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
| 1509 |
|
| 1510 |
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
| 1511 |
|
| 1512 |
+
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 1513 |
+
|
| 1514 |
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
| 1515 |
|
| 1516 |
abort_btn.click(abort_generation,state,abort_btn )
|
| 1517 |
output.select(select_video, state, None )
|
| 1518 |
+
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
|
| 1519 |
+
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]).then(
|
| 1520 |
fn=generate_video,
|
| 1521 |
inputs=[
|
| 1522 |
prompt,
|
|
|
|
| 1544 |
).then(
|
| 1545 |
finalize_gallery,
|
| 1546 |
[state],
|
| 1547 |
+
[output , abort_btn, generate_btn, onemore_btn]
|
| 1548 |
)
|
| 1549 |
|
| 1550 |
apply_btn.click(
|
|
|
|
| 1565 |
).then(
|
| 1566 |
update_defaults,
|
| 1567 |
[state, num_inference_steps, flow_shift],
|
| 1568 |
+
[num_inference_steps, flow_shift, header, lset_name , loras_choices ]
|
| 1569 |
)
|
| 1570 |
|
| 1571 |
return demo
|
requirements.txt
CHANGED
|
@@ -16,5 +16,5 @@ gradio>=5.0.0
|
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
-
mmgp==3.2.
|
| 20 |
peft==0.14.0
|
|
|
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
+
mmgp==3.2.7
|
| 20 |
peft==0.14.0
|
wan/image2video.py
CHANGED
|
@@ -331,7 +331,7 @@ class WanI2V:
|
|
| 331 |
callback(-1, None)
|
| 332 |
|
| 333 |
for i, t in enumerate(tqdm(timesteps)):
|
| 334 |
-
offload.set_step_no_for_lora(i)
|
| 335 |
latent_model_input = [latent.to(self.device)]
|
| 336 |
timestep = [t]
|
| 337 |
|
|
|
|
| 331 |
callback(-1, None)
|
| 332 |
|
| 333 |
for i, t in enumerate(tqdm(timesteps)):
|
| 334 |
+
offload.set_step_no_for_lora(self.model, i)
|
| 335 |
latent_model_input = [latent.to(self.device)]
|
| 336 |
timestep = [t]
|
| 337 |
|
wan/modules/attention.py
CHANGED
|
@@ -60,6 +60,30 @@ try:
|
|
| 60 |
except ImportError:
|
| 61 |
sageattn = None
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
@torch.compiler.disable()
|
| 65 |
def sdpa_wrapper(
|
|
@@ -119,7 +143,8 @@ def pay_attention(
|
|
| 119 |
deterministic=False,
|
| 120 |
dtype=torch.bfloat16,
|
| 121 |
version=None,
|
| 122 |
-
force_attention= None
|
|
|
|
| 123 |
):
|
| 124 |
"""
|
| 125 |
q: [B, Lq, Nq, C1].
|
|
@@ -194,9 +219,67 @@ def pay_attention(
|
|
| 194 |
max_seqlen_kv=lk,
|
| 195 |
).unflatten(0, (b, lq))
|
| 196 |
elif attn=="sage2":
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
elif attn=="sdpa":
|
| 201 |
qkv_list = [q, k, v]
|
| 202 |
del q, k , v
|
|
|
|
| 60 |
except ImportError:
|
| 61 |
sageattn = None
|
| 62 |
|
| 63 |
+
# # try:
|
| 64 |
+
# if True:
|
| 65 |
+
# from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
|
| 66 |
+
# @torch.compiler.disable()
|
| 67 |
+
# def sageattn_window_wrapper(
|
| 68 |
+
# qkv_list,
|
| 69 |
+
# attention_length,
|
| 70 |
+
# window
|
| 71 |
+
# ):
|
| 72 |
+
# q,k, v = qkv_list
|
| 73 |
+
# padding_length = q.shape[0] -attention_length
|
| 74 |
+
# q = q[:attention_length, :, : ].unsqueeze(0)
|
| 75 |
+
# k = k[:attention_length, :, : ].unsqueeze(0)
|
| 76 |
+
# v = v[:attention_length, :, : ].unsqueeze(0)
|
| 77 |
+
# o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
|
| 78 |
+
# del q, k ,v
|
| 79 |
+
# qkv_list.clear()
|
| 80 |
+
|
| 81 |
+
# if padding_length > 0:
|
| 82 |
+
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
| 83 |
+
|
| 84 |
+
# return o
|
| 85 |
+
# # except ImportError:
|
| 86 |
+
# # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
|
| 87 |
|
| 88 |
@torch.compiler.disable()
|
| 89 |
def sdpa_wrapper(
|
|
|
|
| 143 |
deterministic=False,
|
| 144 |
dtype=torch.bfloat16,
|
| 145 |
version=None,
|
| 146 |
+
force_attention= None,
|
| 147 |
+
cross_attn= False
|
| 148 |
):
|
| 149 |
"""
|
| 150 |
q: [B, Lq, Nq, C1].
|
|
|
|
| 219 |
max_seqlen_kv=lk,
|
| 220 |
).unflatten(0, (b, lq))
|
| 221 |
elif attn=="sage2":
|
| 222 |
+
import math
|
| 223 |
+
if cross_attn or True:
|
| 224 |
+
qkv_list = [q,k,v]
|
| 225 |
+
del q,k,v
|
| 226 |
+
|
| 227 |
+
x = sageattn_wrapper(qkv_list, lq).unsqueeze(0)
|
| 228 |
+
# else:
|
| 229 |
+
# layer = offload.shared_state["layer"]
|
| 230 |
+
# embed_sizes = offload.shared_state["embed_sizes"]
|
| 231 |
+
# current_step = offload.shared_state["step_no"]
|
| 232 |
+
# max_steps = offload.shared_state["max_steps"]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
|
| 236 |
+
|
| 237 |
+
# window = 0
|
| 238 |
+
# start_window_step = int(max_steps * 0.4)
|
| 239 |
+
# start_layer = 10
|
| 240 |
+
# if (layer < start_layer ) or current_step <start_window_step:
|
| 241 |
+
# window = 0
|
| 242 |
+
# else:
|
| 243 |
+
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
| 244 |
+
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
|
| 245 |
+
# window = math.ceil(coef* nb_latents)
|
| 246 |
+
|
| 247 |
+
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
|
| 248 |
+
|
| 249 |
+
# def flip(q):
|
| 250 |
+
# q = q.reshape(*embed_sizes, *q.shape[-2:])
|
| 251 |
+
# q = q.transpose(0,2)
|
| 252 |
+
# q = q.contiguous()
|
| 253 |
+
# q = q.transpose(0,2)
|
| 254 |
+
# q = q.reshape( -1, *q.shape[-2:])
|
| 255 |
+
# return q
|
| 256 |
+
|
| 257 |
+
# def flop(q):
|
| 258 |
+
# q = q.reshape(embed_sizes[2], embed_sizes[1], embed_sizes[0] , *q.shape[-2:])
|
| 259 |
+
# q = q.transpose(0,2)
|
| 260 |
+
# q = q.contiguous()
|
| 261 |
+
# q = q.transpose(0,2)
|
| 262 |
+
# q = q.reshape( -1, *q.shape[-2:])
|
| 263 |
+
# return q
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# if invert_spaces:
|
| 267 |
+
|
| 268 |
+
# q = flip(q)
|
| 269 |
+
# k = flip(k)
|
| 270 |
+
# v = flip(v)
|
| 271 |
+
# qkv_list = [q,k,v]
|
| 272 |
+
# del q,k,v
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# x = sageattn_window_wrapper(qkv_list, lq, window= window) #.unsqueeze(0)
|
| 277 |
+
|
| 278 |
+
# if invert_spaces:
|
| 279 |
+
# x = flop(x)
|
| 280 |
+
# x = x.unsqueeze(0)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
elif attn=="sdpa":
|
| 284 |
qkv_list = [q, k, v]
|
| 285 |
del q, k , v
|
wan/modules/model.py
CHANGED
|
@@ -8,7 +8,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
| 8 |
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
import numpy as np
|
| 10 |
from typing import Union,Optional
|
| 11 |
-
|
| 12 |
from .attention import pay_attention
|
| 13 |
|
| 14 |
__all__ = ['WanModel']
|
|
@@ -302,7 +302,7 @@ class WanT2VCrossAttention(WanSelfAttention):
|
|
| 302 |
# compute attention
|
| 303 |
qvl_list=[q, k, v]
|
| 304 |
del q, k, v
|
| 305 |
-
x = pay_attention(qvl_list, k_lens=context_lens)
|
| 306 |
|
| 307 |
# output
|
| 308 |
x = x.flatten(2)
|
|
@@ -716,7 +716,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 716 |
pipeline = None,
|
| 717 |
current_step = 0,
|
| 718 |
context2 = None,
|
| 719 |
-
is_uncond=False
|
|
|
|
|
|
|
| 720 |
):
|
| 721 |
r"""
|
| 722 |
Forward pass through the diffusion model
|
|
@@ -755,6 +757,12 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 755 |
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 756 |
|
| 757 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
|
| 759 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 760 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
|
@@ -843,7 +851,11 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 843 |
# context=context,
|
| 844 |
context_lens=context_lens)
|
| 845 |
|
| 846 |
-
for block in
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
if pipeline._interrupt:
|
| 848 |
if joint_pass:
|
| 849 |
return None, None
|
|
|
|
| 8 |
from diffusers.models.modeling_utils import ModelMixin
|
| 9 |
import numpy as np
|
| 10 |
from typing import Union,Optional
|
| 11 |
+
from mmgp import offload
|
| 12 |
from .attention import pay_attention
|
| 13 |
|
| 14 |
__all__ = ['WanModel']
|
|
|
|
| 302 |
# compute attention
|
| 303 |
qvl_list=[q, k, v]
|
| 304 |
del q, k, v
|
| 305 |
+
x = pay_attention(qvl_list, k_lens=context_lens, cross_attn= True)
|
| 306 |
|
| 307 |
# output
|
| 308 |
x = x.flatten(2)
|
|
|
|
| 716 |
pipeline = None,
|
| 717 |
current_step = 0,
|
| 718 |
context2 = None,
|
| 719 |
+
is_uncond=False,
|
| 720 |
+
max_steps = 0
|
| 721 |
+
|
| 722 |
):
|
| 723 |
r"""
|
| 724 |
Forward pass through the diffusion model
|
|
|
|
| 757 |
# [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 758 |
|
| 759 |
grid_sizes = [ list(u.shape[2:]) for u in x]
|
| 760 |
+
embed_sizes = grid_sizes[0]
|
| 761 |
+
|
| 762 |
+
offload.shared_state["embed_sizes"] = embed_sizes
|
| 763 |
+
offload.shared_state["step_no"] = current_step
|
| 764 |
+
offload.shared_state["max_steps"] = max_steps
|
| 765 |
+
|
| 766 |
|
| 767 |
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 768 |
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
|
|
|
| 851 |
# context=context,
|
| 852 |
context_lens=context_lens)
|
| 853 |
|
| 854 |
+
for l, block in enumerate(self.blocks):
|
| 855 |
+
offload.shared_state["layer"] = l
|
| 856 |
+
if "refresh" in offload.shared_state:
|
| 857 |
+
del offload.shared_state["refresh"]
|
| 858 |
+
offload.shared_state["callback"](-1, -1, True)
|
| 859 |
if pipeline._interrupt:
|
| 860 |
if joint_pass:
|
| 861 |
return None, None
|
wan/text2video.py
CHANGED
|
@@ -243,6 +243,10 @@ class WanT2V:
|
|
| 243 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 244 |
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
if self.model.enable_teacache:
|
| 247 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 248 |
if callback != None:
|
|
@@ -250,7 +254,7 @@ class WanT2V:
|
|
| 250 |
for i, t in enumerate(tqdm(timesteps)):
|
| 251 |
latent_model_input = latents
|
| 252 |
timestep = [t]
|
| 253 |
-
offload.set_step_no_for_lora(i)
|
| 254 |
timestep = torch.stack(timestep)
|
| 255 |
|
| 256 |
# self.model.to(self.device)
|
|
|
|
| 243 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 244 |
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 245 |
|
| 246 |
+
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 247 |
+
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 248 |
+
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 249 |
+
|
| 250 |
if self.model.enable_teacache:
|
| 251 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 252 |
if callback != None:
|
|
|
|
| 254 |
for i, t in enumerate(tqdm(timesteps)):
|
| 255 |
latent_model_input = latents
|
| 256 |
timestep = [t]
|
| 257 |
+
offload.set_step_no_for_lora(self.model, i)
|
| 258 |
timestep = torch.stack(timestep)
|
| 259 |
|
| 260 |
# self.model.to(self.device)
|