RAVE / app.py
ozgurkara's picture
bug fix
dfa1bd7
raw
history blame
No virus
13.7 kB
import gradio as gr
import os
import torch
import argparse
import os
import sys
import yaml
import datetime
sys.path.append(os.path.dirname(os.getcwd()))
from pipelines.sd_controlnet_rave import RAVE
from pipelines.sd_multicontrolnet_rave import RAVE_MultiControlNet
import subprocess
import utils.constants as const
import utils.video_grid_utils as vgu
import warnings
warnings.filterwarnings("ignore")
import pprint
import glob
def init_device():
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
return device
def init_paths(input_ns):
if input_ns.save_folder == None or input_ns.save_folder == '':
input_ns.save_folder = input_ns.video_name
else:
input_ns.save_folder = os.path.join(input_ns.save_folder, input_ns.video_name)
save_dir = os.path.join(const.OUTPUT_PATH, input_ns.save_folder)
os.makedirs(save_dir, exist_ok=True)
save_idx = max([int(x[-5:]) for x in os.listdir(save_dir)])+1 if os.listdir(save_dir) != [] else 0
input_ns.save_path = os.path.join(save_dir, f'{input_ns.positive_prompts}-{str(save_idx).zfill(5)}')
if '-' in input_ns.preprocess_name:
input_ns.hf_cn_path = [const.PREPROCESSOR_DICT[i] for i in input_ns.preprocess_name.split('-')]
else:
input_ns.hf_cn_path = const.PREPROCESSOR_DICT[input_ns.preprocess_name]
input_ns.hf_path = "runwayml/stable-diffusion-v1-5"
input_ns.inverse_path = os.path.join(const.GENERATED_DATA_PATH, 'inverses', input_ns.video_name, f'{input_ns.preprocess_name}_{input_ns.model_id}_{input_ns.grid_size}x{input_ns.grid_size}_{input_ns.pad}')
input_ns.control_path = os.path.join(const.GENERATED_DATA_PATH, 'controls', input_ns.video_name, f'{input_ns.preprocess_name}_{input_ns.grid_size}x{input_ns.grid_size}_{input_ns.pad}')
os.makedirs(input_ns.control_path, exist_ok=True)
os.makedirs(input_ns.inverse_path, exist_ok=True)
os.makedirs(input_ns.save_path, exist_ok=True)
return input_ns
def install_civitai_model(model_id):
full_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id, '*')
if len(glob.glob(full_path)) > 0:
full_path = glob.glob(full_path)[0]
return full_path
install_path = os.path.join(const.CWD, 'CIVIT_AI', 'safetensors')
install_path_model = os.path.join(const.CWD, 'CIVIT_AI', 'safetensors', model_id)
diffusers_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id)
convert_py_path = os.path.join(const.CWD, 'CIVIT_AI', 'convert.py')
os.makedirs(install_path, exist_ok=True)
os.makedirs(diffusers_path, exist_ok=True)
subprocess.run(f'wget https://civitai.com/api/download/models/{model_id} --content-disposition --directory {install_path_model}'.split())
model_name = glob.glob(os.path.join(install_path, model_id, '*'))[0]
model_name2 = os.path.basename(glob.glob(os.path.join(install_path, model_id, '*'))[0]).replace('.safetensors', '')
diffusers_path_model_name = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models', model_id, model_name2)
print(model_name)
subprocess.run(f'python {convert_py_path} --checkpoint_path {model_name} --dump_path {diffusers_path_model_name} --from_safetensors'.split())
subprocess.run(f'rm -rf {install_path}'.split())
return diffusers_path_model_name
def run(*args):
batch_size = 4
batch_size_vae = 1
is_ddim_inversion = True
is_shuffle = True
num_inference_steps = 20
num_inversion_step = 20
cond_step_start = 0.0
give_control_inversion = True
model_id = 'SD 1.5'
inversion_prompt = ''
save_folder = ''
list_of_inputs = [x for x in args]
input_ns = argparse.Namespace(**{})
input_ns.video_path = list_of_inputs[0] # video_path
input_ns.video_name = os.path.basename(input_ns.video_path).replace('.mp4', '').replace('.gif', '')
input_ns.preprocess_name = list_of_inputs[1]
input_ns.batch_size = batch_size
input_ns.batch_size_vae = batch_size_vae
input_ns.cond_step_start = cond_step_start
input_ns.controlnet_conditioning_scale = list_of_inputs[2]
input_ns.controlnet_guidance_end = list_of_inputs[3]
input_ns.controlnet_guidance_start = list_of_inputs[4]
input_ns.give_control_inversion = give_control_inversion
input_ns.grid_size = list_of_inputs[5]
input_ns.sample_size = list_of_inputs[6]
input_ns.pad = list_of_inputs[7]
input_ns.guidance_scale = list_of_inputs[8]
input_ns.inversion_prompt = inversion_prompt
input_ns.is_ddim_inversion = is_ddim_inversion
input_ns.is_shuffle = is_shuffle
input_ns.negative_prompts = list_of_inputs[9]
input_ns.num_inference_steps = num_inference_steps
input_ns.num_inversion_step = num_inversion_step
input_ns.positive_prompts = list_of_inputs[10]
input_ns.save_folder = save_folder
input_ns.seed = list_of_inputs[11]
input_ns.model_id = const.MODEL_IDS[model_id]
# input_ns.width = list_of_inputs[23]
# input_ns.height = list_of_inputs[24]
# input_ns.original_size = list_of_inputs[25]
diffusers_model_path = os.path.join(const.CWD, 'CIVIT_AI', 'diffusers_models')
os.makedirs(diffusers_model_path, exist_ok=True)
if 'model_id' not in list(input_ns.__dict__.keys()):
input_ns.model_id = "None"
if str(input_ns.model_id) != 'None':
input_ns.model_id = install_civitai_model(input_ns.model_id)
device = init_device()
input_ns = init_paths(input_ns)
input_ns.image_pil_list = vgu.prepare_video_to_grid(input_ns.video_path, input_ns.sample_size, input_ns.grid_size, input_ns.pad)
print(input_ns.video_path)
input_ns.sample_size = len(input_ns.image_pil_list)
print(f'Frame count: {len(input_ns.image_pil_list)}')
controlnet_class = RAVE_MultiControlNet if '-' in str(input_ns.controlnet_conditioning_scale) else RAVE
CN = controlnet_class(device)
CN.init_models(input_ns.hf_cn_path, input_ns.hf_path, input_ns.preprocess_name, input_ns.model_id)
input_dict = vars(input_ns)
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(input_dict)
yaml_dict = {k:v for k,v in input_dict.items() if k != 'image_pil_list'}
start_time = datetime.datetime.now()
if '-' in str(input_ns.controlnet_conditioning_scale):
res_vid, control_vid_1, control_vid_2 = CN(input_dict)
else:
res_vid, control_vid = CN(input_dict)
end_time = datetime.datetime.now()
save_name = f"{'-'.join(input_ns.positive_prompts.split())}_cstart-{input_ns.controlnet_guidance_start}_gs-{input_ns.guidance_scale}_pre-{'-'.join((input_ns.preprocess_name.replace('-','+').split('_')))}_cscale-{input_ns.controlnet_conditioning_scale}_grid-{input_ns.grid_size}_pad-{input_ns.pad}_model-{os.path.basename(input_ns.model_id)}"
res_vid[0].save(os.path.join(input_ns.save_path, f'{save_name}.gif'), save_all=True, append_images=res_vid[1:], loop=10000)
control_vid[0].save(os.path.join(input_ns.save_path, f'control_{save_name}.gif'), save_all=True, append_images=control_vid[1:], optimize=False, loop=10000)
yaml_dict['total_time'] = (end_time - start_time).total_seconds()
yaml_dict['total_number_of_frames'] = len(res_vid)
yaml_dict['sec_per_frame'] = yaml_dict['total_time']/yaml_dict['total_number_of_frames']
with open(os.path.join(input_ns.save_path, 'config.yaml'), 'w') as yaml_file:
yaml.dump(yaml_dict, yaml_file)
return os.path.join(input_ns.save_path, f'{save_name}.gif'), os.path.join(input_ns.save_path, f'control_{save_name}.gif')
def output_video_fn(video_path, text_prompt):
fold_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos")
video_path = os.path.join(fold_path, os.path.basename(video_path).replace('input', 'output')).replace('.mp4', '.gif')
return video_path
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown('## RAVE: Randomized Noise Shuffling for Fast and Consistent Video Editing with Diffusion Models')
with gr.Row():
with gr.Column():
with gr.Row():
input_path = gr.File(label='Upload Input Video', file_types=['.mp4'], scale=1)
inputs = gr.Video(label='Input Video',
format='mp4',
visible=True,
interactive=False,
scale=5)
input_path.upload(lambda x:x, inputs=[input_path], outputs=[inputs])
gr.Markdown('# Example Video Edits')
with gr.Row():
example_input = gr.Video(label='Input Example',
format='mp4',
visible=True,
interactive=False)
example_output = gr.Video(label='Output Example',
format='mp4',
visible=True,
interactive=False)
# input(os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos", "exp_input_1.mp4"))
ex_prompt = gr.Textbox(label='Text Prompt', interactive=False)
with gr.Row():
ex_list = []
ex_prompt_dict = {
'1': "A black panther",
'2': "A medieval knight",
'3': "Swarovski blue crystal swan",
'4': "Switzerland SBB CFF FFS train",
'5': "White cupcakes, moving on the table",
}
for i in range(1,6):
ex_list.append([os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_videos", f"exp_input_{i}.mp4"), ex_prompt_dict[str(i)]])
ex = gr.Examples(
examples=ex_list,
inputs=[example_input, ex_prompt],
outputs=example_output,
fn=output_video_fn,
cache_examples=True,)
with gr.Column():
with gr.Row():
result_video = gr.Image(label='Edited Video',
interactive=False)
control_video = gr.Image(label='Control Video',
interactive=False)
with gr.Row():
positive_prompts = gr.Textbox(label='Positive prompts')
negative_prompts = gr.Textbox(label='Negative prompts')
with gr.Row():
preprocess_list = ['depth_zoe', 'lineart_realistic', 'lineart_standard', 'softedge_hed']
preprocess_name = gr.Dropdown(preprocess_list,
label='Control type',
value='depth_zoe')
guidance_scale = gr.Slider(label='Guidance scale',
minimum=0,
maximum=40,
step=0.1,
value=7.5)
seed = gr.Slider(label='Seed',
minimum=0,
maximum=2147483647,
step=1,
value=0,
randomize=True)
run_button = gr.Button(value='Run All')
with gr.Accordion('Configuration',
open=False):
with gr.Row():
controlnet_conditioning_scale = gr.Slider(label='ControlNet conditioning scale',
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01)
controlnet_guidance_end = gr.Slider(label='ControlNet guidance end',
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01)
controlnet_guidance_start = gr.Slider(label='ControlNet guidance start',
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.01)
with gr.Row():
grid_size = gr.Slider(label='Grid size (n x n)',
minimum=2,
maximum=3,
value=3,
step=1)
sample_size = gr.Slider(label='Number of grids',
minimum=1,
maximum=10,
value=2,
step=1)
pad = gr.Slider(label='Pad',
minimum=1,
maximum=10,
value=1,
step=1)
inputs = [input_path, preprocess_name, controlnet_conditioning_scale, controlnet_guidance_end, controlnet_guidance_start, grid_size, sample_size, pad, guidance_scale, negative_prompts, positive_prompts, seed]
run_button.click(fn=run,
inputs=inputs,
outputs=[result_video, control_video])
if __name__ == "__main__":
block.launch(share=True)