QinOwen
commited on
Commit
•
5098655
1
Parent(s):
2846498
load-base-model-first
Browse files- VADER-VideoCrafter/scripts/main/train_t2v_lora.py +29 -32
- app.py +21 -49
- app_bk.py +273 -0
- gradio_cached_examples/32/indices.csv +1 -0
- gradio_cached_examples/32/log.csv +2 -0
- gradio_cached_examples/34/indices.csv +1 -0
- gradio_cached_examples/34/log.csv +2 -0
VADER-VideoCrafter/scripts/main/train_t2v_lora.py
CHANGED
@@ -567,7 +567,7 @@ def should_sample(global_step, validation_steps, is_sample_preview):
|
|
567 |
and is_sample_preview
|
568 |
|
569 |
|
570 |
-
def run_training(args,
|
571 |
## ---------------------step 1: accelerator setup---------------------------
|
572 |
accelerator = Accelerator( # Initialize Accelerator
|
573 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
@@ -576,6 +576,29 @@ def run_training(args, peft_model, **kwargs):
|
|
576 |
)
|
577 |
output_dir = args.project_dir
|
578 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
# Make one log on every process with the configuration for debugging.
|
580 |
create_logging(logging, logger, accelerator)
|
581 |
|
@@ -698,7 +721,7 @@ def run_training(args, peft_model, **kwargs):
|
|
698 |
# ==================================================================
|
699 |
|
700 |
|
701 |
-
def setup_model(
|
702 |
parser = get_parser()
|
703 |
args = parser.parse_args()
|
704 |
|
@@ -721,41 +744,13 @@ def setup_model(lora_ckpt_path="huggingface-pickscore", lora_rank=16):
|
|
721 |
model.first_stage_model = model.first_stage_model.half()
|
722 |
model.cond_stage_model = model.cond_stage_model.half()
|
723 |
|
724 |
-
# step 2.1: add LoRA using peft
|
725 |
-
config = peft.LoraConfig(
|
726 |
-
r=lora_rank,
|
727 |
-
target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
|
728 |
-
lora_dropout=0.01,
|
729 |
-
)
|
730 |
|
731 |
-
peft_model = peft.get_peft_model(model, config)
|
732 |
-
|
733 |
-
peft_model.print_trainable_parameters()
|
734 |
-
|
735 |
-
# load the pretrained LoRA model
|
736 |
-
if lora_ckpt_path != "Base Model":
|
737 |
-
if lora_ckpt_path == "huggingface-hps-aesthetic": # download the pretrained LoRA model from huggingface
|
738 |
-
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
739 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_hps_aesthetic.pt'
|
740 |
-
elif lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
|
741 |
-
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
742 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
|
743 |
-
elif lora_ckpt_path == "peft_model_532":
|
744 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_532.pt'
|
745 |
-
elif lora_ckpt_path == "peft_model_548":
|
746 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_548.pt'
|
747 |
-
elif lora_ckpt_path == "peft_model_536":
|
748 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_536.pt'
|
749 |
-
elif lora_ckpt_path == "peft_model_400":
|
750 |
-
lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/peft_model_400.pt'
|
751 |
-
# load the pretrained LoRA model
|
752 |
-
peft.set_peft_model_state_dict(peft_model, torch.load(lora_ckpt_path))
|
753 |
|
754 |
print("Model setup complete!")
|
755 |
-
return
|
756 |
|
757 |
|
758 |
-
def main_fn(prompt, seed=200, height=320, width=512, unconditional_guidance_scale=12, ddim_steps=25, ddim_eta=1.0,
|
759 |
frames=24, savefps=10, model=None):
|
760 |
|
761 |
parser = get_parser()
|
@@ -765,6 +760,8 @@ def main_fn(prompt, seed=200, height=320, width=512, unconditional_guidance_scal
|
|
765 |
|
766 |
# overwrite the default arguments
|
767 |
args.prompt_str = prompt
|
|
|
|
|
768 |
args.seed = seed
|
769 |
args.height = height
|
770 |
args.width = width
|
|
|
567 |
and is_sample_preview
|
568 |
|
569 |
|
570 |
+
def run_training(args, model, **kwargs):
|
571 |
## ---------------------step 1: accelerator setup---------------------------
|
572 |
accelerator = Accelerator( # Initialize Accelerator
|
573 |
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
|
576 |
)
|
577 |
output_dir = args.project_dir
|
578 |
|
579 |
+
|
580 |
+
# step 2.1: add LoRA using peft
|
581 |
+
config = peft.LoraConfig(
|
582 |
+
r=args.lora_rank,
|
583 |
+
target_modules=["to_k", "to_v", "to_q"], # only diffusion_model has these modules
|
584 |
+
lora_dropout=0.01,
|
585 |
+
)
|
586 |
+
|
587 |
+
peft_model = peft.get_peft_model(model, config)
|
588 |
+
|
589 |
+
peft_model.print_trainable_parameters()
|
590 |
+
|
591 |
+
# load the pretrained LoRA model
|
592 |
+
if args.lora_ckpt_path != "Base Model":
|
593 |
+
if args.lora_ckpt_path == "huggingface-hps-aesthetic": # download the pretrained LoRA model from huggingface
|
594 |
+
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
595 |
+
args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_hps_aesthetic.pt'
|
596 |
+
elif args.lora_ckpt_path == "huggingface-pickscore": # download the pretrained LoRA model from huggingface
|
597 |
+
snapshot_download(repo_id='zheyangqin/VADER', local_dir ='VADER-VideoCrafter/checkpoints/pretrained_lora')
|
598 |
+
args.lora_ckpt_path = 'VADER-VideoCrafter/checkpoints/pretrained_lora/vader_videocrafter_pickscore.pt'
|
599 |
+
# load the pretrained LoRA model
|
600 |
+
peft.set_peft_model_state_dict(peft_model, torch.load(args.lora_ckpt_path))
|
601 |
+
|
602 |
# Make one log on every process with the configuration for debugging.
|
603 |
create_logging(logging, logger, accelerator)
|
604 |
|
|
|
721 |
# ==================================================================
|
722 |
|
723 |
|
724 |
+
def setup_model():
|
725 |
parser = get_parser()
|
726 |
args = parser.parse_args()
|
727 |
|
|
|
744 |
model.first_stage_model = model.first_stage_model.half()
|
745 |
model.cond_stage_model = model.cond_stage_model.half()
|
746 |
|
|
|
|
|
|
|
|
|
|
|
|
|
747 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
748 |
|
749 |
print("Model setup complete!")
|
750 |
+
return model
|
751 |
|
752 |
|
753 |
+
def main_fn(prompt, lora_model, lora_rank, seed=200, height=320, width=512, unconditional_guidance_scale=12, ddim_steps=25, ddim_eta=1.0,
|
754 |
frames=24, savefps=10, model=None):
|
755 |
|
756 |
parser = get_parser()
|
|
|
760 |
|
761 |
# overwrite the default arguments
|
762 |
args.prompt_str = prompt
|
763 |
+
args.lora_ckpt_path = lora_model
|
764 |
+
args.lora_rank = lora_rank
|
765 |
args.seed = seed
|
766 |
args.height = height
|
767 |
args.width = width
|
app.py
CHANGED
@@ -2,6 +2,7 @@ import gradio as gr
|
|
2 |
import os
|
3 |
import spaces
|
4 |
import sys
|
|
|
5 |
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
6 |
sys.path.append('./VADER-VideoCrafter/scripts')
|
7 |
sys.path.append('./VADER-VideoCrafter')
|
@@ -19,24 +20,26 @@ examples = [
|
|
19 |
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
20 |
]
|
21 |
|
22 |
-
model =
|
23 |
|
24 |
@spaces.GPU(duration=70)
|
25 |
-
def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
26 |
frames, savefps):
|
27 |
global model
|
28 |
if model is None:
|
29 |
return "Model is not loaded. Please load the model first."
|
30 |
video_path = main_fn(prompt=prompt,
|
|
|
|
|
31 |
seed=int(seed),
|
32 |
height=int(height),
|
33 |
-
width=int(width),
|
34 |
-
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
35 |
-
ddim_steps=int(ddim_steps),
|
36 |
ddim_eta=float(ddim_eta),
|
37 |
-
frames=int(frames),
|
38 |
savefps=int(savefps),
|
39 |
-
model=model)
|
40 |
|
41 |
return video_path
|
42 |
|
@@ -60,35 +63,6 @@ def update_dropdown(lora_rank):
|
|
60 |
else: # 0
|
61 |
return gr.update(value="Base Model")
|
62 |
|
63 |
-
@spaces.GPU(duration=180)
|
64 |
-
def setup_model_progress(lora_model, lora_rank):
|
65 |
-
global model
|
66 |
-
|
67 |
-
# Disable buttons and show loading indicator
|
68 |
-
yield (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "Loading model...")
|
69 |
-
|
70 |
-
model = setup_model(lora_model, lora_rank) # Ensure you pass the necessary parameters to the setup_model function
|
71 |
-
|
72 |
-
# Enable buttons after loading and update indicator
|
73 |
-
yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
|
74 |
-
|
75 |
-
@spaces.GPU(duration=300)
|
76 |
-
def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
77 |
-
frames, savefps):
|
78 |
-
global model
|
79 |
-
model = setup_model(lora_model, lora_rank)
|
80 |
-
video_path = main_fn(prompt=prompt,
|
81 |
-
seed=int(seed),
|
82 |
-
height=int(height),
|
83 |
-
width=int(width),
|
84 |
-
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
85 |
-
ddim_steps=int(ddim_steps),
|
86 |
-
ddim_eta=float(ddim_eta),
|
87 |
-
frames=int(frames),
|
88 |
-
savefps=int(savefps),
|
89 |
-
model=model)
|
90 |
-
return video_path
|
91 |
-
|
92 |
custom_css = """
|
93 |
#centered {
|
94 |
display: flex;
|
@@ -215,23 +189,19 @@ with gr.Blocks(css=custom_css) as demo:
|
|
215 |
value="huggingface-pickscore"
|
216 |
)
|
217 |
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
|
222 |
with gr.Column(scale=0.3):
|
223 |
output_video = gr.Video(elem_id="image-upload")
|
224 |
|
225 |
with gr.Row(elem_id="centered"):
|
226 |
with gr.Column(scale=0.6):
|
227 |
-
|
228 |
-
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
229 |
|
230 |
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
231 |
|
232 |
-
run_btn = gr.Button("Run Inference")
|
233 |
-
|
234 |
-
|
235 |
with gr.Row():
|
236 |
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
237 |
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
@@ -252,10 +222,10 @@ with gr.Blocks(css=custom_css) as demo:
|
|
252 |
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
253 |
|
254 |
|
255 |
-
|
256 |
-
load_btn.click(fn=setup_model_progress, inputs=[lora_model, lora_rank], outputs=[load_btn, run_btn, reset_btn, loading_indicator])
|
257 |
run_btn.click(fn=gradio_main_fn,
|
258 |
-
inputs=[prompt,
|
|
|
|
|
259 |
outputs=output_video
|
260 |
)
|
261 |
|
@@ -263,9 +233,11 @@ with gr.Blocks(css=custom_css) as demo:
|
|
263 |
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
264 |
|
265 |
gr.Examples(examples=examples,
|
266 |
-
inputs=[prompt, lora_model, lora_rank, seed,
|
|
|
|
|
267 |
outputs=output_video,
|
268 |
-
fn=
|
269 |
run_on_click=False,
|
270 |
cache_examples="lazy",
|
271 |
)
|
|
|
2 |
import os
|
3 |
import spaces
|
4 |
import sys
|
5 |
+
from copy import deepcopy
|
6 |
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
7 |
sys.path.append('./VADER-VideoCrafter/scripts')
|
8 |
sys.path.append('./VADER-VideoCrafter')
|
|
|
20 |
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
21 |
]
|
22 |
|
23 |
+
model = setup_model()
|
24 |
|
25 |
@spaces.GPU(duration=70)
|
26 |
+
def gradio_main_fn(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
27 |
frames, savefps):
|
28 |
global model
|
29 |
if model is None:
|
30 |
return "Model is not loaded. Please load the model first."
|
31 |
video_path = main_fn(prompt=prompt,
|
32 |
+
lora_model=lora_model,
|
33 |
+
lora_rank=int(lora_rank),
|
34 |
seed=int(seed),
|
35 |
height=int(height),
|
36 |
+
width=int(width),
|
37 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
38 |
+
ddim_steps=int(ddim_steps),
|
39 |
ddim_eta=float(ddim_eta),
|
40 |
+
frames=int(frames),
|
41 |
savefps=int(savefps),
|
42 |
+
model=deepcopy(model))
|
43 |
|
44 |
return video_path
|
45 |
|
|
|
63 |
else: # 0
|
64 |
return gr.update(value="Base Model")
|
65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
custom_css = """
|
67 |
#centered {
|
68 |
display: flex;
|
|
|
189 |
value="huggingface-pickscore"
|
190 |
)
|
191 |
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
192 |
+
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
|
193 |
+
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
194 |
+
run_btn = gr.Button("Run Inference")
|
195 |
|
196 |
with gr.Column(scale=0.3):
|
197 |
output_video = gr.Video(elem_id="image-upload")
|
198 |
|
199 |
with gr.Row(elem_id="centered"):
|
200 |
with gr.Column(scale=0.6):
|
201 |
+
|
|
|
202 |
|
203 |
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
204 |
|
|
|
|
|
|
|
205 |
with gr.Row():
|
206 |
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
207 |
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
|
|
222 |
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
223 |
|
224 |
|
|
|
|
|
225 |
run_btn.click(fn=gradio_main_fn,
|
226 |
+
inputs=[prompt, lora_model, lora_rank,
|
227 |
+
seed, height, width, unconditional_guidance_scale,
|
228 |
+
DDIM_Steps, DDIM_Eta, frames, savefps],
|
229 |
outputs=output_video
|
230 |
)
|
231 |
|
|
|
233 |
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
234 |
|
235 |
gr.Examples(examples=examples,
|
236 |
+
inputs=[prompt, lora_model, lora_rank, seed,
|
237 |
+
height, width, unconditional_guidance_scale,
|
238 |
+
DDIM_Steps, DDIM_Eta, frames, savefps],
|
239 |
outputs=output_video,
|
240 |
+
fn=gradio_main_fn,
|
241 |
run_on_click=False,
|
242 |
cache_examples="lazy",
|
243 |
)
|
app_bk.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import spaces
|
4 |
+
import sys
|
5 |
+
sys.path.append('./VADER-VideoCrafter/scripts/main')
|
6 |
+
sys.path.append('./VADER-VideoCrafter/scripts')
|
7 |
+
sys.path.append('./VADER-VideoCrafter')
|
8 |
+
|
9 |
+
|
10 |
+
from train_t2v_lora import main_fn, setup_model
|
11 |
+
|
12 |
+
examples = [
|
13 |
+
["A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10],
|
14 |
+
["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 'huggingface-hps-aesthetic', 8, 206, 384, 512, 12.0, 25, 1.0, 24, 10],
|
15 |
+
["A raccoon playing a guitar under a blossoming cherry tree.", 'huggingface-hps-aesthetic', 8, 204, 384, 512, 12.0, 25, 1.0, 24, 10],
|
16 |
+
["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
|
17 |
+
"huggingface-pickscore", 16, 205, 384, 512, 12.0, 25, 1.0, 24, 10],
|
18 |
+
["A talking bird with shimmering feathers and a melodious voice leads an adventure to find a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.",
|
19 |
+
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10]
|
20 |
+
]
|
21 |
+
|
22 |
+
model = None # Placeholder for model
|
23 |
+
|
24 |
+
@spaces.GPU(duration=70)
|
25 |
+
def gradio_main_fn(prompt, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
26 |
+
frames, savefps):
|
27 |
+
global model
|
28 |
+
if model is None:
|
29 |
+
return "Model is not loaded. Please load the model first."
|
30 |
+
video_path = main_fn(prompt=prompt,
|
31 |
+
seed=int(seed),
|
32 |
+
height=int(height),
|
33 |
+
width=int(width),
|
34 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
35 |
+
ddim_steps=int(ddim_steps),
|
36 |
+
ddim_eta=float(ddim_eta),
|
37 |
+
frames=int(frames),
|
38 |
+
savefps=int(savefps),
|
39 |
+
model=model)
|
40 |
+
|
41 |
+
return video_path
|
42 |
+
|
43 |
+
def reset_fn():
|
44 |
+
return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.",
|
45 |
+
200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore")
|
46 |
+
|
47 |
+
def update_lora_rank(lora_model):
|
48 |
+
if lora_model == "huggingface-pickscore":
|
49 |
+
return gr.update(value=16)
|
50 |
+
elif lora_model == "huggingface-hps-aesthetic":
|
51 |
+
return gr.update(value=8)
|
52 |
+
else: # "Base Model"
|
53 |
+
return gr.update(value=8)
|
54 |
+
|
55 |
+
def update_dropdown(lora_rank):
|
56 |
+
if lora_rank == 16:
|
57 |
+
return gr.update(value="huggingface-pickscore")
|
58 |
+
elif lora_rank == 8:
|
59 |
+
return gr.update(value="huggingface-hps-aesthetic")
|
60 |
+
else: # 0
|
61 |
+
return gr.update(value="Base Model")
|
62 |
+
|
63 |
+
@spaces.GPU(duration=120)
|
64 |
+
def setup_model_progress(lora_model, lora_rank):
|
65 |
+
global model
|
66 |
+
|
67 |
+
# Disable buttons and show loading indicator
|
68 |
+
yield (gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False), "Loading model...")
|
69 |
+
|
70 |
+
model = setup_model(lora_model, lora_rank) # Ensure you pass the necessary parameters to the setup_model function
|
71 |
+
|
72 |
+
# Enable buttons after loading and update indicator
|
73 |
+
yield (gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), "Model loaded successfully")
|
74 |
+
|
75 |
+
@spaces.GPU(duration=180)
|
76 |
+
def generate_example(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta,
|
77 |
+
frames, savefps):
|
78 |
+
global model
|
79 |
+
model = setup_model(lora_model, lora_rank)
|
80 |
+
video_path = main_fn(prompt=prompt,
|
81 |
+
seed=int(seed),
|
82 |
+
height=int(height),
|
83 |
+
width=int(width),
|
84 |
+
unconditional_guidance_scale=float(unconditional_guidance_scale),
|
85 |
+
ddim_steps=int(ddim_steps),
|
86 |
+
ddim_eta=float(ddim_eta),
|
87 |
+
frames=int(frames),
|
88 |
+
savefps=int(savefps),
|
89 |
+
model=model)
|
90 |
+
return video_path
|
91 |
+
|
92 |
+
custom_css = """
|
93 |
+
#centered {
|
94 |
+
display: flex;
|
95 |
+
justify-content: center;
|
96 |
+
}
|
97 |
+
.column-centered {
|
98 |
+
display: flex;
|
99 |
+
flex-direction: column;
|
100 |
+
align-items: center;
|
101 |
+
width: 60%;
|
102 |
+
}
|
103 |
+
#image-upload {
|
104 |
+
flex-grow: 1;
|
105 |
+
}
|
106 |
+
#params .tabs {
|
107 |
+
display: flex;
|
108 |
+
flex-direction: column;
|
109 |
+
flex-grow: 1;
|
110 |
+
}
|
111 |
+
#params .tabitem[style="display: block;"] {
|
112 |
+
flex-grow: 1;
|
113 |
+
display: flex !important;
|
114 |
+
}
|
115 |
+
#params .gap {
|
116 |
+
flex-grow: 1;
|
117 |
+
}
|
118 |
+
#params .form {
|
119 |
+
flex-grow: 1 !important;
|
120 |
+
}
|
121 |
+
#params .form > :last-child{
|
122 |
+
flex-grow: 1;
|
123 |
+
}
|
124 |
+
"""
|
125 |
+
|
126 |
+
with gr.Blocks(css=custom_css) as demo:
|
127 |
+
with gr.Row():
|
128 |
+
with gr.Column():
|
129 |
+
gr.HTML(
|
130 |
+
"""
|
131 |
+
<h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'>
|
132 |
+
Video Diffusion Alignment via Reward Gradient
|
133 |
+
</h1>
|
134 |
+
"""
|
135 |
+
)
|
136 |
+
gr.HTML(
|
137 |
+
"""
|
138 |
+
<style>
|
139 |
+
body {
|
140 |
+
font-family: Arial, sans-serif;
|
141 |
+
text-align: center;
|
142 |
+
margin: 50px;
|
143 |
+
}
|
144 |
+
a {
|
145 |
+
text-decoration: none !important;
|
146 |
+
color: black !important;
|
147 |
+
}
|
148 |
+
|
149 |
+
</style>
|
150 |
+
<body>
|
151 |
+
<div style="font-size: 1.4em; margin-bottom: 0.5em; ">
|
152 |
+
<a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup>
|
153 |
+
<a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup>
|
154 |
+
<a href="mailto: zheyangqin.qzy@gmail.com">Zheyang Qin</a><sup>*</sup>
|
155 |
+
<a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup>
|
156 |
+
<a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup>
|
157 |
+
|
158 |
+
|
159 |
+
</div>
|
160 |
+
<div style="font-size: 1.3em; font-style: italic;">
|
161 |
+
Carnegie Mellon University
|
162 |
+
</div>
|
163 |
+
</body>
|
164 |
+
"""
|
165 |
+
)
|
166 |
+
gr.HTML(
|
167 |
+
"""
|
168 |
+
<head>
|
169 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css">
|
170 |
+
|
171 |
+
<style>
|
172 |
+
.button-container {
|
173 |
+
display: flex;
|
174 |
+
justify-content: center;
|
175 |
+
gap: 10px;
|
176 |
+
margin-top: 10px;
|
177 |
+
}
|
178 |
+
|
179 |
+
.button-container a {
|
180 |
+
display: inline-flex;
|
181 |
+
align-items: center;
|
182 |
+
padding: 10px 20px;
|
183 |
+
border-radius: 30px;
|
184 |
+
border: 1px solid #ccc;
|
185 |
+
text-decoration: none;
|
186 |
+
color: #333 !important;
|
187 |
+
font-size: 16px;
|
188 |
+
text-decoration: none !important;
|
189 |
+
}
|
190 |
+
|
191 |
+
.button-container a i {
|
192 |
+
margin-right: 8px;
|
193 |
+
}
|
194 |
+
</style>
|
195 |
+
</head>
|
196 |
+
|
197 |
+
<div class="button-container">
|
198 |
+
<a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary">
|
199 |
+
<i class="fa-solid fa-file-pdf"></i> Paper
|
200 |
+
</a>
|
201 |
+
<a href="https://vader-vid.github.io/" class="btn btn-outline-danger">
|
202 |
+
<i class="fa-solid fa-video"></i> Website
|
203 |
+
<a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary">
|
204 |
+
<i class="fa-brands fa-github"></i> Code
|
205 |
+
</a>
|
206 |
+
</div>
|
207 |
+
"""
|
208 |
+
)
|
209 |
+
|
210 |
+
with gr.Row(elem_id="centered"):
|
211 |
+
with gr.Column(scale=0.3, elem_id="params"):
|
212 |
+
lora_model = gr.Dropdown(
|
213 |
+
label="VADER Model",
|
214 |
+
choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"],
|
215 |
+
value="huggingface-pickscore"
|
216 |
+
)
|
217 |
+
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16)
|
218 |
+
load_btn = gr.Button("Load Model")
|
219 |
+
# Add a label to show the loading indicator
|
220 |
+
loading_indicator = gr.Label(value="", label="Loading Indicator")
|
221 |
+
|
222 |
+
with gr.Column(scale=0.3):
|
223 |
+
output_video = gr.Video(elem_id="image-upload")
|
224 |
+
|
225 |
+
with gr.Row(elem_id="centered"):
|
226 |
+
with gr.Column(scale=0.6):
|
227 |
+
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt",
|
228 |
+
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.")
|
229 |
+
|
230 |
+
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200)
|
231 |
+
|
232 |
+
run_btn = gr.Button("Run Inference")
|
233 |
+
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
height = gr.Slider(minimum=0, maximum=1024, label="Height", step = 16, value=384)
|
237 |
+
width = gr.Slider(minimum=0, maximum=1024, label="Width", step = 16, value=512)
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24)
|
241 |
+
savefps = gr.Slider(minimum=0, maximum=60, label="Save FPS", step = 1, value=10)
|
242 |
+
|
243 |
+
|
244 |
+
with gr.Row():
|
245 |
+
DDIM_Steps = gr.Slider(minimum=0, maximum=100, label="DDIM Steps", step = 1, value=25)
|
246 |
+
unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0)
|
247 |
+
DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0)
|
248 |
+
|
249 |
+
# reset button
|
250 |
+
reset_btn = gr.Button("Reset")
|
251 |
+
|
252 |
+
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model])
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
load_btn.click(fn=setup_model_progress, inputs=[lora_model, lora_rank], outputs=[load_btn, run_btn, reset_btn, loading_indicator])
|
257 |
+
run_btn.click(fn=gradio_main_fn,
|
258 |
+
inputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps],
|
259 |
+
outputs=output_video
|
260 |
+
)
|
261 |
+
|
262 |
+
lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank)
|
263 |
+
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model)
|
264 |
+
|
265 |
+
gr.Examples(examples=examples,
|
266 |
+
inputs=[prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, savefps],
|
267 |
+
outputs=output_video,
|
268 |
+
fn=generate_example,
|
269 |
+
run_on_click=False,
|
270 |
+
cache_examples="lazy",
|
271 |
+
)
|
272 |
+
|
273 |
+
demo.launch(share=True)
|
gradio_cached_examples/32/indices.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0
|
gradio_cached_examples/32/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
component 0,flag,username,timestamp
|
2 |
+
"{""video"": {""path"": ""gradio_cached_examples/32/component 0/fd156c6a458fa048724e/temporal.mp4"", ""url"": ""/file=/tmp/gradio/4bc133becbc469de8da700250f7f7df1103c6f56/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-19 00:00:10.509808
|
gradio_cached_examples/34/indices.csv
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
0
|
gradio_cached_examples/34/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
component 0,flag,username,timestamp
|
2 |
+
"{""video"": {""path"": ""gradio_cached_examples/34/component 0/d2ac1c9664e80f60d50f/temporal.mp4"", ""url"": ""/file=/tmp/gradio/4bc133becbc469de8da700250f7f7df1103c6f56/temporal.mp4"", ""size"": null, ""orig_name"": ""temporal.mp4"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""subtitles"": null}",,,2024-07-18 23:33:26.912888
|