RAVE / app.py
ozgurkara's picture
bug fix
e686ff4
import gradio as gr
import os
import torch
import argparse
import os
import sys
import yaml
import datetime
sys.path.append(os.path.dirname(os.getcwd()))
from pipelines.sd_controlnet_rave import RAVE
from pipelines.sd_multicontrolnet_rave import RAVE_MultiControlNet
import subprocess
import utils.constants as const
import utils.video_grid_utils as vgu
import warnings
warnings.filterwarnings("ignore")
import pprint
import glob
def init_device():
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
return device
def init_paths(input_ns):
if input_ns.save_folder == None or input_ns.save_folder == '':
input_ns.save_folder = input_ns.video_name
else:
input_ns.save_folder = os.path.join(input_ns.save_folder, input_ns.video_name)
save_dir = os.path.join(const.OUTPUT_PATH, input_ns.save_folder)
os.makedirs(save_dir, exist_ok=True)
save_idx = max([int(x[-5:]) for x in os.listdir(save_dir)])+1 if os.listdir(save_dir) != [] else 0
input_ns.save_path = os.path.join(save_dir, f'{input_ns.positive_prompts}-{str(save_idx).zfill(5)}')
if '-' in input_ns.preprocess_name:
input_ns.hf_cn_path = [const.PREPROCESSOR_DICT[i] for i in input_ns.preprocess_name.split('-')]
else:
input_ns.hf_cn_path = const.PREPROCESSOR_DICT[input_ns.preprocess_name]
input_ns.hf_path = "runwayml/stable-diffusion-v1-5"
input_ns.inverse_path = os.path.join(const.GENERATED_DATA_PATH, 'inverses', input_ns.video_name, f'{input_ns.preprocess_name}_{input_ns.model_id}_{input_ns.grid_size}x{input_ns.grid_size}_{input_ns.pad}')
input_ns.control_path = os.path.join(const.GENERATED_DATA_PATH, 'controls', input_ns.video_name, f'{input_ns.preprocess_name}_{input_ns.grid_size}x{input_ns.grid_size}_{input_ns.pad}')
os.makedirs(input_ns.control_path, exist_ok=True)
os.makedirs(input_ns.inverse_path, exist_ok=True)
os.makedirs(input_ns.save_path, exist_ok=True)
return input_ns
def install_civitai_model(model_id):
full_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id, '*')
if len(glob.glob(full_path)) > 0:
full_path = glob.glob(full_path)[0]
return full_path
install_path = os.path.join(const.CWD, 'CIVIT_AI', 'safetensors')
install_path_model = os.path.join(const.CWD, 'CIVIT_AI', 'safetensors', model_id)
diffusers_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id)
convert_py_path = os.path.join(const.CWD, 'CIVIT_AI', 'convert.py')
os.makedirs(install_path, exist_ok=True)
os.makedirs(diffusers_path, exist_ok=True)
subprocess.run(f'wget https://civitai.com/api/download/models/{model_id} --content-disposition --directory {install_path_model}'.split())
model_name = glob.glob(os.path.join(install_path, model_id, '*'))[0]
model_name2 = os.path.basename(glob.glob(os.path.join(install_path, model_id, '*'))[0]).replace('.safetensors', '')
diffusers_path_model_name = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id, model_name2)
print(model_name)
subprocess.run(f'python {convert_py_path} --checkpoint_path {model_name} --dump_path {diffusers_path_model_name} --from_safetensors'.split())
subprocess.run(f'rm -rf {install_path}'.split())
return diffusers_path_model_name
def run(*args):
batch_size = 4
batch_size_vae = 1
is_ddim_inversion = True
is_shuffle = True
num_inference_steps = 20
num_inversion_step = 20
cond_step_start = 0.0
give_control_inversion = True
inversion_prompt = ''
save_folder = ''
list_of_inputs = [x for x in args]
input_ns = argparse.Namespace(**{})
input_ns.video_path = list_of_inputs[0] # video_path
input_ns.video_name = os.path.basename(input_ns.video_path).replace('.mp4', '').replace('.gif', '')
input_ns.preprocess_name = list_of_inputs[1]
input_ns.batch_size = batch_size
input_ns.batch_size_vae = batch_size_vae
input_ns.cond_step_start = cond_step_start
input_ns.controlnet_conditioning_scale = list_of_inputs[2]
input_ns.controlnet_guidance_end = list_of_inputs[3]
input_ns.controlnet_guidance_start = list_of_inputs[4]
input_ns.give_control_inversion = give_control_inversion
input_ns.grid_size = list_of_inputs[5]
input_ns.sample_size = list_of_inputs[6]
input_ns.pad = list_of_inputs[7]
input_ns.guidance_scale = list_of_inputs[8]
input_ns.inversion_prompt = inversion_prompt
input_ns.is_ddim_inversion = is_ddim_inversion
input_ns.is_shuffle = is_shuffle
input_ns.negative_prompts = list_of_inputs[9]
input_ns.num_inference_steps = num_inference_steps
input_ns.num_inversion_step = num_inversion_step
input_ns.positive_prompts = list_of_inputs[10]
input_ns.save_folder = save_folder
input_ns.seed = list_of_inputs[11]
input_ns.model_id = const.MODEL_IDS[list_of_inputs[12]]
# input_ns.width = list_of_inputs[23]
# input_ns.height = list_of_inputs[24]
# input_ns.original_size = list_of_inputs[25]
diffusers_model_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models')
os.makedirs(diffusers_model_path, exist_ok=True)
if 'model_id' not in list(input_ns.__dict__.keys()):
input_ns.model_id = "None"
if str(input_ns.model_id) != 'None':
input_ns.model_id = install_civitai_model(input_ns.model_id)
device = init_device()
input_ns = init_paths(input_ns)
input_ns.image_pil_list = vgu.prepare_video_to_grid(input_ns.video_path, input_ns.sample_size, input_ns.grid_size, input_ns.pad)
print(input_ns.video_path)
input_ns.sample_size = len(input_ns.image_pil_list)
print(f'Frame count: {len(input_ns.image_pil_list)}')
controlnet_class = RAVE_MultiControlNet if '-' in str(input_ns.controlnet_conditioning_scale) else RAVE
CN = controlnet_class(device)
CN.init_models(input_ns.hf_cn_path, input_ns.hf_path, input_ns.preprocess_name, input_ns.model_id)
input_dict = vars(input_ns)
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(input_dict)
yaml_dict = {k:v for k,v in input_dict.items() if k != 'image_pil_list'}
start_time = datetime.datetime.now()
if '-' in str(input_ns.controlnet_conditioning_scale):
res_vid, control_vid_1, control_vid_2 = CN(input_dict)
else:
res_vid, control_vid = CN(input_dict)
end_time = datetime.datetime.now()
# res_vid = [x.crop() .resize((x.size[0], x.size[1])) for x in res_vid]
# control_vid = [x[2:-2, 2:-2].resize((x.size[0], x.size[1])) for x in control_vid]
save_name = f"{'-'.join(input_ns.positive_prompts.split())}_cstart-{input_ns.controlnet_guidance_start}_gs-{input_ns.guidance_scale}_pre-{'-'.join((input_ns.preprocess_name.replace('-','+').split('_')))}_cscale-{input_ns.controlnet_conditioning_scale}_grid-{input_ns.grid_size}_pad-{input_ns.pad}_model-{os.path.basename(input_ns.model_id)}"
res_vid[0].save(os.path.join(input_ns.save_path, f'{save_name}.gif'), save_all=True, append_images=res_vid[1:], loop=10000)
control_vid[0].save(os.path.join(input_ns.save_path, f'control_{save_name}.gif'), save_all=True, append_images=control_vid[1:], optimize=False, loop=10000)
yaml_dict['total_time'] = (end_time - start_time).total_seconds()
yaml_dict['total_number_of_frames'] = len(res_vid)
yaml_dict['sec_per_frame'] = yaml_dict['total_time']/yaml_dict['total_number_of_frames']
with open(os.path.join(input_ns.save_path, 'config.yaml'), 'w') as yaml_file:
yaml.dump(yaml_dict, yaml_file)
return os.path.join(input_ns.save_path, f'{save_name}.gif'), os.path.join(input_ns.save_path, f'control_{save_name}.gif')
def output_video_fn(video_path, text_prompt):
fold_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos")
video_path = os.path.join(fold_path, os.path.basename(video_path).replace('input', 'output'))
return video_path
block = gr.Blocks().queue()
with block:
gr.HTML(
"""
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
<a href="https://rave-video.github.io/" style="color:blue;">
RAVE: Randomized Noise Shuffling for Fast and Consistent Video Editing with Diffusion Models</a>
</h1>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
Ozgur Kara<sup>1</sup>, Bariscan Kurtkaya<sup>2</sup>, Hidir Yesiltepe<sup>4</sup>, James M. Rehg<sup>1,3</sup>, Pinar Yanardag<sup>4</sup>
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
<sup>1</sup>Georgia Institute of Technology, <sup>2</sup>KUIS AI Center, <sup>3</sup>University of Illinois Urbana-Champaign, <sup>4</sup>Virginia Tech
</h2>
<h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
[<a href="https://arxiv.org/abs/2312.04524" style="color:blue;">arXiv</a>]
[<a href="https://github.com/rehg-lab/RAVE" style="color:blue;">GitHub</a>]
[<a href="https://rave-video.github.io/" style="color:blue;">Project Webpage</a>]
</h2>
<h2 style="font-weight: 450; font-size: 1rem;">
TL; DR: RAVE is a zero-shot, lightweight, and fast framework for text-guided video editing, supporting videos of any length utilizing text-to-image pretrained diffusion models.
</h2>
<h2 style="font-weight: 450; font-size: 1rem;">
Note that this page is a limited demo of RAVE. To run with more configurations, please check out our GitHub page.
</h2>
</div>
""")
with gr.Row():
with gr.Column():
with gr.Row():
input_path = gr.File(label='Upload Input Video', file_types=['.mp4'], scale=1)
inputs = gr.Video(label='Input Video',
format='mp4',
visible=True,
interactive=False,
scale=5)
input_path.upload(lambda x:x, inputs=[input_path], outputs=[inputs])
gr.Markdown('# Example Video Edits')
with gr.Row():
example_input = gr.Video(label='Input Example',
format='mp4',
visible=True,
interactive=False)
example_output = gr.Video(label='Output Example',
format='mp4',
visible=True,
interactive=False)
# input(os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos", "exp_input_1.mp4"))
ex_prompt = gr.Textbox(label='Text Prompt', interactive=False)
with gr.Row():
ex_list = []
ex_prompt_dict = {
'1': "A black panther",
'2': "A medieval knight",
'3': "Swarovski blue crystal swan",
'4': "Switzerland SBB CFF FFS train",
'5': "White cupcakes, moving on the table",
}
for i in range(1,6):
ex_list.append([os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos", f"exp_input_{i}.mp4"), ex_prompt_dict[str(i)]])
ex = gr.Examples(
examples=ex_list,
inputs=[example_input, ex_prompt],
outputs=example_output,
fn=output_video_fn,
cache_examples=True,)
with gr.Column():
with gr.Row():
result_video = gr.Image(label='Edited Video',
interactive=False)
control_video = gr.Image(label='Control Video',
interactive=False)
with gr.Row():
positive_prompts = gr.Textbox(label='Positive prompts')
negative_prompts = gr.Textbox(label='Negative prompts')
model_id = gr.Dropdown(const.MODEL_IDS,
label='Model id',
value='SD 1.5')
with gr.Row():
preprocess_list = ['depth_zoe', 'lineart_realistic', 'lineart_standard', 'softedge_hed']
preprocess_name = gr.Dropdown(preprocess_list,
label='Control type',
value='depth_zoe')
guidance_scale = gr.Slider(label='Guidance scale',
minimum=0,
maximum=40,
step=0.1,
value=7.5)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=2147483647,
step=1,
value=0,
randomize=True)
run_button = gr.Button(value='Run All')
with gr.Accordion('Configuration',
open=False):
with gr.Row():
controlnet_conditioning_scale = gr.Slider(label='ControlNet conditioning scale',
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01)
controlnet_guidance_end = gr.Slider(label='ControlNet guidance end',
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01)
controlnet_guidance_start = gr.Slider(label='ControlNet guidance start',
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.01)
with gr.Row():
grid_size = gr.Slider(label='Grid size (n x n)',
minimum=2,
maximum=3,
value=3,
step=1)
sample_size = gr.Slider(label='Number of grids',
minimum=1,
maximum=10,
value=1,
step=1)
pad = gr.Slider(label='Pad',
minimum=1,
maximum=5,
value=2,
step=1)
inputs = [input_path, preprocess_name, controlnet_conditioning_scale, controlnet_guidance_end, controlnet_guidance_start, grid_size, sample_size, pad, guidance_scale, negative_prompts, positive_prompts, seed, model_id]
run_button.click(fn=run,
inputs=inputs,
outputs=[result_video, control_video])
if __name__ == "__main__":
block.queue(max_size=20)
block.launch(share=True)