Control-A-Video / app.py
weifeng-chen's picture
update
e207824
from model.video_diffusion.models.controlnet3d import ControlNet3DModel
from model.video_diffusion.models.unet_3d_condition import UNetPseudo3DConditionModel
from model.video_diffusion.pipelines.pipeline_stable_diffusion_controlnet3d import Controlnet3DStableDiffusionPipeline
from transformers import DPTForDepthEstimation
from model.annotator.hed import HEDNetwork
import torch
from einops import rearrange,repeat
import imageio
import numpy as np
import cv2
import torch.nn.functional as F
from PIL import Image
import argparse
import tempfile
import os
import gradio as gr
control_mode = 'depth'
control_net_path = f"wf-genius/controlavideo-{control_mode}"
unet = UNetPseudo3DConditionModel.from_pretrained(control_net_path,
torch_dtype = torch.float16,
subfolder='unet',
).to("cuda")
controlnet = ControlNet3DModel.from_pretrained(control_net_path,
torch_dtype = torch.float16,
subfolder='controlnet',
).to("cuda")
if control_mode == 'depth':
annotator_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
elif control_mode == 'canny':
annotator_model = None
elif control_mode == 'hed':
# firstly download from https://huggingface.co/wf-genius/controlavideo-hed/resolve/main/hed-network.pth
annotator_model = HEDNetwork('hed-network.pth').to("cuda")
video_controlnet_pipe = Controlnet3DStableDiffusionPipeline.from_pretrained(control_net_path, unet=unet,
controlnet=controlnet, annotator_model=annotator_model,
torch_dtype = torch.float16,
).to("cuda")
def to_video(frames, fps: int) -> str:
out_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
writer = imageio.get_writer(out_file.name, format='FFMPEG', fps=fps)
for frame in frames:
writer.append_data(np.array(frame))
writer.close()
return out_file.name
def inference(input_video,
prompt,
seed,
num_inference_steps,
guidance_scale,
sampling_rate,
video_scale,
init_noise_thres,
each_sample_frame,
iter_times,
h,
w,
):
num_sample_frames = iter_times * each_sample_frame
testing_prompt = [prompt]
np_frames, fps_vid = Controlnet3DStableDiffusionPipeline.get_frames_preprocess(input_video, num_frames=num_sample_frames, sampling_rate=sampling_rate, return_np=True)
if control_mode == 'depth':
frames = torch.from_numpy(np_frames).div(255) * 2 - 1
frames = rearrange(frames, "f h w c -> c f h w").unsqueeze(0)
frames = rearrange(frames, 'b c f h w -> (b f) c h w')
control_maps = video_controlnet_pipe.get_depth_map(frames, h, w, return_standard_norm=False) # (b f) 1 h w
elif control_mode == 'canny':
control_maps = np.stack([cv2.Canny(inp, 100, 200) for inp in np_frames])
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
control_maps = torch.from_numpy(control_maps).div(255) # 0~1
elif control_mode == 'hed':
control_maps = np.stack([video_controlnet_pipe.get_hed_map(inp) for inp in np_frames])
control_maps = repeat(control_maps, 'f h w -> f c h w',c=1)
control_maps = torch.from_numpy(control_maps).div(255) # 0~1
control_maps = control_maps.to(dtype=controlnet.dtype, device=controlnet.device)
control_maps = F.interpolate(control_maps, size=(h,w), mode='bilinear', align_corners=False)
control_maps = rearrange(control_maps, "(b f) c h w -> b c f h w", f=num_sample_frames)
if control_maps.shape[1] == 1:
control_maps = repeat(control_maps, 'b c f h w -> b (n c) f h w', n=3)
frames = torch.from_numpy(np_frames).div(255)
frames = rearrange(frames, 'f h w c -> f c h w')
v2v_input_frames = torch.nn.functional.interpolate(
frames,
size=(h, w),
mode="bicubic",
antialias=True,
)
v2v_input_frames = rearrange(v2v_input_frames, '(b f) c h w -> b c f h w ', f=num_sample_frames)
out = []
for i in range(num_sample_frames//each_sample_frame):
out1 = video_controlnet_pipe(
# controlnet_hint= control_maps[:,:,:each_sample_frame,:,:],
# images= v2v_input_frames[:,:,:each_sample_frame,:,:],
controlnet_hint=control_maps[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else control_maps[:,:,:each_sample_frame,:,:],
images=v2v_input_frames[:,:,i*each_sample_frame-1:(i+1)*each_sample_frame-1,:,:] if i>0 else v2v_input_frames[:,:,:each_sample_frame,:,:],
first_frame_output=out[-1] if i>0 else None,
prompt=testing_prompt,
num_inference_steps=num_inference_steps,
width=w,
height=h,
guidance_scale=guidance_scale,
generator=[torch.Generator(device="cuda").manual_seed(seed)],
video_scale = video_scale,
init_noise_by_residual_thres = init_noise_thres, # residual-based init. larger thres ==> more smooth.
controlnet_conditioning_scale=1.0,
fix_first_frame=True,
in_domain=True,
)
out1 = out1.images[0]
if len(out1) > 1:
out1 = out1[1:] # drop the first frame
out.extend(out1)
return to_video(out, 8)
examples = [
["bear.mp4",
"a bear walking through stars, artstation"],
["car-shadow.mp4",
"a car, sunset, cartoon style, artstation."],
["libby.mp4",
"a dog running, chinese ink painting."],
]
def preview_inference(
input_video,
prompt, seed,
num_inference_steps, guidance_scale,
sampling_rate, video_scale, init_noise_thres,
each_sample_frame,iter_times, h, w,
):
return inference(input_video,
prompt, seed,
num_inference_steps, guidance_scale,
sampling_rate, 0.0, 0.0, 1, 1, h, w,)
if __name__ == '__main__':
with gr.Blocks() as demo:
with gr.Row():
# with gr.Column(scale=1):
input_video = gr.Video(
label="Input Video", source='upload', format="mp4", visible=True)
with gr.Column():
init_noise_thres = gr.Slider(0, 1, value=0.1, step=0.1, label="init_noise_thress")
each_sample_frame = gr.Slider(6, 16, value=8, step=1, label="each_sample_frame")
iter_times = gr.Slider(1, 4, value=1, step=1, label="iter_times")
sampling_rate = gr.Slider(1, 8, value=3, step=1, label="sampling_rate")
h = gr.Slider(256, 768, value=512, step=64, label="height")
w = gr.Slider(256, 768, value=512, step=64, label="width")
with gr.Column():
seed = gr.Slider(0, 6666, value=1, step=1, label="seed")
num_inference_steps = gr.Slider(5, 50, value=20, step=1, label="num_inference_steps")
guidance_scale = gr.Slider(1, 20, value=7.5, step=0.5, label="guidance_scale")
video_scale = gr.Slider(0, 2.5, value=1.5, step=0.1, label="video_scale")
prompt = gr.Textbox(label='Prompt')
# preview_button = gr.Button('Preview')
run_button = gr.Button('Generate Video')
# with gr.Column(scale=1):
result = gr.Video(label="Generated Video")
inputs = [
input_video,
prompt,
seed,
num_inference_steps,
guidance_scale,
sampling_rate,
video_scale,
init_noise_thres,
each_sample_frame,
iter_times,
h,
w,
]
gr.Examples(examples=examples,
inputs=inputs,
outputs=result,
fn=inference,
cache_examples=False,
run_on_click=False,
)
run_button.click(fn=inference,
inputs=inputs,
outputs=result,)
# preview_button.click(fn=preview_inference,
# inputs=inputs,
# outputs=result,)
demo.launch(server_name="0.0.0.0", server_port=7860)