|
import gradio as gr |
|
import os |
|
import spaces |
|
import sys |
|
from copy import deepcopy |
|
sys.path.append('./VADER-VideoCrafter/scripts/main') |
|
sys.path.append('./VADER-VideoCrafter/scripts') |
|
sys.path.append('./VADER-VideoCrafter') |
|
|
|
|
|
from train_t2v_lora import main_fn, setup_model |
|
|
|
examples = [ |
|
["A fairy tends to enchanted, glowing flowers.", 'huggingface-hps-aesthetic', 8, 400, 384, 512, 12.0, 25, 1.0, 24, 10], |
|
["A cat playing an electric guitar in a loft with industrial-style decor and soft, multicolored lights.", 'huggingface-hps-aesthetic', 8, 206, 384, 512, 12.0, 25, 1.0, 24, 10], |
|
["A raccoon playing a guitar under a blossoming cherry tree.", 'huggingface-hps-aesthetic', 8, 204, 384, 512, 12.0, 25, 1.0, 24, 10], |
|
["A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.", |
|
"huggingface-pickscore", 16, 205, 384, 512, 12.0, 25, 1.0, 24, 10], |
|
["A talking bird with shimmering feathers and a melodious voice leads an adventure to find a legendary treasure, guiding through enchanted forests, ancient ruins, and mystical challenges.", |
|
"huggingface-pickscore", 16, 204, 384, 512, 12.0, 25, 1.0, 24, 10] |
|
] |
|
|
|
model = setup_model() |
|
|
|
@spaces.GPU(duration=120) |
|
def gradio_main_fn(prompt, lora_model, lora_rank, seed, height, width, unconditional_guidance_scale, ddim_steps, ddim_eta, |
|
frames, savefps): |
|
global model |
|
if model is None: |
|
return "Model is not loaded. Please load the model first." |
|
video_path = main_fn(prompt=prompt, |
|
lora_model=lora_model, |
|
lora_rank=int(lora_rank), |
|
seed=int(seed), |
|
height=int(height), |
|
width=int(width), |
|
unconditional_guidance_scale=float(unconditional_guidance_scale), |
|
ddim_steps=int(ddim_steps), |
|
ddim_eta=float(ddim_eta), |
|
frames=int(frames), |
|
savefps=int(savefps), |
|
model=deepcopy(model)) |
|
|
|
return video_path |
|
|
|
def reset_fn(): |
|
return ("A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.", |
|
200, 384, 512, 12.0, 25, 1.0, 24, 16, 10, "huggingface-pickscore") |
|
|
|
def update_lora_rank(lora_model): |
|
if lora_model == "huggingface-pickscore": |
|
return gr.update(value=16) |
|
elif lora_model == "huggingface-hps-aesthetic": |
|
return gr.update(value=8) |
|
else: |
|
return gr.update(value=8) |
|
|
|
def update_dropdown(lora_rank): |
|
if lora_rank == 16: |
|
return gr.update(value="huggingface-pickscore") |
|
elif lora_rank == 8: |
|
return gr.update(value="huggingface-hps-aesthetic") |
|
else: |
|
return gr.update(value="Base Model") |
|
|
|
custom_css = """ |
|
#centered { |
|
display: flex; |
|
justify-content: center; |
|
} |
|
.column-centered { |
|
display: flex; |
|
flex-direction: column; |
|
align-items: center; |
|
width: 60%; |
|
} |
|
#image-upload { |
|
flex-grow: 1; |
|
} |
|
#params .tabs { |
|
display: flex; |
|
flex-direction: column; |
|
flex-grow: 1; |
|
} |
|
#params .tabitem[style="display: block;"] { |
|
flex-grow: 1; |
|
display: flex !important; |
|
} |
|
#params .gap { |
|
flex-grow: 1; |
|
} |
|
#params .form { |
|
flex-grow: 1 !important; |
|
} |
|
#params .form > :last-child{ |
|
flex-grow: 1; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center; font-size: 3.2em; margin-bottom: 0.5em; font-family: Arial, sans-serif; margin: 20px;'> |
|
Video Diffusion Alignment via Reward Gradient |
|
</h1> |
|
""" |
|
) |
|
gr.HTML( |
|
""" |
|
<style> |
|
body { |
|
font-family: Arial, sans-serif; |
|
text-align: center; |
|
margin: 50px; |
|
} |
|
a { |
|
text-decoration: none !important; |
|
color: black !important; |
|
} |
|
|
|
</style> |
|
<body> |
|
<div style="font-size: 1.4em; margin-bottom: 0.5em; "> |
|
<a href="https://mihirp1998.github.io">Mihir Prabhudesai</a><sup>*</sup> |
|
<a href="https://russellmendonca.github.io/">Russell Mendonca</a><sup>*</sup> |
|
<a href="mailto: zheyangqin.qzy@gmail.com">Zheyang Qin</a><sup>*</sup> |
|
<a href="https://www.cs.cmu.edu/~katef/">Katerina Fragkiadaki</a><sup></sup> |
|
<a href="https://www.cs.cmu.edu/~dpathak/">Deepak Pathak</a><sup></sup> |
|
|
|
|
|
</div> |
|
<div style="font-size: 1.3em; font-style: italic;"> |
|
Carnegie Mellon University |
|
</div> |
|
</body> |
|
""" |
|
) |
|
gr.HTML( |
|
""" |
|
<head> |
|
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0-beta3/css/all.min.css"> |
|
|
|
<style> |
|
.button-container { |
|
display: flex; |
|
justify-content: center; |
|
gap: 10px; |
|
margin-top: 10px; |
|
} |
|
|
|
.button-container a { |
|
display: inline-flex; |
|
align-items: center; |
|
padding: 10px 20px; |
|
border-radius: 30px; |
|
border: 1px solid #ccc; |
|
text-decoration: none; |
|
color: #333 !important; |
|
font-size: 16px; |
|
text-decoration: none !important; |
|
} |
|
|
|
.button-container a i { |
|
margin-right: 8px; |
|
} |
|
</style> |
|
</head> |
|
|
|
<div class="button-container"> |
|
<a href="https://arxiv.org/abs/2407.08737" class="btn btn-outline-primary"> |
|
<i class="fa-solid fa-file-pdf"></i> Paper |
|
</a> |
|
<a href="https://vader-vid.github.io/" class="btn btn-outline-danger"> |
|
<i class="fa-solid fa-video"></i> Website |
|
<a href="https://github.com/mihirp1998/VADER" class="btn btn-outline-secondary"> |
|
<i class="fa-brands fa-github"></i> Code |
|
</a> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(elem_id="centered"): |
|
with gr.Column(scale=0.3, elem_id="params"): |
|
lora_model = gr.Dropdown( |
|
label="VADER Model", |
|
choices=["huggingface-pickscore", "huggingface-hps-aesthetic", "Base Model"], |
|
value="huggingface-pickscore" |
|
) |
|
lora_rank = gr.Slider(minimum=8, maximum=16, label="LoRA Rank", step = 8, value=16) |
|
prompt = gr.Textbox(placeholder="Enter prompt text here", lines=4, label="Text Prompt", |
|
value="A mermaid with flowing hair and a shimmering tail discovers a hidden underwater kingdom adorned with coral palaces, glowing pearls, and schools of colorful fish, encountering both wonders and dangers along the way.") |
|
run_btn = gr.Button("Run Inference") |
|
|
|
with gr.Column(scale=0.3): |
|
output_video = gr.Video(elem_id="image-upload") |
|
|
|
with gr.Row(elem_id="centered"): |
|
with gr.Column(scale=0.6): |
|
|
|
|
|
seed = gr.Slider(minimum=0, maximum=65536, label="Seed", step = 1, value=200) |
|
|
|
with gr.Row(): |
|
height = gr.Slider(minimum=0, maximum=512, label="Height", step = 16, value=384) |
|
width = gr.Slider(minimum=0, maximum=512, label="Width", step = 16, value=512) |
|
|
|
with gr.Row(): |
|
frames = gr.Slider(minimum=0, maximum=50, label="Frames", step = 1, value=24) |
|
savefps = gr.Slider(minimum=0, maximum=30, label="Save FPS", step = 1, value=10) |
|
|
|
|
|
with gr.Row(): |
|
DDIM_Steps = gr.Slider(minimum=0, maximum=50, label="DDIM Steps", step = 1, value=25) |
|
unconditional_guidance_scale = gr.Slider(minimum=0, maximum=50, label="Guidance Scale", step = 0.1, value=12.0) |
|
DDIM_Eta = gr.Slider(minimum=0, maximum=1, label="DDIM Eta", step = 0.01, value=1.0) |
|
|
|
|
|
reset_btn = gr.Button("Reset") |
|
|
|
reset_btn.click(fn=reset_fn, outputs=[prompt, seed, height, width, unconditional_guidance_scale, DDIM_Steps, DDIM_Eta, frames, lora_rank, savefps, lora_model]) |
|
|
|
|
|
run_btn.click(fn=gradio_main_fn, |
|
inputs=[prompt, lora_model, lora_rank, |
|
seed, height, width, unconditional_guidance_scale, |
|
DDIM_Steps, DDIM_Eta, frames, savefps], |
|
outputs=output_video |
|
) |
|
|
|
lora_model.change(fn=update_lora_rank, inputs=lora_model, outputs=lora_rank) |
|
lora_rank.change(fn=update_dropdown, inputs=lora_rank, outputs=lora_model) |
|
|
|
gr.Examples(examples=examples, |
|
inputs=[prompt, lora_model, lora_rank, seed, |
|
height, width, unconditional_guidance_scale, |
|
DDIM_Steps, DDIM_Eta, frames, savefps], |
|
outputs=output_video, |
|
fn=gradio_main_fn, |
|
run_on_click=False, |
|
cache_examples="lazy", |
|
) |
|
|
|
demo.launch(share=True) |