Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
0fd9e11
1
Parent(s):
7ecf9b0
Lora festival part 2, new macros, new user interface
Browse files- README.md +36 -3
- gradio_server.py +482 -124
- wan/modules/model.py +7 -2
- wan/utils/prompt_parser.py +291 -0
README.md
CHANGED
@@ -19,6 +19,13 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
19 |
|
20 |
|
21 |
## 🔥 Latest News!!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
* Mar 14, 2025: 👋 Wan2.1GP v1.7:
|
23 |
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
|
24 |
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
|
@@ -121,7 +128,11 @@ To run the text to video generator (in Low VRAM mode):
|
|
121 |
```bash
|
122 |
python gradio_server.py
|
123 |
#or
|
124 |
-
python gradio_server.py --t2v
|
|
|
|
|
|
|
|
|
125 |
|
126 |
```
|
127 |
|
@@ -191,10 +202,27 @@ python gradio_server.py --lora-preset mylorapreset.lset # where 'mylorapreset.l
|
|
191 |
|
192 |
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
### Command line parameters for Gradio Server
|
196 |
--i2v : launch the image to video generator\
|
197 |
-
--t2v : launch the text to video generator\
|
|
|
|
|
198 |
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
199 |
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
200 |
--lora-preset preset : name of preset gile (without the extension) to preload
|
@@ -208,7 +236,12 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|
208 |
--compile : turn on pytorch compilation\
|
209 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
210 |
--profile no : default (4) : no of profile between 1 and 5\
|
211 |
-
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
### Profiles (for power users only)
|
214 |
You can choose between 5 profiles, but two are really relevant here :
|
|
|
19 |
|
20 |
|
21 |
## 🔥 Latest News!!
|
22 |
+
* Marc 17 2022: 👋 Wan2.1GP v2.0: The Lora festival continues:
|
23 |
+
- Clearer user interface
|
24 |
+
- Download 30 Loras in one click to try them all (expand the info section)
|
25 |
+
- Very to use Loras as now Lora presets can input the subject (or other need terms) of the Lora so that you dont have to modify manually a prompt
|
26 |
+
- Added basic macro prompt language to prefill prompts with differnent values. With one prompt template, you can generate multiple prompts.
|
27 |
+
- New Multiple images prompts: you can now combine any number of images with any number of text promtps (need to launch the app with --multiple-images)
|
28 |
+
- New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
|
29 |
* Mar 14, 2025: 👋 Wan2.1GP v1.7:
|
30 |
- Lora Fest special edition: very fast loading / unload of loras for those Loras collectors around. You can also now add / remove loras in the Lora folder without restarting the app. You will need to refresh the requirements *pip install -r requirements.txt*
|
31 |
- Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
|
|
|
128 |
```bash
|
129 |
python gradio_server.py
|
130 |
#or
|
131 |
+
python gradio_server.py --t2v #launch the default text 2 video model
|
132 |
+
#or
|
133 |
+
python gradio_server.py --t2v-14B #for the 14B model
|
134 |
+
#or
|
135 |
+
python gradio_server.py --t2v-1-3B #for the 1.3B model
|
136 |
|
137 |
```
|
138 |
|
|
|
202 |
|
203 |
You will find prebuilt Loras on https://civitai.com/ or you will be able to build them with tools such as kohya or onetrainer.
|
204 |
|
205 |
+
### Macros (basic)
|
206 |
+
In *Advanced Mode*, you can starts prompt lines with a "!" , for instance:\
|
207 |
+
```
|
208 |
+
! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its", "her", "his"
|
209 |
+
In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
|
210 |
+
```
|
211 |
+
|
212 |
+
This will create automatically 3 prompts that will cause the generation of 3 videos:
|
213 |
+
```
|
214 |
+
In the video, a cat is presented. The cat is in a forest and looks at its watch.
|
215 |
+
In the video, a man is presented. The man is in a lake and looks at his watch.
|
216 |
+
In the video, a woman is presented. The woman is in a city and looks at her watch.
|
217 |
+
```
|
218 |
+
|
219 |
+
You can define multiple lines of macros. If there is only one macro line, the app will generate a simple user interface to enter the macro variables when getting back to *Normal Mode* (advanced mode turned off)
|
220 |
|
221 |
### Command line parameters for Gradio Server
|
222 |
--i2v : launch the image to video generator\
|
223 |
+
--t2v : launch the text to video generator (default defined in the configuration)\
|
224 |
+
--t2v-14B : launch the 14B model text to video generator\
|
225 |
+
--t2v-1-3B : launch the 1.3B model text to video generator\
|
226 |
--quantize-transformer bool: (default True) : enable / disable on the fly transformer quantization\
|
227 |
--lora-dir path : Path of directory that contains Loras in diffusers / safetensor format\
|
228 |
--lora-preset preset : name of preset gile (without the extension) to preload
|
|
|
236 |
--compile : turn on pytorch compilation\
|
237 |
--attention mode: force attention mode among, sdpa, flash, sage, sage2\
|
238 |
--profile no : default (4) : no of profile between 1 and 5\
|
239 |
+
--preload no : number in Megabytes to preload partially the diffusion model in VRAM , may offer slight speed gains especially on older hardware. Works only with profile 2 and 4.\
|
240 |
+
--seed no : set default seed value\
|
241 |
+
--frames no : set the default number of frames to generate\
|
242 |
+
--steps no : set the default number of denoising steps\
|
243 |
+
--check-loras : filter loras that are incompatible (will take a few seconds while refreshing the lora list or while starting the app)\
|
244 |
+
--advanced : turn on the advanced mode while launching the app
|
245 |
|
246 |
### Profiles (for power users only)
|
247 |
You can choose between 5 profiles, but two are really relevant here :
|
gradio_server.py
CHANGED
@@ -20,6 +20,9 @@ import gc
|
|
20 |
import traceback
|
21 |
import math
|
22 |
import asyncio
|
|
|
|
|
|
|
23 |
|
24 |
def _parse_args():
|
25 |
parser = argparse.ArgumentParser(
|
@@ -60,7 +63,7 @@ def _parse_args():
|
|
60 |
parser.add_argument(
|
61 |
"--lora-dir-i2v",
|
62 |
type=str,
|
63 |
-
default="
|
64 |
help="Path to a directory that contains Loras for i2v"
|
65 |
)
|
66 |
|
@@ -72,6 +75,13 @@ def _parse_args():
|
|
72 |
help="Path to a directory that contains Loras"
|
73 |
)
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
parser.add_argument(
|
77 |
"--lora-preset",
|
@@ -101,6 +111,34 @@ def _parse_args():
|
|
101 |
help="Verbose level"
|
102 |
)
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
parser.add_argument(
|
105 |
"--server-port",
|
106 |
type=str,
|
@@ -133,6 +171,18 @@ def _parse_args():
|
|
133 |
help="image to video mode"
|
134 |
)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
parser.add_argument(
|
137 |
"--compile",
|
138 |
action="store_true",
|
@@ -196,6 +246,13 @@ preload =int(args.preload)
|
|
196 |
force_profile_no = int(args.profile)
|
197 |
verbose_level = int(args.verbose)
|
198 |
quantizeTransformer = args.quantize_transformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
|
201 |
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
|
@@ -247,12 +304,26 @@ if args.t2v:
|
|
247 |
use_image2video = False
|
248 |
if args.i2v:
|
249 |
use_image2video = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
251 |
lora_dir =args.lora_dir
|
252 |
if use_image2video and len(lora_dir)==0:
|
253 |
-
|
254 |
if len(lora_dir) ==0:
|
255 |
root_lora_dir = "loras_i2v" if use_image2video else "loras"
|
|
|
|
|
256 |
lora_dir = get_lora_dir(root_lora_dir)
|
257 |
lora_preselected_preset = args.lora_preset
|
258 |
default_tea_cache = 0
|
@@ -378,7 +449,8 @@ def setup_loras(transformer, lora_dir, lora_preselected_preset, split_linear_mo
|
|
378 |
dir_presets.sort()
|
379 |
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
380 |
|
381 |
-
|
|
|
382 |
|
383 |
if len(loras) > 0:
|
384 |
loras_names = [ Path(lora).stem for lora in loras ]
|
@@ -492,7 +564,9 @@ def get_default_flow(model_filename):
|
|
492 |
return 3.0 if "480p" in model_filename else 5.0
|
493 |
|
494 |
def generate_header(model_filename, compile, attention_mode):
|
495 |
-
|
|
|
|
|
496 |
|
497 |
if "image" in model_filename:
|
498 |
model_name = "Wan2.1 image2video"
|
@@ -508,7 +582,8 @@ def generate_header(model_filename, compile, attention_mode):
|
|
508 |
|
509 |
if compile:
|
510 |
header += ", pytorch compilation ON"
|
511 |
-
header += ")
|
|
|
512 |
|
513 |
return header
|
514 |
|
@@ -591,19 +666,19 @@ def apply_changes( state,
|
|
591 |
|
592 |
# return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
|
593 |
|
594 |
-
def update_defaults(state, num_inference_steps,flow_shift):
|
595 |
if "config_changes" not in state:
|
596 |
return get_default_flow("")
|
597 |
changes = state["config_changes"]
|
598 |
server_config = state["config_new"]
|
599 |
old_server_config = state["config_old"]
|
600 |
-
|
601 |
if not use_image2video:
|
602 |
old_is_14B = "14B" in server_config["transformer_filename"]
|
603 |
new_is_14B = "14B" in old_server_config["transformer_filename"]
|
604 |
|
605 |
trans_file = server_config["transformer_filename"]
|
606 |
-
|
607 |
# num_inference_steps, flow_shift = get_default_flow(trans_file)
|
608 |
else:
|
609 |
old_is_720P = "720P" in server_config["transformer_filename_i2v"]
|
@@ -615,9 +690,11 @@ def update_defaults(state, num_inference_steps,flow_shift):
|
|
615 |
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
616 |
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
617 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
618 |
-
lset_choices.append( (
|
619 |
-
|
620 |
-
|
|
|
|
|
621 |
|
622 |
|
623 |
from moviepy.editor import ImageSequenceClip
|
@@ -661,9 +738,20 @@ def abort_generation(state):
|
|
661 |
else:
|
662 |
return gr.Button(interactive= True)
|
663 |
|
664 |
-
def refresh_gallery(state):
|
665 |
file_list = state.get("file_list", None)
|
666 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
def finalize_gallery(state):
|
669 |
choice = 0
|
@@ -675,7 +763,7 @@ def finalize_gallery(state):
|
|
675 |
time.sleep(0.2)
|
676 |
global gen_in_progress
|
677 |
gen_in_progress = False
|
678 |
-
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False)
|
679 |
|
680 |
def select_video(state , event_data: gr.EventData):
|
681 |
data= event_data._data
|
@@ -697,7 +785,9 @@ def one_more_video(state):
|
|
697 |
extra_orders = state.get("extra_orders", 0)
|
698 |
extra_orders += 1
|
699 |
state["extra_orders"] = extra_orders
|
700 |
-
prompts_max = state
|
|
|
|
|
701 |
prompt_no = state["prompt_no"]
|
702 |
video_no = state["video_no"]
|
703 |
total_video = state["total_video"]
|
@@ -730,6 +820,7 @@ def generate_video(
|
|
730 |
flow_shift,
|
731 |
embedded_guidance_scale,
|
732 |
repeat_generation,
|
|
|
733 |
tea_cache,
|
734 |
tea_cache_start_step_perc,
|
735 |
loras_choices,
|
@@ -759,8 +850,11 @@ def generate_video(
|
|
759 |
elif attention_mode in attention_modes_supported:
|
760 |
attn = attention_mode
|
761 |
else:
|
762 |
-
|
|
|
763 |
|
|
|
|
|
764 |
width, height = resolution.split("x")
|
765 |
width, height = int(width), int(height)
|
766 |
|
@@ -768,17 +862,18 @@ def generate_video(
|
|
768 |
slg_layers = None
|
769 |
if use_image2video:
|
770 |
if "480p" in transformer_filename_i2v and width * height > 848*480:
|
771 |
-
|
|
|
772 |
|
773 |
resolution = str(width) + "*" + str(height)
|
774 |
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
775 |
-
|
776 |
-
|
777 |
|
778 |
else:
|
779 |
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
|
780 |
-
|
781 |
-
|
782 |
|
783 |
offload.shared_state["_attention"] = attn
|
784 |
|
@@ -808,6 +903,9 @@ def generate_video(
|
|
808 |
temp_filename = None
|
809 |
if len(prompt) ==0:
|
810 |
return
|
|
|
|
|
|
|
811 |
prompts = prompt.replace("\r", "").split("\n")
|
812 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
813 |
if len(prompts) ==0:
|
@@ -818,22 +916,31 @@ def generate_video(
|
|
818 |
image_to_continue = [ tup[0] for tup in image_to_continue ]
|
819 |
else:
|
820 |
image_to_continue = [image_to_continue]
|
821 |
-
if
|
822 |
-
if len(prompts) % len(image_to_continue) !=0:
|
823 |
-
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
824 |
-
rep = len(prompts) // len(image_to_continue)
|
825 |
-
new_image_to_continue = []
|
826 |
-
for i, _ in enumerate(prompts):
|
827 |
-
new_image_to_continue.append(image_to_continue[i//rep] )
|
828 |
-
image_to_continue = new_image_to_continue
|
829 |
-
else:
|
830 |
-
if len(image_to_continue) % len(prompts) !=0:
|
831 |
-
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
832 |
-
rep = len(image_to_continue) // len(prompts)
|
833 |
new_prompts = []
|
834 |
-
|
835 |
-
|
|
|
|
|
836 |
prompts = new_prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
837 |
|
838 |
elif video_to_continue != None and len(video_to_continue) >0 :
|
839 |
input_image_or_video_path = video_to_continue
|
@@ -900,6 +1007,10 @@ def generate_video(
|
|
900 |
# TeaCache
|
901 |
trans.enable_teacache = tea_cache > 0
|
902 |
if trans.enable_teacache:
|
|
|
|
|
|
|
|
|
903 |
if use_image2video:
|
904 |
if '480p' in transformer_filename_i2v:
|
905 |
# teacache_thresholds = [0.13, .19, 0.26]
|
@@ -935,9 +1046,11 @@ def generate_video(
|
|
935 |
start_time = time.time()
|
936 |
state["prompts_max"] = len(prompts)
|
937 |
for no, prompt in enumerate(prompts):
|
|
|
938 |
repeat_no = 0
|
939 |
state["prompt_no"] = no
|
940 |
extra_generation = 0
|
|
|
941 |
while True:
|
942 |
extra_orders = state.get("extra_orders",0)
|
943 |
state["extra_orders"] = 0
|
@@ -950,8 +1063,6 @@ def generate_video(
|
|
950 |
|
951 |
if trans.enable_teacache:
|
952 |
trans.teacache_counter = 0
|
953 |
-
trans.teacache_multiplier = tea_cache
|
954 |
-
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
955 |
trans.num_steps = num_inference_steps
|
956 |
trans.teacache_skipped_steps = 0
|
957 |
trans.previous_residual_uncond = None
|
@@ -1035,6 +1146,7 @@ def generate_video(
|
|
1035 |
if any( keyword in frame.name for keyword in keyword_list):
|
1036 |
VRAM_crash = True
|
1037 |
break
|
|
|
1038 |
if VRAM_crash:
|
1039 |
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
1040 |
else:
|
@@ -1054,6 +1166,7 @@ def generate_video(
|
|
1054 |
if samples == None:
|
1055 |
end_time = time.time()
|
1056 |
abort = True
|
|
|
1057 |
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
1058 |
else:
|
1059 |
sample = samples.cpu()
|
@@ -1076,9 +1189,10 @@ def generate_video(
|
|
1076 |
print(f"New video saved to Path: "+video_path)
|
1077 |
file_list.append(video_path)
|
1078 |
if video_no < total_video:
|
1079 |
-
yield status
|
1080 |
else:
|
1081 |
end_time = time.time()
|
|
|
1082 |
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
1083 |
seed += 1
|
1084 |
repeat_no += 1
|
@@ -1089,18 +1203,22 @@ def generate_video(
|
|
1089 |
offload.unload_loras_from_model(trans)
|
1090 |
|
1091 |
|
1092 |
-
|
|
|
|
|
|
|
|
|
1093 |
|
1094 |
|
1095 |
def validate_delete_lset(lset_name):
|
1096 |
-
if len(lset_name) == 0 or lset_name ==
|
1097 |
gr.Info(f"Choose a Preset to delete")
|
1098 |
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
1099 |
else:
|
1100 |
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
1101 |
|
1102 |
def validate_save_lset(lset_name):
|
1103 |
-
if len(lset_name) == 0 or lset_name ==
|
1104 |
gr.Info("Please enter a name for the preset")
|
1105 |
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
1106 |
else:
|
@@ -1109,10 +1227,12 @@ def validate_save_lset(lset_name):
|
|
1109 |
def cancel_lset():
|
1110 |
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1111 |
|
1112 |
-
def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
1113 |
global loras_presets
|
1114 |
|
1115 |
-
if
|
|
|
|
|
1116 |
gr.Info("Please enter a name for the preset")
|
1117 |
lset_choices =[("Please enter a name for a Lora Preset","")]
|
1118 |
else:
|
@@ -1142,14 +1262,14 @@ def save_lset(lset_name, loras_choices, loras_mult_choices, prompt, save_lset_pr
|
|
1142 |
gr.Info(f"Lora Preset '{lset_name}' has been created")
|
1143 |
loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
|
1144 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
1145 |
-
lset_choices.append( (
|
1146 |
|
1147 |
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1148 |
|
1149 |
def delete_lset(lset_name):
|
1150 |
global loras_presets
|
1151 |
lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
|
1152 |
-
if len(lset_name) > 0 and lset_name !=
|
1153 |
if not os.path.isfile(lset_name_filename):
|
1154 |
raise gr.Error(f"Preset '{lset_name}' not found ")
|
1155 |
os.remove(lset_name_filename)
|
@@ -1161,7 +1281,7 @@ def delete_lset(lset_name):
|
|
1161 |
gr.Info(f"Choose a Preset to delete")
|
1162 |
|
1163 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
1164 |
-
lset_choices.append((
|
1165 |
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1166 |
|
1167 |
def refresh_lora_list(lset_name, loras_choices):
|
@@ -1179,15 +1299,15 @@ def refresh_lora_list(lset_name, loras_choices):
|
|
1179 |
lora_names_selected.append(lora_id)
|
1180 |
|
1181 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
1182 |
-
lset_choices.append((
|
1183 |
if lset_name in loras_presets:
|
1184 |
pos = loras_presets.index(lset_name)
|
1185 |
else:
|
1186 |
pos = len(loras_presets)
|
1187 |
lset_name =""
|
1188 |
|
1189 |
-
errors = wan_model.model
|
1190 |
-
if len(errors) > 0:
|
1191 |
error_files = [path for path, _ in errors]
|
1192 |
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
1193 |
else:
|
@@ -1196,9 +1316,11 @@ def refresh_lora_list(lset_name, loras_choices):
|
|
1196 |
|
1197 |
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
1198 |
|
1199 |
-
def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
1200 |
|
1201 |
-
|
|
|
|
|
1202 |
gr.Info("Please choose a preset in the list or create one")
|
1203 |
else:
|
1204 |
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
|
@@ -1213,34 +1335,221 @@ def apply_lset(lset_name, loras_choices, loras_mult_choices, prompt):
|
|
1213 |
prompt = "\n".join(prompts)
|
1214 |
prompt = preset_prompt + '\n' + prompt
|
1215 |
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
|
|
|
|
1216 |
|
1217 |
return loras_choices, loras_mult_choices, prompt
|
1218 |
|
1219 |
-
def create_demo():
|
1220 |
-
|
1221 |
-
default_inference_steps = 30
|
1222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1223 |
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
1224 |
-
with gr.Blocks() as demo:
|
1225 |
-
|
1226 |
|
1227 |
if use_image2video:
|
1228 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP>
|
1229 |
else:
|
1230 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP>
|
1231 |
|
1232 |
-
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP
|
1233 |
-
|
1234 |
-
if use_image2video and False:
|
1235 |
-
pass
|
1236 |
-
else:
|
1237 |
-
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
|
1238 |
-
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
1239 |
-
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
1240 |
-
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
1241 |
-
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
1242 |
-
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
1243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1244 |
|
1245 |
# css = """<STYLE>
|
1246 |
# h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
|
@@ -1370,8 +1679,36 @@ def create_demo():
|
|
1370 |
apply_btn = gr.Button("Apply Changes")
|
1371 |
|
1372 |
|
|
|
1373 |
with gr.Row():
|
1374 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1375 |
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
1376 |
if args.multiple_images:
|
1377 |
image_to_continue = gr.Gallery(
|
@@ -1380,8 +1717,30 @@ def create_demo():
|
|
1380 |
else:
|
1381 |
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
|
1382 |
|
1383 |
-
|
1384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1385 |
with gr.Row():
|
1386 |
if use_image2video:
|
1387 |
resolution = gr.Dropdown(
|
@@ -1417,64 +1776,51 @@ def create_demo():
|
|
1417 |
|
1418 |
with gr.Row():
|
1419 |
with gr.Column():
|
1420 |
-
video_length = gr.Slider(5, 193, value=81, step=4, label="Number of frames (16 = 1s)")
|
1421 |
with gr.Column():
|
1422 |
-
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps, step=1, label="Number of Inference Steps")
|
1423 |
|
1424 |
with gr.Row():
|
1425 |
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
|
1426 |
|
1427 |
|
1428 |
-
with gr.Row(visible= len(loras)>0):
|
1429 |
-
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(new_preset_msg, "")]
|
1430 |
-
with gr.Column(scale=5):
|
1431 |
-
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
|
1432 |
-
with gr.Column(scale=1):
|
1433 |
-
# with gr.Column():
|
1434 |
-
with gr.Row(height=17):
|
1435 |
-
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
1436 |
-
refresh_lora_btn = gr.Button("Refresh Lora List", size="sm", min_width= 1)
|
1437 |
-
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
1438 |
-
save_lset_prompt_drop= gr.Dropdown(
|
1439 |
-
choices=[
|
1440 |
-
("Save Prompt Comments Only", 0),
|
1441 |
-
("Save Full Prompt", 1)
|
1442 |
-
], show_label= False, container=False, visible= False
|
1443 |
-
)
|
1444 |
-
|
1445 |
-
|
1446 |
-
with gr.Row(height=17):
|
1447 |
-
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
|
1448 |
-
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
|
1449 |
-
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
1450 |
-
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
1451 |
-
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
1452 |
-
|
1453 |
-
|
1454 |
-
loras_choices = gr.Dropdown(
|
1455 |
-
choices=[
|
1456 |
-
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
1457 |
-
],
|
1458 |
-
value= default_loras_choices,
|
1459 |
-
multiselect= True,
|
1460 |
-
visible= len(loras)>0,
|
1461 |
-
label="Activated Loras"
|
1462 |
-
)
|
1463 |
-
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
|
1464 |
|
1465 |
-
show_advanced = gr.Checkbox(label="
|
1466 |
-
with gr.Row(visible=
|
1467 |
with gr.Column():
|
1468 |
-
seed = gr.Slider(-1, 999999999, value
|
1469 |
-
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
|
1470 |
with gr.Row():
|
1471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1472 |
with gr.Row():
|
1473 |
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
1474 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
1475 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
1476 |
with gr.Row():
|
1477 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1478 |
with gr.Row():
|
1479 |
tea_cache_setting = gr.Dropdown(
|
1480 |
choices=[
|
@@ -1491,6 +1837,7 @@ def create_demo():
|
|
1491 |
)
|
1492 |
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
|
1493 |
|
|
|
1494 |
RIFLEx_setting = gr.Dropdown(
|
1495 |
choices=[
|
1496 |
("Auto (ON if Video longer than 5s)", 0),
|
@@ -1503,7 +1850,7 @@ def create_demo():
|
|
1503 |
|
1504 |
|
1505 |
with gr.Row():
|
1506 |
-
gr.Markdown("Experimental: Skip Layer guidance,should improve video quality")
|
1507 |
with gr.Row():
|
1508 |
slg_switch = gr.Dropdown(
|
1509 |
choices=[
|
@@ -1529,33 +1876,43 @@ def create_demo():
|
|
1529 |
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
|
1530 |
|
1531 |
|
1532 |
-
show_advanced.change(fn=
|
|
|
1533 |
|
1534 |
with gr.Column():
|
1535 |
gen_status = gr.Text(label="Status", interactive= False)
|
1536 |
output = gr.Gallery(
|
1537 |
label="Generated videos", show_label=False, elem_id="gallery"
|
1538 |
-
, columns=[3], rows=[1], object_fit="contain", height=
|
1539 |
generate_btn = gr.Button("Generate")
|
1540 |
onemore_btn = gr.Button("One More Please !", visible= False)
|
1541 |
abort_btn = gr.Button("Abort")
|
|
|
1542 |
|
1543 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
1544 |
-
confirm_save_lset_btn.click(
|
|
|
1545 |
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
1546 |
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
1547 |
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
1548 |
|
1549 |
-
apply_lset_btn.click(apply_lset, inputs=[lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt])
|
|
|
|
|
1550 |
|
1551 |
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
|
|
|
|
1552 |
|
1553 |
-
gen_status.change(refresh_gallery, inputs = [state], outputs = output )
|
1554 |
|
1555 |
abort_btn.click(abort_generation,state,abort_btn )
|
1556 |
output.select(select_video, state, None )
|
1557 |
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
|
1558 |
-
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
|
|
|
|
|
|
|
1559 |
fn=generate_video,
|
1560 |
inputs=[
|
1561 |
prompt,
|
@@ -1568,6 +1925,7 @@ def create_demo():
|
|
1568 |
flow_shift,
|
1569 |
embedded_guidance_scale,
|
1570 |
repeat_generation,
|
|
|
1571 |
tea_cache_setting,
|
1572 |
tea_cache_start_step_perc,
|
1573 |
loras_choices,
|
@@ -1587,7 +1945,7 @@ def create_demo():
|
|
1587 |
).then(
|
1588 |
finalize_gallery,
|
1589 |
[state],
|
1590 |
-
[output , abort_btn, generate_btn, onemore_btn]
|
1591 |
)
|
1592 |
|
1593 |
apply_btn.click(
|
@@ -1607,7 +1965,7 @@ def create_demo():
|
|
1607 |
outputs= msg
|
1608 |
).then(
|
1609 |
update_defaults,
|
1610 |
-
[state, num_inference_steps, flow_shift],
|
1611 |
[num_inference_steps, flow_shift, header, lset_name , loras_choices ]
|
1612 |
)
|
1613 |
|
|
|
20 |
import traceback
|
21 |
import math
|
22 |
import asyncio
|
23 |
+
from wan.utils import prompt_parser
|
24 |
+
PROMPT_VARS_MAX = 10
|
25 |
+
|
26 |
|
27 |
def _parse_args():
|
28 |
parser = argparse.ArgumentParser(
|
|
|
63 |
parser.add_argument(
|
64 |
"--lora-dir-i2v",
|
65 |
type=str,
|
66 |
+
default="",
|
67 |
help="Path to a directory that contains Loras for i2v"
|
68 |
)
|
69 |
|
|
|
75 |
help="Path to a directory that contains Loras"
|
76 |
)
|
77 |
|
78 |
+
parser.add_argument(
|
79 |
+
"--check-loras",
|
80 |
+
type=str,
|
81 |
+
default=0,
|
82 |
+
help="Filter Loras that are not valid"
|
83 |
+
)
|
84 |
+
|
85 |
|
86 |
parser.add_argument(
|
87 |
"--lora-preset",
|
|
|
111 |
help="Verbose level"
|
112 |
)
|
113 |
|
114 |
+
parser.add_argument(
|
115 |
+
"--steps",
|
116 |
+
type=int,
|
117 |
+
default=0,
|
118 |
+
help="default denoising steps"
|
119 |
+
)
|
120 |
+
|
121 |
+
parser.add_argument(
|
122 |
+
"--frames",
|
123 |
+
type=int,
|
124 |
+
default=0,
|
125 |
+
help="default number of frames"
|
126 |
+
)
|
127 |
+
|
128 |
+
parser.add_argument(
|
129 |
+
"--seed",
|
130 |
+
type=int,
|
131 |
+
default=-1,
|
132 |
+
help="default generation seed"
|
133 |
+
)
|
134 |
+
|
135 |
+
parser.add_argument(
|
136 |
+
"--advanced",
|
137 |
+
action="store_true",
|
138 |
+
help="Access advanced options by default"
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
parser.add_argument(
|
143 |
"--server-port",
|
144 |
type=str,
|
|
|
171 |
help="image to video mode"
|
172 |
)
|
173 |
|
174 |
+
parser.add_argument(
|
175 |
+
"--t2v-14B",
|
176 |
+
action="store_true",
|
177 |
+
help="text to video mode 14B model"
|
178 |
+
)
|
179 |
+
|
180 |
+
parser.add_argument(
|
181 |
+
"--t2v-1-3B",
|
182 |
+
action="store_true",
|
183 |
+
help="text to video mode 1.3B model"
|
184 |
+
)
|
185 |
+
|
186 |
parser.add_argument(
|
187 |
"--compile",
|
188 |
action="store_true",
|
|
|
246 |
force_profile_no = int(args.profile)
|
247 |
verbose_level = int(args.verbose)
|
248 |
quantizeTransformer = args.quantize_transformer
|
249 |
+
default_seed = args.seed
|
250 |
+
default_number_frames = int(args.frames)
|
251 |
+
if default_number_frames > 0:
|
252 |
+
default_number_frames = ((default_number_frames - 1) // 4) * 4 + 1
|
253 |
+
default_inference_steps = args.steps
|
254 |
+
check_loras = args.check_loras ==1
|
255 |
+
advanced = args.advanced
|
256 |
|
257 |
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
|
258 |
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors"]
|
|
|
304 |
use_image2video = False
|
305 |
if args.i2v:
|
306 |
use_image2video = True
|
307 |
+
if args.t2v_14B:
|
308 |
+
use_image2video = False
|
309 |
+
if not "14B" in transformer_filename_t2v:
|
310 |
+
transformer_filename_t2v = transformer_choices_t2v[2]
|
311 |
+
lock_ui_transformer = False
|
312 |
+
|
313 |
+
if args.t2v_1_3B:
|
314 |
+
transformer_filename_t2v = transformer_choices_t2v[0]
|
315 |
+
use_image2video = False
|
316 |
+
lock_ui_transformer = False
|
317 |
+
|
318 |
+
only_allow_edit_in_advanced = False
|
319 |
|
320 |
lora_dir =args.lora_dir
|
321 |
if use_image2video and len(lora_dir)==0:
|
322 |
+
lora_dir =args.lora_dir_i2v
|
323 |
if len(lora_dir) ==0:
|
324 |
root_lora_dir = "loras_i2v" if use_image2video else "loras"
|
325 |
+
else:
|
326 |
+
root_lora_dir = lora_dir
|
327 |
lora_dir = get_lora_dir(root_lora_dir)
|
328 |
lora_preselected_preset = args.lora_preset
|
329 |
default_tea_cache = 0
|
|
|
449 |
dir_presets.sort()
|
450 |
loras_presets = [ Path(Path(file_path).parts[-1]).stem for file_path in dir_presets]
|
451 |
|
452 |
+
if check_loras:
|
453 |
+
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
|
454 |
|
455 |
if len(loras) > 0:
|
456 |
loras_names = [ Path(lora).stem for lora in loras ]
|
|
|
564 |
return 3.0 if "480p" in model_filename else 5.0
|
565 |
|
566 |
def generate_header(model_filename, compile, attention_mode):
|
567 |
+
|
568 |
+
|
569 |
+
header = "<div class='title-with-lines'><div class=line></div><h2>"
|
570 |
|
571 |
if "image" in model_filename:
|
572 |
model_name = "Wan2.1 image2video"
|
|
|
582 |
|
583 |
if compile:
|
584 |
header += ", pytorch compilation ON"
|
585 |
+
header += ") </h2><div class=line></div> "
|
586 |
+
|
587 |
|
588 |
return header
|
589 |
|
|
|
666 |
|
667 |
# return "<DIV ALIGN=CENTER>New Config file created. Please restart the Gradio Server</DIV>"
|
668 |
|
669 |
+
def update_defaults(state, num_inference_steps,flow_shift, lset_name , loras_choices):
|
670 |
if "config_changes" not in state:
|
671 |
return get_default_flow("")
|
672 |
changes = state["config_changes"]
|
673 |
server_config = state["config_new"]
|
674 |
old_server_config = state["config_old"]
|
675 |
+
t2v_changed = False
|
676 |
if not use_image2video:
|
677 |
old_is_14B = "14B" in server_config["transformer_filename"]
|
678 |
new_is_14B = "14B" in old_server_config["transformer_filename"]
|
679 |
|
680 |
trans_file = server_config["transformer_filename"]
|
681 |
+
t2v_changed = old_is_14B != new_is_14B
|
682 |
# num_inference_steps, flow_shift = get_default_flow(trans_file)
|
683 |
else:
|
684 |
old_is_720P = "720P" in server_config["transformer_filename_i2v"]
|
|
|
690 |
header = generate_header(trans_file, server_config["compile"], server_config["attention_mode"] )
|
691 |
new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)]
|
692 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
693 |
+
lset_choices.append( (get_new_preset_msg(advanced), ""))
|
694 |
+
if t2v_changed:
|
695 |
+
return num_inference_steps, flow_shift, header, gr.Dropdown(choices=lset_choices, value= ""), gr.Dropdown(choices=new_loras_choices, value= [])
|
696 |
+
else:
|
697 |
+
return num_inference_steps, flow_shift, header, lset_name , loras_choices
|
698 |
|
699 |
|
700 |
from moviepy.editor import ImageSequenceClip
|
|
|
738 |
else:
|
739 |
return gr.Button(interactive= True)
|
740 |
|
741 |
+
def refresh_gallery(state, txt):
|
742 |
file_list = state.get("file_list", None)
|
743 |
+
prompt = state.get("prompt", "")
|
744 |
+
if len(prompt) == 0:
|
745 |
+
return file_list, gr.Text(visible= False, value="")
|
746 |
+
else:
|
747 |
+
prompts_max = state.get("prompts_max",0)
|
748 |
+
prompt_no = state.get("prompt_no",0)
|
749 |
+
if prompts_max >1 :
|
750 |
+
label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
|
751 |
+
else:
|
752 |
+
label = f"Current Prompt"
|
753 |
+
return file_list, gr.Text(visible= True, value=prompt, label=label)
|
754 |
+
|
755 |
|
756 |
def finalize_gallery(state):
|
757 |
choice = 0
|
|
|
763 |
time.sleep(0.2)
|
764 |
global gen_in_progress
|
765 |
gen_in_progress = False
|
766 |
+
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Checkbox(visible= False), gr.Text(visible= False, value="")
|
767 |
|
768 |
def select_video(state , event_data: gr.EventData):
|
769 |
data= event_data._data
|
|
|
785 |
extra_orders = state.get("extra_orders", 0)
|
786 |
extra_orders += 1
|
787 |
state["extra_orders"] = extra_orders
|
788 |
+
prompts_max = state.get("prompts_max",0)
|
789 |
+
if prompts_max == 0:
|
790 |
+
return state
|
791 |
prompt_no = state["prompt_no"]
|
792 |
video_no = state["video_no"]
|
793 |
total_video = state["total_video"]
|
|
|
820 |
flow_shift,
|
821 |
embedded_guidance_scale,
|
822 |
repeat_generation,
|
823 |
+
multi_images_gen_type,
|
824 |
tea_cache,
|
825 |
tea_cache_start_step_perc,
|
826 |
loras_choices,
|
|
|
850 |
elif attention_mode in attention_modes_supported:
|
851 |
attn = attention_mode
|
852 |
else:
|
853 |
+
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed on your system. You should either install it or switch to the default 'sdpa' attention.")
|
854 |
+
return
|
855 |
|
856 |
+
if state.get("validate_success",0) != 1:
|
857 |
+
return
|
858 |
width, height = resolution.split("x")
|
859 |
width, height = int(width), int(height)
|
860 |
|
|
|
862 |
slg_layers = None
|
863 |
if use_image2video:
|
864 |
if "480p" in transformer_filename_i2v and width * height > 848*480:
|
865 |
+
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
866 |
+
return
|
867 |
|
868 |
resolution = str(width) + "*" + str(height)
|
869 |
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
870 |
+
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
871 |
+
return
|
872 |
|
873 |
else:
|
874 |
if "1.3B" in transformer_filename_t2v and width * height > 848*480:
|
875 |
+
gr.Info("You must use the 14B text to video model to generate videos with a resolution equivalent to 720P")
|
876 |
+
return
|
877 |
|
878 |
offload.shared_state["_attention"] = attn
|
879 |
|
|
|
903 |
temp_filename = None
|
904 |
if len(prompt) ==0:
|
905 |
return
|
906 |
+
prompt, errors = prompt_parser.process_template(prompt)
|
907 |
+
if len(errors) > 0:
|
908 |
+
gr.Info(f"Error processing prompt template: " + errors)
|
909 |
prompts = prompt.replace("\r", "").split("\n")
|
910 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
911 |
if len(prompts) ==0:
|
|
|
916 |
image_to_continue = [ tup[0] for tup in image_to_continue ]
|
917 |
else:
|
918 |
image_to_continue = [image_to_continue]
|
919 |
+
if multi_images_gen_type == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
new_prompts = []
|
921 |
+
new_image_to_continue = []
|
922 |
+
for i in range(len(prompts) * len(image_to_continue) ):
|
923 |
+
new_prompts.append( prompts[ i % len(prompts)] )
|
924 |
+
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
|
925 |
prompts = new_prompts
|
926 |
+
image_to_continue = new_image_to_continue
|
927 |
+
else:
|
928 |
+
if len(prompts) >= len(image_to_continue):
|
929 |
+
if len(prompts) % len(image_to_continue) !=0:
|
930 |
+
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
931 |
+
rep = len(prompts) // len(image_to_continue)
|
932 |
+
new_image_to_continue = []
|
933 |
+
for i, _ in enumerate(prompts):
|
934 |
+
new_image_to_continue.append(image_to_continue[i//rep] )
|
935 |
+
image_to_continue = new_image_to_continue
|
936 |
+
else:
|
937 |
+
if len(image_to_continue) % len(prompts) !=0:
|
938 |
+
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
939 |
+
rep = len(image_to_continue) // len(prompts)
|
940 |
+
new_prompts = []
|
941 |
+
for i, _ in enumerate(image_to_continue):
|
942 |
+
new_prompts.append( prompts[ i//rep] )
|
943 |
+
prompts = new_prompts
|
944 |
|
945 |
elif video_to_continue != None and len(video_to_continue) >0 :
|
946 |
input_image_or_video_path = video_to_continue
|
|
|
1007 |
# TeaCache
|
1008 |
trans.enable_teacache = tea_cache > 0
|
1009 |
if trans.enable_teacache:
|
1010 |
+
trans.teacache_multiplier = tea_cache
|
1011 |
+
trans.rel_l1_thresh = 0
|
1012 |
+
trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
|
1013 |
+
|
1014 |
if use_image2video:
|
1015 |
if '480p' in transformer_filename_i2v:
|
1016 |
# teacache_thresholds = [0.13, .19, 0.26]
|
|
|
1046 |
start_time = time.time()
|
1047 |
state["prompts_max"] = len(prompts)
|
1048 |
for no, prompt in enumerate(prompts):
|
1049 |
+
state["prompt"] = prompt
|
1050 |
repeat_no = 0
|
1051 |
state["prompt_no"] = no
|
1052 |
extra_generation = 0
|
1053 |
+
yield f"Prompt No{no}"
|
1054 |
while True:
|
1055 |
extra_orders = state.get("extra_orders",0)
|
1056 |
state["extra_orders"] = 0
|
|
|
1063 |
|
1064 |
if trans.enable_teacache:
|
1065 |
trans.teacache_counter = 0
|
|
|
|
|
1066 |
trans.num_steps = num_inference_steps
|
1067 |
trans.teacache_skipped_steps = 0
|
1068 |
trans.previous_residual_uncond = None
|
|
|
1146 |
if any( keyword in frame.name for keyword in keyword_list):
|
1147 |
VRAM_crash = True
|
1148 |
break
|
1149 |
+
state["prompt"] = ""
|
1150 |
if VRAM_crash:
|
1151 |
raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
|
1152 |
else:
|
|
|
1166 |
if samples == None:
|
1167 |
end_time = time.time()
|
1168 |
abort = True
|
1169 |
+
state["prompt"] = ""
|
1170 |
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
1171 |
else:
|
1172 |
sample = samples.cpu()
|
|
|
1189 |
print(f"New video saved to Path: "+video_path)
|
1190 |
file_list.append(video_path)
|
1191 |
if video_no < total_video:
|
1192 |
+
yield status
|
1193 |
else:
|
1194 |
end_time = time.time()
|
1195 |
+
state["prompt"] = ""
|
1196 |
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
1197 |
seed += 1
|
1198 |
repeat_no += 1
|
|
|
1203 |
offload.unload_loras_from_model(trans)
|
1204 |
|
1205 |
|
1206 |
+
def get_new_preset_msg(advanced = True):
|
1207 |
+
if advanced:
|
1208 |
+
return "Enter here a Name for a Lora Preset or Choose one in the List"
|
1209 |
+
else:
|
1210 |
+
return "Choose a Lora Preset in this List to Apply a Special Effect"
|
1211 |
|
1212 |
|
1213 |
def validate_delete_lset(lset_name):
|
1214 |
+
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
|
1215 |
gr.Info(f"Choose a Preset to delete")
|
1216 |
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False)
|
1217 |
else:
|
1218 |
return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True)
|
1219 |
|
1220 |
def validate_save_lset(lset_name):
|
1221 |
+
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
|
1222 |
gr.Info("Please enter a name for the preset")
|
1223 |
return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False)
|
1224 |
else:
|
|
|
1227 |
def cancel_lset():
|
1228 |
return gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1229 |
|
1230 |
+
def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_cbox):
|
1231 |
global loras_presets
|
1232 |
|
1233 |
+
if state.get("validate_success",0) == 0:
|
1234 |
+
pass
|
1235 |
+
if len(lset_name) == 0 or lset_name == get_new_preset_msg(True) or lset_name == get_new_preset_msg(False):
|
1236 |
gr.Info("Please enter a name for the preset")
|
1237 |
lset_choices =[("Please enter a name for a Lora Preset","")]
|
1238 |
else:
|
|
|
1262 |
gr.Info(f"Lora Preset '{lset_name}' has been created")
|
1263 |
loras_presets.append(Path(Path(lset_name_filename).parts[-1]).stem )
|
1264 |
lset_choices = [ ( preset, preset) for preset in loras_presets ]
|
1265 |
+
lset_choices.append( (get_new_preset_msg(), ""))
|
1266 |
|
1267 |
return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1268 |
|
1269 |
def delete_lset(lset_name):
|
1270 |
global loras_presets
|
1271 |
lset_name_filename = os.path.join(lora_dir, sanitize_file_name(lset_name) + ".lset" )
|
1272 |
+
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
|
1273 |
if not os.path.isfile(lset_name_filename):
|
1274 |
raise gr.Error(f"Preset '{lset_name}' not found ")
|
1275 |
os.remove(lset_name_filename)
|
|
|
1281 |
gr.Info(f"Choose a Preset to delete")
|
1282 |
|
1283 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
1284 |
+
lset_choices.append((get_new_preset_msg(), ""))
|
1285 |
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False)
|
1286 |
|
1287 |
def refresh_lora_list(lset_name, loras_choices):
|
|
|
1299 |
lora_names_selected.append(lora_id)
|
1300 |
|
1301 |
lset_choices = [ (preset, preset) for preset in loras_presets]
|
1302 |
+
lset_choices.append((get_new_preset_msg(advanced), ""))
|
1303 |
if lset_name in loras_presets:
|
1304 |
pos = loras_presets.index(lset_name)
|
1305 |
else:
|
1306 |
pos = len(loras_presets)
|
1307 |
lset_name =""
|
1308 |
|
1309 |
+
errors = getattr(wan_model.model, "_loras_errors", "")
|
1310 |
+
if errors !=None and len(errors) > 0:
|
1311 |
error_files = [path for path, _ in errors]
|
1312 |
gr.Info("Error while refreshing Lora List, invalid Lora files: " + ", ".join(error_files))
|
1313 |
else:
|
|
|
1316 |
|
1317 |
return gr.Dropdown(choices=lset_choices, value= lset_choices[pos][1]), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected)
|
1318 |
|
1319 |
+
def apply_lset(state, lset_name, loras_choices, loras_mult_choices, prompt):
|
1320 |
|
1321 |
+
state["apply_success"] = 0
|
1322 |
+
|
1323 |
+
if len(lset_name) == 0 or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False):
|
1324 |
gr.Info("Please choose a preset in the list or create one")
|
1325 |
else:
|
1326 |
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
|
|
|
1335 |
prompt = "\n".join(prompts)
|
1336 |
prompt = preset_prompt + '\n' + prompt
|
1337 |
gr.Info(f"Lora Preset '{lset_name}' has been applied")
|
1338 |
+
state["apply_success"] = 1
|
1339 |
+
state["wizard_prompt"] = 0
|
1340 |
|
1341 |
return loras_choices, loras_mult_choices, prompt
|
1342 |
|
|
|
|
|
|
|
1343 |
|
1344 |
+
def extract_prompt_from_wizard(state, prompt, wizard_prompt, allow_null_values, *args):
|
1345 |
+
|
1346 |
+
prompts = wizard_prompt.replace("\r" ,"").split("\n")
|
1347 |
+
|
1348 |
+
new_prompts = []
|
1349 |
+
macro_already_written = False
|
1350 |
+
for prompt in prompts:
|
1351 |
+
if not macro_already_written and not prompt.startswith("#") and "{" in prompt and "}" in prompt:
|
1352 |
+
variables = state["variables"]
|
1353 |
+
values = args[:len(variables)]
|
1354 |
+
macro = "! "
|
1355 |
+
for i, (variable, value) in enumerate(zip(variables, values)):
|
1356 |
+
if len(value) == 0 and not allow_null_values:
|
1357 |
+
return prompt, "You need to provide a value for '" + variable + "'"
|
1358 |
+
sub_values= [ "\"" + sub_value + "\"" for sub_value in value.split("\n") ]
|
1359 |
+
value = ",".join(sub_values)
|
1360 |
+
if i>0:
|
1361 |
+
macro += " : "
|
1362 |
+
macro += "{" + variable + "}"+ f"={value}"
|
1363 |
+
if len(variables) > 0:
|
1364 |
+
macro_already_written = True
|
1365 |
+
new_prompts.append(macro)
|
1366 |
+
new_prompts.append(prompt)
|
1367 |
+
else:
|
1368 |
+
new_prompts.append(prompt)
|
1369 |
+
|
1370 |
+
prompt = "\n".join(new_prompts)
|
1371 |
+
return prompt, ""
|
1372 |
+
|
1373 |
+
def validate_wizard_prompt(state, prompt, wizard_prompt, *args):
|
1374 |
+
state["validate_success"] = 0
|
1375 |
+
|
1376 |
+
if state.get("wizard_prompt",0) != 1:
|
1377 |
+
state["validate_success"] = 1
|
1378 |
+
return prompt
|
1379 |
+
|
1380 |
+
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, False, *args)
|
1381 |
+
if len(errors) > 0:
|
1382 |
+
gr.Info(errors)
|
1383 |
+
return prompt
|
1384 |
+
|
1385 |
+
state["validate_success"] = 1
|
1386 |
+
|
1387 |
+
return prompt
|
1388 |
+
|
1389 |
+
def fill_prompt_from_wizard(state, prompt, wizard_prompt, *args):
|
1390 |
+
|
1391 |
+
if state.get("wizard_prompt",0) == 1:
|
1392 |
+
prompt, errors = extract_prompt_from_wizard(state, prompt, wizard_prompt, True, *args)
|
1393 |
+
if len(errors) > 0:
|
1394 |
+
gr.Info(errors)
|
1395 |
+
|
1396 |
+
state["wizard_prompt"] = 0
|
1397 |
+
|
1398 |
+
return gr.Textbox(visible= True, value =prompt) , gr.Textbox(visible= False), gr.Column(visible = True), *[gr.Column(visible = False)] * 2, *[gr.Textbox(visible= False)] * PROMPT_VARS_MAX
|
1399 |
+
|
1400 |
+
def extract_wizard_prompt(prompt):
|
1401 |
+
variables = []
|
1402 |
+
values = {}
|
1403 |
+
prompts = prompt.replace("\r" ,"").split("\n")
|
1404 |
+
if sum(prompt.startswith("!") for prompt in prompts) > 1:
|
1405 |
+
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
|
1406 |
+
|
1407 |
+
new_prompts = []
|
1408 |
+
errors = ""
|
1409 |
+
for prompt in prompts:
|
1410 |
+
if prompt.startswith("!"):
|
1411 |
+
variables, errors = prompt_parser.extract_variable_names(prompt)
|
1412 |
+
if len(errors) > 0:
|
1413 |
+
return "", variables, values, "Error parsing Prompt templace: " + errors
|
1414 |
+
if len(variables) > PROMPT_VARS_MAX:
|
1415 |
+
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
|
1416 |
+
values, errors = prompt_parser.extract_variable_values(prompt)
|
1417 |
+
if len(errors) > 0:
|
1418 |
+
return "", variables, values, "Error parsing Prompt templace: " + errors
|
1419 |
+
else:
|
1420 |
+
variables_extra, errors = prompt_parser.extract_variable_names(prompt)
|
1421 |
+
if len(errors) > 0:
|
1422 |
+
return "", variables, values, "Error parsing Prompt templace: " + errors
|
1423 |
+
variables += variables_extra
|
1424 |
+
variables = [var for pos, var in enumerate(variables) if var not in variables[:pos]]
|
1425 |
+
if len(variables) > PROMPT_VARS_MAX:
|
1426 |
+
return "", variables, values, "Prompt is too complex for basic Prompt editor, switching to Advanced Prompt"
|
1427 |
+
|
1428 |
+
new_prompts.append(prompt)
|
1429 |
+
wizard_prompt = "\n".join(new_prompts)
|
1430 |
+
return wizard_prompt, variables, values, errors
|
1431 |
+
|
1432 |
+
def fill_wizard_prompt(state, prompt, wizard_prompt):
|
1433 |
+
def get_hidden_textboxes(num = PROMPT_VARS_MAX ):
|
1434 |
+
return [gr.Textbox(value="", visible=False)] * num
|
1435 |
+
|
1436 |
+
hidden_column = gr.Column(visible = False)
|
1437 |
+
visible_column = gr.Column(visible = True)
|
1438 |
+
|
1439 |
+
if advanced or state.get("apply_success") != 1:
|
1440 |
+
return prompt, wizard_prompt, gr.Column(), gr.Column(), hidden_column, *get_hidden_textboxes()
|
1441 |
+
prompt_parts= []
|
1442 |
+
state["wizard_prompt"] = 0
|
1443 |
+
|
1444 |
+
wizard_prompt, variables, values, errors = extract_wizard_prompt(prompt)
|
1445 |
+
if len(errors) > 0:
|
1446 |
+
gr.Info( errors )
|
1447 |
+
return gr.Textbox(prompt, visible=True), gr.Textbox(wizard_prompt, visible=False), visible_column, *[hidden_column] * 2, *get_hidden_textboxes()
|
1448 |
+
|
1449 |
+
for variable in variables:
|
1450 |
+
value = values.get(variable, "")
|
1451 |
+
prompt_parts.append(gr.Textbox( placeholder=variable, info= variable, visible= True, value= "\n".join(value) ))
|
1452 |
+
any_macro = len(variables) > 0
|
1453 |
+
|
1454 |
+
prompt_parts += get_hidden_textboxes(PROMPT_VARS_MAX-len(prompt_parts))
|
1455 |
+
|
1456 |
+
state["variables"] = variables
|
1457 |
+
state["wizard_prompt"] = 1
|
1458 |
+
|
1459 |
+
return gr.Textbox(prompt, visible = False), gr.Textbox(wizard_prompt, visible = True), hidden_column, visible_column, visible_column if any_macro else hidden_column, *prompt_parts
|
1460 |
+
|
1461 |
+
def switch_prompt_type(state, prompt, wizard_prompt, *prompt_vars):
|
1462 |
+
if advanced:
|
1463 |
+
return fill_prompt_from_wizard(state, prompt, wizard_prompt, *prompt_vars)
|
1464 |
+
else:
|
1465 |
+
state["apply_success"] = 1
|
1466 |
+
return fill_wizard_prompt(state, prompt, wizard_prompt)
|
1467 |
+
|
1468 |
+
|
1469 |
+
visible= False
|
1470 |
+
def switch_advanced(new_advanced, lset_name):
|
1471 |
+
global advanced
|
1472 |
+
advanced= new_advanced
|
1473 |
+
lset_choices = [ (preset, preset) for preset in loras_presets]
|
1474 |
+
lset_choices.append((get_new_preset_msg(advanced), ""))
|
1475 |
+
if lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False) or lset_name=="":
|
1476 |
+
lset_name = get_new_preset_msg(advanced)
|
1477 |
+
|
1478 |
+
if only_allow_edit_in_advanced:
|
1479 |
+
return gr.Row(visible=new_advanced), gr.Row(visible=new_advanced), gr.Button(visible=new_advanced), gr.Row(visible= not new_advanced), gr.Dropdown(choices=lset_choices, value= lset_name)
|
1480 |
+
else:
|
1481 |
+
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
1482 |
+
|
1483 |
+
def download_loras():
|
1484 |
+
from huggingface_hub import snapshot_download
|
1485 |
+
|
1486 |
+
|
1487 |
+
yield "<B><FONT SIZE=3>Please wait while the Loras are being downloaded</B></FONT>", *[gr.Column(visible=False)] * 2
|
1488 |
+
log_path = os.path.join(lora_dir, "log.txt")
|
1489 |
+
if not os.path.isfile(log_path):
|
1490 |
+
import shutil
|
1491 |
+
tmp_path = os.path.join(lora_dir, "tmp_lora_dowload")
|
1492 |
+
|
1493 |
+
import shutil, glob
|
1494 |
+
snapshot_download(repo_id="DeepBeepMeep/Wan2.1", allow_patterns="loras_i2v/*", local_dir= tmp_path)
|
1495 |
+
[shutil.move(f, lora_dir) for f in glob.glob(os.path.join(tmp_path, "loras_i2v", "*.*")) if not "README.txt" in f ]
|
1496 |
+
|
1497 |
+
|
1498 |
+
yield "<B><FONT SIZE=3>Loras have been completely downloaded</B></FONT>", *[gr.Column(visible=True)] * 2
|
1499 |
+
|
1500 |
+
from datetime import datetime
|
1501 |
+
dt = datetime.today().strftime('%Y-%m-%d')
|
1502 |
+
with open( log_path, "w", encoding="utf-8") as writer:
|
1503 |
+
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
|
1504 |
+
|
1505 |
+
return
|
1506 |
+
def create_demo():
|
1507 |
+
css= """
|
1508 |
+
.title-with-lines {
|
1509 |
+
display: flex;
|
1510 |
+
align-items: center;
|
1511 |
+
margin: 30px 0;
|
1512 |
+
}
|
1513 |
+
.line {
|
1514 |
+
flex-grow: 1;
|
1515 |
+
height: 1px;
|
1516 |
+
background-color: #333;
|
1517 |
+
}
|
1518 |
+
h2 {
|
1519 |
+
margin: 0 20px;
|
1520 |
+
white-space: nowrap;
|
1521 |
+
}
|
1522 |
+
"""
|
1523 |
default_flow_shift = get_default_flow(transformer_filename_i2v if use_image2video else transformer_filename_t2v)
|
1524 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="emerald", neutral_hue="slate", text_size= "md")) as demo:
|
1525 |
+
state_dict = {}
|
1526 |
|
1527 |
if use_image2video:
|
1528 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Image To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
1529 |
else:
|
1530 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v2.0 - Text To Video <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A> / <A HREF='https://github.com/Wan-Video/Wan2.1'>Original by Alibaba</A>)</FONT SIZE=3></H1></div>")
|
1531 |
|
1532 |
+
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1533 |
|
1534 |
+
with gr.Accordion("Click here for some Info on how to use Wan2GP and to download 20+ Loras", open = False):
|
1535 |
+
if use_image2video and False:
|
1536 |
+
pass
|
1537 |
+
else:
|
1538 |
+
gr.Markdown("The VRAM requirements will depend greatly of the resolution and the duration of the video, for instance :")
|
1539 |
+
gr.Markdown("- 848 x 480 with a 14B model: 80 frames (5s) : 8 GB of VRAM")
|
1540 |
+
gr.Markdown("- 848 x 480 with the 1.3B model: 80 frames (5s) : 5 GB of VRAM")
|
1541 |
+
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
1542 |
+
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
1543 |
+
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
1544 |
+
|
1545 |
+
if use_image2video:
|
1546 |
+
with gr.Row():
|
1547 |
+
with gr.Row(scale =3):
|
1548 |
+
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras). Dont't forget first to make a backup of your Loras just in case.")
|
1549 |
+
with gr.Row(scale =1):
|
1550 |
+
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
|
1551 |
+
with gr.Row():
|
1552 |
+
download_status = gr.Markdown()
|
1553 |
|
1554 |
# css = """<STYLE>
|
1555 |
# h2 { width: 100%; text-align: center; border-bottom: 1px solid #000; line-height: 0.1em; margin: 10px 0 20px; }
|
|
|
1679 |
apply_btn = gr.Button("Apply Changes")
|
1680 |
|
1681 |
|
1682 |
+
|
1683 |
with gr.Row():
|
1684 |
with gr.Column():
|
1685 |
+
with gr.Row(visible= len(loras)>0) as presets_column:
|
1686 |
+
lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
|
1687 |
+
with gr.Column(scale=6):
|
1688 |
+
lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=default_lora_preset)
|
1689 |
+
with gr.Column(scale=1):
|
1690 |
+
# with gr.Column():
|
1691 |
+
with gr.Row(height=17):
|
1692 |
+
apply_lset_btn = gr.Button("Apply Lora Preset", size="sm", min_width= 1)
|
1693 |
+
refresh_lora_btn = gr.Button("Refresh", size="sm", min_width= 1, visible=advanced or not only_allow_edit_in_advanced)
|
1694 |
+
# save_lset_prompt_cbox = gr.Checkbox(label="Save Prompt Comments in Preset", value=False, visible= False)
|
1695 |
+
save_lset_prompt_drop= gr.Dropdown(
|
1696 |
+
choices=[
|
1697 |
+
("Save Prompt Comments Only", 0),
|
1698 |
+
("Save Full Prompt", 1)
|
1699 |
+
], show_label= False, container=False, value =1, visible= False
|
1700 |
+
)
|
1701 |
+
|
1702 |
+
with gr.Row(height=17, visible=False) as refresh2_row:
|
1703 |
+
refresh_lora_btn2 = gr.Button("Refresh", size="sm", min_width= 1)
|
1704 |
+
|
1705 |
+
with gr.Row(height=17, visible=advanced or not only_allow_edit_in_advanced) as preset_buttons_rows:
|
1706 |
+
confirm_save_lset_btn = gr.Button("Go Ahead Save it !", size="sm", min_width= 1, visible=False)
|
1707 |
+
confirm_delete_lset_btn = gr.Button("Go Ahead Delete it !", size="sm", min_width= 1, visible=False)
|
1708 |
+
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
1709 |
+
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
1710 |
+
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
1711 |
+
|
1712 |
video_to_continue = gr.Video(label= "Video to continue", visible= use_image2video and False) #######
|
1713 |
if args.multiple_images:
|
1714 |
image_to_continue = gr.Gallery(
|
|
|
1717 |
else:
|
1718 |
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=use_image2video)
|
1719 |
|
1720 |
+
advanced_prompt = advanced
|
1721 |
+
prompt_vars=[]
|
1722 |
+
if not advanced_prompt:
|
1723 |
+
default_wizard_prompt, variables, values, errors = extract_wizard_prompt(default_prompt)
|
1724 |
+
advanced_prompt = len(errors) > 0
|
1725 |
+
|
1726 |
+
with gr.Column(visible= advanced_prompt) as prompt_column_advanced: #visible= False
|
1727 |
+
prompt = gr.Textbox( visible= advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments, ! lines = macros)", value=default_prompt, lines=3)
|
1728 |
+
|
1729 |
+
with gr.Column(visible=not advanced_prompt and len(variables) > 0) as prompt_column_wizard_vars: #visible= False
|
1730 |
+
gr.Markdown("<B>Please fill the following input fields to adapt automatically the Prompt:</B>")
|
1731 |
+
with gr.Row(): #visible= not advanced_prompt and len(variables) > 0
|
1732 |
+
if not advanced_prompt:
|
1733 |
+
for variable in variables:
|
1734 |
+
value = values.get(variable, "")
|
1735 |
+
prompt_vars.append(gr.Textbox( placeholder=variable, min_width=80, show_label= False, info= variable, visible= True, value= "\n".join(value) ))
|
1736 |
+
state_dict["wizard_prompt"] = 1
|
1737 |
+
state_dict["variables"] = variables
|
1738 |
+
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
|
1739 |
+
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
|
1740 |
+
with gr.Column(not advanced_prompt) as prompt_column_wizard:
|
1741 |
+
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
|
1742 |
+
state = gr.State(state_dict)
|
1743 |
+
|
1744 |
with gr.Row():
|
1745 |
if use_image2video:
|
1746 |
resolution = gr.Dropdown(
|
|
|
1776 |
|
1777 |
with gr.Row():
|
1778 |
with gr.Column():
|
1779 |
+
video_length = gr.Slider(5, 193, value=default_number_frames if default_number_frames > 0 else 81, step=4, label="Number of frames (16 = 1s)")
|
1780 |
with gr.Column():
|
1781 |
+
num_inference_steps = gr.Slider(1, 100, value= default_inference_steps if default_inference_steps > 0 else 30, step=1, label="Number of Inference Steps")
|
1782 |
|
1783 |
with gr.Row():
|
1784 |
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=use_image2video and False) #########
|
1785 |
|
1786 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1787 |
|
1788 |
+
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
|
1789 |
+
with gr.Row(visible=advanced) as advanced_row:
|
1790 |
with gr.Column():
|
1791 |
+
seed = gr.Slider(-1, 999999999, value=default_seed, step=1, label="Seed (-1 for random)")
|
|
|
1792 |
with gr.Row():
|
1793 |
+
repeat_generation = gr.Slider(1, 25.0, value=1.0, step=1, label="Default Number of Generated Videos per Prompt")
|
1794 |
+
multi_images_gen_type = gr.Dropdown(
|
1795 |
+
choices=[
|
1796 |
+
("Generate every combination of images and texts prompts", 0),
|
1797 |
+
("Match images and text prompts", 1),
|
1798 |
+
], visible= args.multiple_images, label= "Multiple Images as Prompts"
|
1799 |
+
)
|
1800 |
+
|
1801 |
with gr.Row():
|
1802 |
guidance_scale = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Guidance Scale", visible=True)
|
1803 |
embedded_guidance_scale = gr.Slider(1.0, 20.0, value=6.0, step=0.5, label="Embedded Guidance Scale", visible=False)
|
1804 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
1805 |
with gr.Row():
|
1806 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
1807 |
+
with gr.Row():
|
1808 |
+
gr.Markdown("<B>Loras can be used to create special effects on the video by mentioned a trigger word in the Prompt. You can save Loras combinations in presets.</B>")
|
1809 |
+
with gr.Column() as loras_column:
|
1810 |
+
loras_choices = gr.Dropdown(
|
1811 |
+
choices=[
|
1812 |
+
(lora_name, str(i) ) for i, lora_name in enumerate(loras_names)
|
1813 |
+
],
|
1814 |
+
value= default_loras_choices,
|
1815 |
+
multiselect= True,
|
1816 |
+
visible= len(loras)>0,
|
1817 |
+
label="Activated Loras"
|
1818 |
+
)
|
1819 |
+
loras_mult_choices = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by space characters or carriage returns, line that starts with # are ignored", value=default_loras_multis_str, visible= len(loras)>0 )
|
1820 |
+
|
1821 |
+
|
1822 |
+
with gr.Row():
|
1823 |
+
gr.Markdown("<B>Tea Cache accelerates by skipping intelligently some steps, the more steps are skipped the lower the quality of the video (Tea Cache consumes also VRAM)</B>")
|
1824 |
with gr.Row():
|
1825 |
tea_cache_setting = gr.Dropdown(
|
1826 |
choices=[
|
|
|
1837 |
)
|
1838 |
tea_cache_start_step_perc = gr.Slider(0, 100, value=0, step=1, label="Tea Cache starting moment in % of generation")
|
1839 |
|
1840 |
+
gr.Markdown("<B>With Riflex you can generate videos longer than 5s which is the default duration of videos used to train the model</B>")
|
1841 |
RIFLEx_setting = gr.Dropdown(
|
1842 |
choices=[
|
1843 |
("Auto (ON if Video longer than 5s)", 0),
|
|
|
1850 |
|
1851 |
|
1852 |
with gr.Row():
|
1853 |
+
gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
|
1854 |
with gr.Row():
|
1855 |
slg_switch = gr.Dropdown(
|
1856 |
choices=[
|
|
|
1876 |
slg_end_perc = gr.Slider(0, 100, value=90, step=1, label="Denoising Steps % end")
|
1877 |
|
1878 |
|
1879 |
+
show_advanced.change(fn=switch_advanced, inputs=[show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
1880 |
+
fn=switch_prompt_type, inputs = [state, prompt, wizard_prompt, *prompt_vars], outputs = [prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
1881 |
|
1882 |
with gr.Column():
|
1883 |
gen_status = gr.Text(label="Status", interactive= False)
|
1884 |
output = gr.Gallery(
|
1885 |
label="Generated videos", show_label=False, elem_id="gallery"
|
1886 |
+
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
1887 |
generate_btn = gr.Button("Generate")
|
1888 |
onemore_btn = gr.Button("One More Please !", visible= False)
|
1889 |
abort_btn = gr.Button("Abort")
|
1890 |
+
gen_info = gr.Text(label="Current prompt", visible= False , interactive= False) #gr.Markdown("Current prompt") #, ,
|
1891 |
|
1892 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
1893 |
+
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
1894 |
+
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
1895 |
delete_lset_btn.click(validate_delete_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
1896 |
confirm_delete_lset_btn.click(delete_lset, inputs=[lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
|
1897 |
cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ])
|
1898 |
|
1899 |
+
apply_lset_btn.click(apply_lset, inputs=[state, lset_name,loras_choices, loras_mult_choices, prompt], outputs=[loras_choices, loras_mult_choices, prompt]).then(
|
1900 |
+
fn = fill_wizard_prompt, inputs = [state, prompt, wizard_prompt], outputs = [ prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]
|
1901 |
+
)
|
1902 |
|
1903 |
refresh_lora_btn.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
1904 |
+
refresh_lora_btn2.click(refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
1905 |
+
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
1906 |
|
1907 |
+
gen_status.change(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info] )
|
1908 |
|
1909 |
abort_btn.click(abort_generation,state,abort_btn )
|
1910 |
output.select(select_video, state, None )
|
1911 |
onemore_btn.click(fn=one_more_video,inputs=[state], outputs= [state])
|
1912 |
+
generate_btn.click(fn=prepare_generate_video,inputs=[], outputs= [generate_btn, onemore_btn]
|
1913 |
+
).then(
|
1914 |
+
fn=validate_wizard_prompt, inputs =[state, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
|
1915 |
+
).then(
|
1916 |
fn=generate_video,
|
1917 |
inputs=[
|
1918 |
prompt,
|
|
|
1925 |
flow_shift,
|
1926 |
embedded_guidance_scale,
|
1927 |
repeat_generation,
|
1928 |
+
multi_images_gen_type,
|
1929 |
tea_cache_setting,
|
1930 |
tea_cache_start_step_perc,
|
1931 |
loras_choices,
|
|
|
1945 |
).then(
|
1946 |
finalize_gallery,
|
1947 |
[state],
|
1948 |
+
[output , abort_btn, generate_btn, onemore_btn, gen_info]
|
1949 |
)
|
1950 |
|
1951 |
apply_btn.click(
|
|
|
1965 |
outputs= msg
|
1966 |
).then(
|
1967 |
update_defaults,
|
1968 |
+
[state, num_inference_steps, flow_shift,lset_name , loras_choices],
|
1969 |
[num_inference_steps, flow_shift, header, lset_name , loras_choices ]
|
1970 |
)
|
1971 |
|
wan/modules/model.py
CHANGED
@@ -676,6 +676,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
676 |
|
677 |
best_threshold = 0.01
|
678 |
best_diff = 1000
|
|
|
679 |
target_nb_steps= int(len(timesteps) / speed_factor)
|
680 |
threshold = 0.01
|
681 |
while threshold <= 0.6:
|
@@ -686,6 +687,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
686 |
skip = False
|
687 |
if not (i<=start_step or i== len(timesteps)):
|
688 |
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
|
|
|
|
|
689 |
if accumulated_rel_l1_distance < threshold:
|
690 |
skip = True
|
691 |
else:
|
@@ -693,15 +696,17 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
693 |
previous_modulated_input = e_list[i]
|
694 |
if not skip:
|
695 |
nb_steps += 1
|
696 |
-
|
|
|
697 |
if diff < best_diff:
|
698 |
best_threshold = threshold
|
699 |
best_diff = diff
|
|
|
700 |
elif diff > best_diff:
|
701 |
break
|
702 |
threshold += 0.01
|
703 |
self.rel_l1_thresh = best_threshold
|
704 |
-
print(f"Tea Cache, best threshold found:{best_threshold} with gain x{len(timesteps)/(
|
705 |
return best_threshold
|
706 |
|
707 |
def forward(
|
|
|
676 |
|
677 |
best_threshold = 0.01
|
678 |
best_diff = 1000
|
679 |
+
best_signed_diff = 1000
|
680 |
target_nb_steps= int(len(timesteps) / speed_factor)
|
681 |
threshold = 0.01
|
682 |
while threshold <= 0.6:
|
|
|
687 |
skip = False
|
688 |
if not (i<=start_step or i== len(timesteps)):
|
689 |
accumulated_rel_l1_distance += rescale_func(((e_list[i]-previous_modulated_input).abs().mean() / previous_modulated_input.abs().mean()).cpu().item())
|
690 |
+
# self.accumulated_rel_l1_distance_even += rescale_func(((e_list[i]-self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean()).cpu().item())
|
691 |
+
|
692 |
if accumulated_rel_l1_distance < threshold:
|
693 |
skip = True
|
694 |
else:
|
|
|
696 |
previous_modulated_input = e_list[i]
|
697 |
if not skip:
|
698 |
nb_steps += 1
|
699 |
+
signed_diff = target_nb_steps - nb_steps
|
700 |
+
diff = abs(signed_diff)
|
701 |
if diff < best_diff:
|
702 |
best_threshold = threshold
|
703 |
best_diff = diff
|
704 |
+
best_signed_diff = signed_diff
|
705 |
elif diff > best_diff:
|
706 |
break
|
707 |
threshold += 0.01
|
708 |
self.rel_l1_thresh = best_threshold
|
709 |
+
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
710 |
return best_threshold
|
711 |
|
712 |
def forward(
|
wan/utils/prompt_parser.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
def process_template(input_text):
|
4 |
+
"""
|
5 |
+
Process a text template with macro instructions and variable substitution.
|
6 |
+
Supports multiple values for variables to generate multiple output versions.
|
7 |
+
Each section between macro lines is treated as a separate template.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
input_text (str): The input template text
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
tuple: (output_text, error_message)
|
14 |
+
- output_text: Processed output with variables substituted, or empty string if error
|
15 |
+
- error_message: Error description and problematic line, or empty string if no error
|
16 |
+
"""
|
17 |
+
lines = input_text.strip().split('\n')
|
18 |
+
current_variables = {}
|
19 |
+
current_template_lines = []
|
20 |
+
all_output_lines = []
|
21 |
+
error_message = ""
|
22 |
+
|
23 |
+
# Process the input line by line
|
24 |
+
line_number = 0
|
25 |
+
while line_number < len(lines):
|
26 |
+
orig_line = lines[line_number]
|
27 |
+
line = orig_line.strip()
|
28 |
+
line_number += 1
|
29 |
+
|
30 |
+
# Skip empty lines or comments
|
31 |
+
if not line or line.startswith('#'):
|
32 |
+
continue
|
33 |
+
|
34 |
+
# Handle macro instructions
|
35 |
+
if line.startswith('!'):
|
36 |
+
# Process any accumulated template lines before starting a new macro
|
37 |
+
if current_template_lines:
|
38 |
+
# Process the current template with current variables
|
39 |
+
template_output, err = process_current_template(current_template_lines, current_variables)
|
40 |
+
if err:
|
41 |
+
return "", err
|
42 |
+
all_output_lines.extend(template_output)
|
43 |
+
current_template_lines = [] # Reset template lines
|
44 |
+
|
45 |
+
# Reset variables for the new macro
|
46 |
+
current_variables = {}
|
47 |
+
|
48 |
+
# Parse the macro line
|
49 |
+
macro_line = line[1:].strip()
|
50 |
+
|
51 |
+
# Check for unmatched braces in the whole line
|
52 |
+
open_braces = macro_line.count('{')
|
53 |
+
close_braces = macro_line.count('}')
|
54 |
+
if open_braces != close_braces:
|
55 |
+
error_message = f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces\nLine: '{orig_line}'"
|
56 |
+
return "", error_message
|
57 |
+
|
58 |
+
# Check for unclosed quotes
|
59 |
+
if macro_line.count('"') % 2 != 0:
|
60 |
+
error_message = f"Unclosed double quotes\nLine: '{orig_line}'"
|
61 |
+
return "", error_message
|
62 |
+
|
63 |
+
# Split by optional colon separator
|
64 |
+
var_sections = re.split(r'\s*:\s*', macro_line)
|
65 |
+
|
66 |
+
for section in var_sections:
|
67 |
+
section = section.strip()
|
68 |
+
if not section:
|
69 |
+
continue
|
70 |
+
|
71 |
+
# Extract variable name
|
72 |
+
var_match = re.search(r'\{([^}]+)\}', section)
|
73 |
+
if not var_match:
|
74 |
+
if '{' in section or '}' in section:
|
75 |
+
error_message = f"Malformed variable declaration\nLine: '{orig_line}'"
|
76 |
+
return "", error_message
|
77 |
+
continue
|
78 |
+
|
79 |
+
var_name = var_match.group(1).strip()
|
80 |
+
if not var_name:
|
81 |
+
error_message = f"Empty variable name\nLine: '{orig_line}'"
|
82 |
+
return "", error_message
|
83 |
+
|
84 |
+
# Check variable value format
|
85 |
+
value_part = section[section.find('}')+1:].strip()
|
86 |
+
if not value_part.startswith('='):
|
87 |
+
error_message = f"Missing '=' after variable '{{{var_name}}}'\nLine: '{orig_line}'"
|
88 |
+
return "", error_message
|
89 |
+
|
90 |
+
# Extract all quoted values
|
91 |
+
var_values = re.findall(r'"([^"]*)"', value_part)
|
92 |
+
|
93 |
+
# Check if there are values specified
|
94 |
+
if not var_values:
|
95 |
+
error_message = f"No quoted values found for variable '{{{var_name}}}'\nLine: '{orig_line}'"
|
96 |
+
return "", error_message
|
97 |
+
|
98 |
+
# Check for missing commas between values
|
99 |
+
# Look for patterns like "value""value" (missing comma)
|
100 |
+
if re.search(r'"[^,]*"[^,]*"', value_part):
|
101 |
+
error_message = f"Missing comma between values for variable '{{{var_name}}}'\nLine: '{orig_line}'"
|
102 |
+
return "", error_message
|
103 |
+
|
104 |
+
# Store the variable values
|
105 |
+
current_variables[var_name] = var_values
|
106 |
+
|
107 |
+
# Handle template lines
|
108 |
+
else:
|
109 |
+
# Check for unknown variables in template line
|
110 |
+
var_references = re.findall(r'\{([^}]+)\}', line)
|
111 |
+
for var_ref in var_references:
|
112 |
+
if var_ref not in current_variables:
|
113 |
+
error_message = f"Unknown variable '{{{var_ref}}}' in template\nLine: '{orig_line}'"
|
114 |
+
return "", error_message
|
115 |
+
|
116 |
+
# Add to current template lines
|
117 |
+
current_template_lines.append(line)
|
118 |
+
|
119 |
+
# Process any remaining template lines
|
120 |
+
if current_template_lines:
|
121 |
+
template_output, err = process_current_template(current_template_lines, current_variables)
|
122 |
+
if err:
|
123 |
+
return "", err
|
124 |
+
all_output_lines.extend(template_output)
|
125 |
+
|
126 |
+
return '\n'.join(all_output_lines), ""
|
127 |
+
|
128 |
+
def process_current_template(template_lines, variables):
|
129 |
+
"""
|
130 |
+
Process a set of template lines with the current variables.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
template_lines (list): List of template lines to process
|
134 |
+
variables (dict): Dictionary of variable names to lists of values
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tuple: (output_lines, error_message)
|
138 |
+
"""
|
139 |
+
if not variables or not template_lines:
|
140 |
+
return template_lines, ""
|
141 |
+
|
142 |
+
output_lines = []
|
143 |
+
|
144 |
+
# Find the maximum number of values for any variable
|
145 |
+
max_values = max(len(values) for values in variables.values())
|
146 |
+
|
147 |
+
# Generate each combination
|
148 |
+
for i in range(max_values):
|
149 |
+
for template in template_lines:
|
150 |
+
output_line = template
|
151 |
+
for var_name, var_values in variables.items():
|
152 |
+
# Use modulo to cycle through values if needed
|
153 |
+
value_index = i % len(var_values)
|
154 |
+
var_value = var_values[value_index]
|
155 |
+
output_line = output_line.replace(f"{{{var_name}}}", var_value)
|
156 |
+
output_lines.append(output_line)
|
157 |
+
|
158 |
+
return output_lines, ""
|
159 |
+
|
160 |
+
|
161 |
+
def extract_variable_names(macro_line):
|
162 |
+
"""
|
163 |
+
Extract all variable names from a macro line.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
macro_line (str): A macro line (with or without the leading '!')
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
tuple: (variable_names, error_message)
|
170 |
+
- variable_names: List of variable names found in the macro
|
171 |
+
- error_message: Error description if any, empty string if no error
|
172 |
+
"""
|
173 |
+
# Remove leading '!' if present
|
174 |
+
if macro_line.startswith('!'):
|
175 |
+
macro_line = macro_line[1:].strip()
|
176 |
+
|
177 |
+
variable_names = []
|
178 |
+
|
179 |
+
# Check for unmatched braces
|
180 |
+
open_braces = macro_line.count('{')
|
181 |
+
close_braces = macro_line.count('}')
|
182 |
+
if open_braces != close_braces:
|
183 |
+
return [], f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
|
184 |
+
|
185 |
+
# Split by optional colon separator
|
186 |
+
var_sections = re.split(r'\s*:\s*', macro_line)
|
187 |
+
|
188 |
+
for section in var_sections:
|
189 |
+
section = section.strip()
|
190 |
+
if not section:
|
191 |
+
continue
|
192 |
+
|
193 |
+
# Extract variable name
|
194 |
+
var_matches = re.findall(r'\{([^}]+)\}', section)
|
195 |
+
for var_name in var_matches:
|
196 |
+
new_var = var_name.strip()
|
197 |
+
if not new_var in variable_names:
|
198 |
+
variable_names.append(new_var)
|
199 |
+
|
200 |
+
return variable_names, ""
|
201 |
+
|
202 |
+
def extract_variable_values(macro_line):
|
203 |
+
"""
|
204 |
+
Extract all variable names and their values from a macro line.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
macro_line (str): A macro line (with or without the leading '!')
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
tuple: (variables_dict, error_message)
|
211 |
+
- variables_dict: Dictionary mapping variable names to their values
|
212 |
+
- error_message: Error description if any, empty string if no error
|
213 |
+
"""
|
214 |
+
# Remove leading '!' if present
|
215 |
+
if macro_line.startswith('!'):
|
216 |
+
macro_line = macro_line[1:].strip()
|
217 |
+
|
218 |
+
variables = {}
|
219 |
+
|
220 |
+
# Check for unmatched braces
|
221 |
+
open_braces = macro_line.count('{')
|
222 |
+
close_braces = macro_line.count('}')
|
223 |
+
if open_braces != close_braces:
|
224 |
+
return {}, f"Unmatched braces: {open_braces} opening '{{' and {close_braces} closing '}}' braces"
|
225 |
+
|
226 |
+
# Check for unclosed quotes
|
227 |
+
if macro_line.count('"') % 2 != 0:
|
228 |
+
return {}, "Unclosed double quotes"
|
229 |
+
|
230 |
+
# Split by optional colon separator
|
231 |
+
var_sections = re.split(r'\s*:\s*', macro_line)
|
232 |
+
|
233 |
+
for section in var_sections:
|
234 |
+
section = section.strip()
|
235 |
+
if not section:
|
236 |
+
continue
|
237 |
+
|
238 |
+
# Extract variable name
|
239 |
+
var_match = re.search(r'\{([^}]+)\}', section)
|
240 |
+
if not var_match:
|
241 |
+
if '{' in section or '}' in section:
|
242 |
+
return {}, "Malformed variable declaration"
|
243 |
+
continue
|
244 |
+
|
245 |
+
var_name = var_match.group(1).strip()
|
246 |
+
if not var_name:
|
247 |
+
return {}, "Empty variable name"
|
248 |
+
|
249 |
+
# Check variable value format
|
250 |
+
value_part = section[section.find('}')+1:].strip()
|
251 |
+
if not value_part.startswith('='):
|
252 |
+
return {}, f"Missing '=' after variable '{{{var_name}}}'"
|
253 |
+
|
254 |
+
# Extract all quoted values
|
255 |
+
var_values = re.findall(r'"([^"]*)"', value_part)
|
256 |
+
|
257 |
+
# Check if there are values specified
|
258 |
+
if not var_values:
|
259 |
+
return {}, f"No quoted values found for variable '{{{var_name}}}'"
|
260 |
+
|
261 |
+
# Check for missing commas between values
|
262 |
+
if re.search(r'"[^,]*"[^,]*"', value_part):
|
263 |
+
return {}, f"Missing comma between values for variable '{{{var_name}}}'"
|
264 |
+
|
265 |
+
variables[var_name] = var_values
|
266 |
+
|
267 |
+
return variables, ""
|
268 |
+
|
269 |
+
def generate_macro_line(variables_dict):
|
270 |
+
"""
|
271 |
+
Generate a macro line from a dictionary of variable names and their values.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
variables_dict (dict): Dictionary mapping variable names to lists of values
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
str: A formatted macro line (including the leading '!')
|
278 |
+
"""
|
279 |
+
sections = []
|
280 |
+
|
281 |
+
for var_name, values in variables_dict.items():
|
282 |
+
# Format each value with quotes
|
283 |
+
quoted_values = [f'"{value}"' for value in values]
|
284 |
+
# Join values with commas
|
285 |
+
values_str = ','.join(quoted_values)
|
286 |
+
# Create the variable assignment
|
287 |
+
section = f"{{{var_name}}}={values_str}"
|
288 |
+
sections.append(section)
|
289 |
+
|
290 |
+
# Join sections with a colon and space for readability
|
291 |
+
return "! " + " : ".join(sections)
|