Spaces:
Running
Running
import os | |
import random | |
from datetime import datetime | |
import gradio as gr | |
import numpy as np | |
import torch | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from einops import repeat | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import CLIPVisionModelWithProjection | |
from src.models.pose_guider import PoseGuider | |
from src.models.unet_2d_condition import UNet2DConditionModel | |
from src.models.unet_3d import UNet3DConditionModel | |
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline | |
from src.utils.download_models import prepare_base_model, prepare_image_encoder | |
from src.utils.util import get_fps, read_frames, save_videos_grid | |
# Partial download | |
prepare_base_model() | |
prepare_image_encoder() | |
snapshot_download( | |
repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse" | |
) | |
snapshot_download( | |
repo_id="patrolli/AnimateAnyone", | |
local_dir="./pretrained_weights", | |
) | |
class AnimateController: | |
def __init__( | |
self, | |
config_path="./configs/prompts/animation.yaml", | |
weight_dtype=torch.float16, | |
): | |
# Read pretrained weights path from config | |
self.config = OmegaConf.load(config_path) | |
self.pipeline = None | |
self.weight_dtype = weight_dtype | |
def animate( | |
self, | |
ref_image, | |
pose_video_path, | |
width=512, | |
height=768, | |
length=24, | |
num_inference_steps=25, | |
cfg=3.5, | |
seed=123, | |
): | |
generator = torch.manual_seed(seed) | |
if isinstance(ref_image, np.ndarray): | |
ref_image = Image.fromarray(ref_image) | |
if self.pipeline is None: | |
vae = AutoencoderKL.from_pretrained( | |
self.config.pretrained_vae_path, | |
).to("cuda", dtype=self.weight_dtype) | |
reference_unet = UNet2DConditionModel.from_pretrained( | |
self.config.pretrained_base_model_path, | |
subfolder="unet", | |
).to(dtype=self.weight_dtype, device="cuda") | |
inference_config_path = self.config.inference_config | |
infer_config = OmegaConf.load(inference_config_path) | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
self.config.pretrained_base_model_path, | |
self.config.motion_module_path, | |
subfolder="unet", | |
unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
).to(dtype=self.weight_dtype, device="cuda") | |
pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( | |
dtype=self.weight_dtype, device="cuda" | |
) | |
image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
self.config.image_encoder_path | |
).to(dtype=self.weight_dtype, device="cuda") | |
sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
scheduler = DDIMScheduler(**sched_kwargs) | |
# load pretrained weights | |
denoising_unet.load_state_dict( | |
torch.load(self.config.denoising_unet_path, map_location="cpu"), | |
strict=False, | |
) | |
reference_unet.load_state_dict( | |
torch.load(self.config.reference_unet_path, map_location="cpu"), | |
) | |
pose_guider.load_state_dict( | |
torch.load(self.config.pose_guider_path, map_location="cpu"), | |
) | |
pipe = Pose2VideoPipeline( | |
vae=vae, | |
image_encoder=image_enc, | |
reference_unet=reference_unet, | |
denoising_unet=denoising_unet, | |
pose_guider=pose_guider, | |
scheduler=scheduler, | |
) | |
pipe = pipe.to("cuda", dtype=self.weight_dtype) | |
self.pipeline = pipe | |
pose_images = read_frames(pose_video_path) | |
src_fps = get_fps(pose_video_path) | |
pose_list = [] | |
pose_tensor_list = [] | |
pose_transform = transforms.Compose( | |
[transforms.Resize((height, width)), transforms.ToTensor()] | |
) | |
for pose_image_pil in pose_images[:length]: | |
pose_list.append(pose_image_pil) | |
pose_tensor_list.append(pose_transform(pose_image_pil)) | |
video = self.pipeline( | |
ref_image, | |
pose_list, | |
width=width, | |
height=height, | |
video_length=length, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=cfg, | |
generator=generator, | |
).videos | |
ref_image_tensor = pose_transform(ref_image) # (c, h, w) | |
ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) | |
ref_image_tensor = repeat( | |
ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=length | |
) | |
pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) | |
pose_tensor = pose_tensor.transpose(0, 1) | |
pose_tensor = pose_tensor.unsqueeze(0) | |
video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) | |
save_dir = f"./output/gradio" | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
date_str = datetime.now().strftime("%Y%m%d") | |
time_str = datetime.now().strftime("%H%M") | |
out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") | |
save_videos_grid( | |
video, | |
out_path, | |
n_rows=3, | |
fps=src_fps, | |
) | |
torch.cuda.empty_cache() | |
return out_path | |
controller = AnimateController() | |
def ui(): | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1 style="color:#dc5b1c;text-align:center"> | |
Moore-AnimateAnyone Gradio Demo | |
</h1> | |
<div style="text-align:center"> | |
<div style="display: inline-block; text-align: left;"> | |
<p> This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo. </p> | |
<p> To reduce waiting time, we limit the size (width, height and length) and inference steps when generating videos. </p> | |
<p> If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone"> our GitHub repo </a> π€. </p> | |
</div> | |
</div> | |
""" | |
) | |
animation = gr.Video( | |
format="mp4", | |
label="Animation Results", | |
height=448, | |
autoplay=True, | |
) | |
with gr.Row(): | |
reference_image = gr.Image(label="Reference Image") | |
motion_sequence = gr.Video( | |
format="mp4", label="Motion Sequence", height=512 | |
) | |
with gr.Column(): | |
width_slider = gr.Slider( | |
label="Width", minimum=256, maximum=448, value=448, step=64 | |
) | |
height_slider = gr.Slider( | |
label="Height", minimum=256, maximum=512, value=512, step=64 | |
) | |
length_slider = gr.Slider( | |
label="Video Length", minimum=24, maximum=24, value=24, step=1 | |
) | |
with gr.Row(): | |
seed_textbox = gr.Textbox(label="Seed", value=-1) | |
seed_button = gr.Button( | |
value="\U0001F3B2", elem_classes="toolbutton" | |
) | |
seed_button.click( | |
fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), | |
inputs=[], | |
outputs=[seed_textbox], | |
) | |
with gr.Row(): | |
sampling_steps = gr.Slider( | |
label="Sampling steps", | |
value=15, | |
info="default: 15", | |
step=5, | |
maximum=15, | |
minimum=10, | |
) | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
value=3.5, | |
info="default: 3.5", | |
step=0.5, | |
maximum=6.5, | |
minimum=2.0, | |
) | |
submit = gr.Button("Animate") | |
def read_video(video): | |
return video | |
def read_image(image): | |
return Image.fromarray(image) | |
# when user uploads a new video | |
motion_sequence.upload( | |
read_video, motion_sequence, motion_sequence, queue=False | |
) | |
# when `first_frame` is updated | |
reference_image.upload( | |
read_image, reference_image, reference_image, queue=False | |
) | |
# when the `submit` button is clicked | |
submit.click( | |
controller.animate, | |
[ | |
reference_image, | |
motion_sequence, | |
width_slider, | |
height_slider, | |
length_slider, | |
sampling_steps, | |
guidance_scale, | |
seed_textbox, | |
], | |
animation, | |
) | |
# Examples | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=[ | |
[ | |
"./configs/inference/ref_images/anyone-5.png", | |
"./configs/inference/pose_videos/anyone-video-2_kps.mp4", | |
], | |
[ | |
"./configs/inference/ref_images/anyone-10.png", | |
"./configs/inference/pose_videos/anyone-video-1_kps.mp4", | |
], | |
[ | |
"./configs/inference/ref_images/anyone-2.png", | |
"./configs/inference/pose_videos/anyone-video-5_kps.mp4", | |
], | |
], | |
inputs=[reference_image, motion_sequence], | |
outputs=animation, | |
) | |
return demo | |
demo = ui() | |
demo.queue(max_size=10) | |
demo.launch(share=True, show_api=False) | |