Spaces:
Configuration error
Configuration error
| import os | |
| import random | |
| from datetime import datetime | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from einops import repeat | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import CLIPVisionModelWithProjection | |
| from src.models.pose_guider import PoseGuider | |
| from src.models.unet_2d_condition import UNet2DConditionModel | |
| from src.models.unet_3d import UNet3DConditionModel | |
| from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline | |
| from src.utils.download_models import prepare_base_model, prepare_image_encoder | |
| from src.utils.util import get_fps, read_frames, save_videos_grid | |
| # Partial download | |
| prepare_base_model() | |
| prepare_image_encoder() | |
| snapshot_download( | |
| repo_id="stabilityai/sd-vae-ft-mse", local_dir="./pretrained_weights/sd-vae-ft-mse" | |
| ) | |
| snapshot_download( | |
| repo_id="patrolli/AnimateAnyone", | |
| local_dir="./pretrained_weights", | |
| ) | |
| class AnimateController: | |
| def __init__( | |
| self, | |
| config_path="./configs/prompts/animation.yaml", | |
| weight_dtype=torch.float16, | |
| ): | |
| # Read pretrained weights path from config | |
| self.config = OmegaConf.load(config_path) | |
| self.pipeline = None | |
| self.weight_dtype = weight_dtype | |
| def animate( | |
| self, | |
| ref_image, | |
| pose_video_path, | |
| width=512, | |
| height=768, | |
| length=24, | |
| num_inference_steps=25, | |
| cfg=3.5, | |
| seed=123, | |
| ): | |
| generator = torch.manual_seed(seed) | |
| if isinstance(ref_image, np.ndarray): | |
| ref_image = Image.fromarray(ref_image) | |
| if self.pipeline is None: | |
| vae = AutoencoderKL.from_pretrained( | |
| self.config.pretrained_vae_path, | |
| ).to("cuda", dtype=self.weight_dtype) | |
| reference_unet = UNet2DConditionModel.from_pretrained( | |
| self.config.pretrained_base_model_path, | |
| subfolder="unet", | |
| ).to(dtype=self.weight_dtype, device="cuda") | |
| inference_config_path = self.config.inference_config | |
| infer_config = OmegaConf.load(inference_config_path) | |
| denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
| self.config.pretrained_base_model_path, | |
| self.config.motion_module_path, | |
| subfolder="unet", | |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, | |
| ).to(dtype=self.weight_dtype, device="cuda") | |
| pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( | |
| dtype=self.weight_dtype, device="cuda" | |
| ) | |
| image_enc = CLIPVisionModelWithProjection.from_pretrained( | |
| self.config.image_encoder_path | |
| ).to(dtype=self.weight_dtype, device="cuda") | |
| sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) | |
| scheduler = DDIMScheduler(**sched_kwargs) | |
| # load pretrained weights | |
| denoising_unet.load_state_dict( | |
| torch.load(self.config.denoising_unet_path, map_location="cpu"), | |
| strict=False, | |
| ) | |
| reference_unet.load_state_dict( | |
| torch.load(self.config.reference_unet_path, map_location="cpu"), | |
| ) | |
| pose_guider.load_state_dict( | |
| torch.load(self.config.pose_guider_path, map_location="cpu"), | |
| ) | |
| pipe = Pose2VideoPipeline( | |
| vae=vae, | |
| image_encoder=image_enc, | |
| reference_unet=reference_unet, | |
| denoising_unet=denoising_unet, | |
| pose_guider=pose_guider, | |
| scheduler=scheduler, | |
| ) | |
| pipe = pipe.to("cuda", dtype=self.weight_dtype) | |
| self.pipeline = pipe | |
| pose_images = read_frames(pose_video_path) | |
| src_fps = get_fps(pose_video_path) | |
| pose_list = [] | |
| total_length = min(length, len(pose_images)) | |
| for pose_image_pil in pose_images[:total_length]: | |
| pose_list.append(pose_image_pil) | |
| video = self.pipeline( | |
| ref_image, | |
| pose_list, | |
| width=width, | |
| height=height, | |
| video_length=total_length, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=cfg, | |
| generator=generator, | |
| ).videos | |
| new_h, new_w = video.shape[-2:] | |
| pose_transform = transforms.Compose( | |
| [transforms.Resize((new_h, new_w)), transforms.ToTensor()] | |
| ) | |
| pose_tensor_list = [] | |
| for pose_image_pil in pose_images[:total_length]: | |
| pose_tensor_list.append(pose_transform(pose_image_pil)) | |
| ref_image_tensor = pose_transform(ref_image) # (c, h, w) | |
| ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w) | |
| ref_image_tensor = repeat( | |
| ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=total_length | |
| ) | |
| pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) | |
| pose_tensor = pose_tensor.transpose(0, 1) | |
| pose_tensor = pose_tensor.unsqueeze(0) | |
| video = torch.cat([ref_image_tensor, pose_tensor, video], dim=0) | |
| save_dir = f"./output/gradio" | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| date_str = datetime.now().strftime("%Y%m%d") | |
| time_str = datetime.now().strftime("%H%M") | |
| out_path = os.path.join(save_dir, f"{date_str}T{time_str}.mp4") | |
| save_videos_grid( | |
| video, | |
| out_path, | |
| n_rows=3, | |
| fps=src_fps, | |
| ) | |
| torch.cuda.empty_cache() | |
| return out_path | |
| controller = AnimateController() | |
| def ui(): | |
| from datasets import load_dataset | |
| import io | |
| from PIL import Image | |
| # Load dataset and filter images | |
| image_ds = load_dataset("svjack/Genshin-Impact-Item-Image") | |
| image_df = image_ds["train"].to_pandas() | |
| image_df = image_df[ | |
| image_df["tag"].map( | |
| lambda x: "肖像" in x and "角色" in x | |
| ) | |
| ] | |
| def bytes_to_pil_image(byte_data): | |
| """ | |
| Convert a byte array to a PIL Image. | |
| :param byte_data: A byte array containing image data. | |
| :return: A PIL Image object. | |
| """ | |
| # Create a BytesIO object from the byte data | |
| image_stream = io.BytesIO(byte_data) | |
| # Open the image using PIL | |
| pil_image = Image.open(image_stream) | |
| return pil_image | |
| image_df["image"] = image_df["image"].map(lambda x: bytes_to_pil_image(x["bytes"])) | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <h1 style="color:#dc5b1c;text-align:center"> | |
| Moore-AnimateAnyone Gradio Demo | |
| </h1> | |
| <div style="text-align:center"> | |
| <div style="display: inline-block; text-align: left;"> | |
| <p> This is a quick preview demo of Moore-AnimateAnyone. We appreciate the assistance provided by the HuggingFace team in setting up this demo. </p> | |
| <p> If you like this project, please consider giving a star on <a herf="https://github.com/MooreThreads/Moore-AnimateAnyone"> our GitHub repo </a> 🤗. </p> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| # Add Gallery for selecting images | |
| with gr.Row(): | |
| gallery = gr.Gallery( | |
| image_df["image"].tolist(), | |
| label="Select Reference Image", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=[2, 3, 4, 5, 6, 6], # Number of columns for different screen sizes | |
| rows=[2, 2, 2, 2, 2, 2], # Number of rows for different screen sizes | |
| height="400px", # Height of the gallery | |
| object_fit="contain", # How images should be fit in the grid | |
| ) | |
| with gr.Row(): | |
| reference_image = gr.Image(label="Reference Image") | |
| motion_sequence = gr.Video( | |
| format="mp4", label="Motion Sequence", height=512 | |
| ) | |
| with gr.Column(): | |
| width_slider = gr.Slider( | |
| label="Width", minimum=448, maximum=768, value=512, step=64 | |
| ) | |
| height_slider = gr.Slider( | |
| label="Height", minimum=512, maximum=960, value=768, step=64 | |
| ) | |
| length_slider = gr.Slider( | |
| label="Video Length", minimum=24, maximum=128, value=72, step=24 | |
| ) | |
| with gr.Row(): | |
| seed_textbox = gr.Textbox(label="Seed", value=-1) | |
| seed_button = gr.Button( | |
| value="\U0001F3B2", elem_classes="toolbutton" | |
| ) | |
| seed_button.click( | |
| fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), | |
| inputs=[], | |
| outputs=[seed_textbox], | |
| ) | |
| with gr.Row(): | |
| sampling_steps = gr.Slider( | |
| label="Sampling steps", | |
| value=15, | |
| info="default: 15", | |
| step=5, | |
| maximum=20, | |
| minimum=10, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| value=3.5, | |
| info="default: 3.5", | |
| step=0.5, | |
| maximum=6.5, | |
| minimum=2.0, | |
| ) | |
| submit = gr.Button("Animate") | |
| # Populate gallery with images from the dataset | |
| # gallery.update(value=image_df["image"].tolist()) | |
| with gr.Row(): | |
| animation = gr.Video( | |
| format="mp4", | |
| label="Animation Results", | |
| height=448, | |
| autoplay=True, | |
| ) | |
| def read_video(video): | |
| return video | |
| def read_image(image): | |
| return Image.fromarray(image) | |
| def select_image(selection: gr.SelectData): | |
| print(selection.value['image']) | |
| return selection.value['image']["path"] | |
| # when user uploads a new video | |
| motion_sequence.upload( | |
| read_video, motion_sequence, motion_sequence, queue=False | |
| ) | |
| # when `first_frame` is updated | |
| reference_image.upload( | |
| read_image, reference_image, reference_image, queue=False | |
| ) | |
| # when the `submit` button is clicked | |
| submit.click( | |
| controller.animate, | |
| [ | |
| reference_image, | |
| motion_sequence, | |
| width_slider, | |
| height_slider, | |
| length_slider, | |
| sampling_steps, | |
| guidance_scale, | |
| seed_textbox, | |
| ], | |
| animation, | |
| ) | |
| gallery.select(fn=select_image, inputs=None, outputs=[reference_image]) | |
| # Examples | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "./configs/inference/ref_images/anyone-5.png", | |
| "./configs/inference/pose_videos/anyone-video-2_kps.mp4", | |
| 512, | |
| 768, | |
| 72, | |
| ], | |
| [ | |
| "./configs/inference/ref_images/anyone-10.png", | |
| "./configs/inference/pose_videos/anyone-video-1_kps.mp4", | |
| 512, | |
| 768, | |
| 72, | |
| ], | |
| [ | |
| "./configs/inference/ref_images/anyone-2.png", | |
| "./configs/inference/pose_videos/anyone-video-5_kps.mp4", | |
| 512, | |
| 768, | |
| 72, | |
| ], | |
| ], | |
| inputs=[reference_image, motion_sequence, width_slider, height_slider, length_slider], | |
| outputs=animation, | |
| ) | |
| return demo | |
| demo = ui() | |
| demo.queue(max_size=10) | |
| demo.launch(share=True, show_api=False) |