import os import json import torch import random import gradio as gr from glob import glob from omegaconf import OmegaConf from datetime import datetime from safetensors import safe_open from diffusers import AutoencoderKL from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer from animatelcm.scheduler.lcm_scheduler import LCMScheduler from animatelcm.models.unet import UNet3DConditionModel from animatelcm.pipelines.pipeline_animation import AnimationPipeline from animatelcm.utils.util import save_videos_grid from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora from animatelcm.utils.lcm_utils import convert_lcm_lora import copy sample_idx = 0 scheduler_dict = { "LCM": LCMScheduler, } css = """ .toolbutton { margin-buttom: 0em 0em 0em 0em; max-width: 2.5em; min-width: 2.5em !important; height: 2.5em; } """ if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" class AnimateController: def __init__(self): # config dirs self.basedir = os.getcwd() self.stable_diffusion_dir = os.path.join( self.basedir, "models", "StableDiffusion") self.motion_module_dir = os.path.join( self.basedir, "models", "Motion_Module") self.personalized_model_dir = os.path.join( self.basedir, "models", "DreamBooth_LoRA") self.savedir = os.path.join( self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) self.savedir_sample = os.path.join(self.savedir, "sample") self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors" os.makedirs(self.savedir, exist_ok=True) self.stable_diffusion_list = [] self.motion_module_list = [] self.personalized_model_list = [] self.refresh_stable_diffusion() self.refresh_motion_module() self.refresh_personalized_model() # config models self.tokenizer = None self.text_encoder = None self.vae = None self.unet = None self.pipeline = None self.lora_model_state_dict = {} self.inference_config = OmegaConf.load("configs/inference.yaml") def refresh_stable_diffusion(self): self.stable_diffusion_list = glob( os.path.join(self.stable_diffusion_dir, "*/")) def refresh_motion_module(self): motion_module_list = glob(os.path.join( self.motion_module_dir, "*.ckpt")) self.motion_module_list = [ os.path.basename(p) for p in motion_module_list] def refresh_personalized_model(self): personalized_model_list = glob(os.path.join( self.personalized_model_dir, "*.safetensors")) self.personalized_model_list = [ os.path.basename(p) for p in personalized_model_list] def update_stable_diffusion(self, stable_diffusion_dropdown): stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown) self.tokenizer = CLIPTokenizer.from_pretrained( stable_diffusion_dropdown, subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained( stable_diffusion_dropdown, subfolder="text_encoder").to(device) self.vae = AutoencoderKL.from_pretrained( stable_diffusion_dropdown, subfolder="vae").to(device) self.unet = UNet3DConditionModel.from_pretrained_2d( stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device) return gr.Dropdown.update() def update_motion_module(self, motion_module_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: motion_module_dropdown = os.path.join( self.motion_module_dir, motion_module_dropdown) motion_module_state_dict = torch.load( motion_module_dropdown, map_location="cpu") missing, unexpected = self.unet.load_state_dict( motion_module_state_dict, strict=False) del motion_module_state_dict assert len(unexpected) == 0 return gr.Dropdown.update() def update_base_model(self, base_model_dropdown): if self.unet is None: gr.Info(f"Please select a pretrained model path.") return gr.Dropdown.update(value=None) else: base_model_dropdown = os.path.join( self.personalized_model_dir, base_model_dropdown) base_model_state_dict = {} with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key) converted_vae_checkpoint = convert_ldm_vae_checkpoint( base_model_state_dict, self.vae.config) self.vae.load_state_dict(converted_vae_checkpoint) converted_unet_checkpoint = convert_ldm_unet_checkpoint( base_model_state_dict, self.unet.config) self.unet.load_state_dict(converted_unet_checkpoint, strict=False) del converted_unet_checkpoint del converted_vae_checkpoint del base_model_state_dict # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) return gr.Dropdown.update() def update_lora_model(self, lora_model_dropdown): lora_model_dropdown = os.path.join( self.personalized_model_dir, lora_model_dropdown) self.lora_model_state_dict = {} if lora_model_dropdown == "none": pass else: with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: for key in f.keys(): self.lora_model_state_dict[key] = f.get_tensor(key) return gr.Dropdown.update() @torch.no_grad() def animate( self, lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox ): if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler_dict[sampler_dropdown]( **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) ).to(device) original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k} pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider) pipeline.enable_sequential_cpu_offload() pipeline.to(device) if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: torch.seed() seed = torch.initial_seed() # with torch.autocast(device): sample = pipeline( prompt_textbox, negative_prompt=negative_prompt_textbox, num_inference_steps=sample_step_slider, guidance_scale=cfg_scale_slider, width=width_slider, height=height_slider, video_length=length_slider, ).videos pipeline.unet.load_state_dict(original_state_dict,strict=False) del original_state_dict save_sample_path = os.path.join( self.savedir_sample, f"{sample_idx}.mp4") save_videos_grid(sample, save_sample_path) sample_config = { "prompt": prompt_textbox, "n_prompt": negative_prompt_textbox, "sampler": sampler_dropdown, "num_inference_steps": sample_step_slider, "guidance_scale": cfg_scale_slider, "width": width_slider, "height": height_slider, "video_length": length_slider, "seed": seed } json_str = json.dumps(sample_config, indent=4) with open(os.path.join(self.savedir, "logs.json"), "a") as f: f.write(json_str) f.write("\n\n") return gr.Video.update(value=save_sample_path) controller = AnimateController() controller.update_stable_diffusion("stable-diffusion-v1-5") controller.update_motion_module("sd15_t2v_beta_motion.ckpt") controller.update_base_model("realistic2.safetensors") def ui(): with gr.Blocks(css=css) as demo: gr.Markdown( """ # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769) Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)
[arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm) """ ''' Important Notes: 1. The generation speed is around few seconds. There is delay in the space. 2. Increase the sampling step and cfg if you want more fancy videos. ''' ) with gr.Column(variant="panel"): with gr.Row(): base_model_dropdown = gr.Dropdown( label="Select base Dreambooth model (required)", choices=controller.personalized_model_list, interactive=True, value="realistic2.safetensors" ) base_model_dropdown.change(fn=controller.update_base_model, inputs=[ base_model_dropdown], outputs=[base_model_dropdown]) lora_model_dropdown = gr.Dropdown( label="Select LoRA model (optional)", choices=["none",], value="none", interactive=True, ) lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[ lora_model_dropdown], outputs=[lora_model_dropdown]) lora_alpha_slider = gr.Slider( label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) spatial_lora_slider = gr.Slider( label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True) personalized_refresh_button = gr.Button( value="\U0001F503", elem_classes="toolbutton") def update_personalized_model(): controller.refresh_personalized_model() return [ gr.Dropdown.update( choices=controller.personalized_model_list), gr.Dropdown.update( choices=["none"] + controller.personalized_model_list) ] personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[ base_model_dropdown, lora_model_dropdown]) with gr.Column(variant="panel"): gr.Markdown( """ ### 2. Configs for AnimateLCM. """ ) prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit") negative_prompt_textbox = gr.Textbox( label="Negative prompt", lines=2, value="bad quality") with gr.Row().style(equal_height=False): with gr.Column(): with gr.Row(): sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list( scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) sample_step_slider = gr.Slider( label="Sampling steps", value=6, minimum=1, maximum=25, step=1) width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64) height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64) length_slider = gr.Slider( label="Animation length", value=16, minimum=12, maximum=20, step=1) cfg_scale_slider = gr.Slider( label="CFG Scale", value=1.5, minimum=1, maximum=2) with gr.Row(): seed_textbox = gr.Textbox(label="Seed", value=-1) seed_button = gr.Button( value="\U0001F3B2", elem_classes="toolbutton") seed_button.click(fn=lambda: gr.Textbox.update( value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) generate_button = gr.Button( value="Generate", variant='primary') result_video = gr.Video( label="Generated Animation", interactive=False) generate_button.click( fn=controller.animate, inputs=[ lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ], outputs=[result_video] ) examples = [ [0.8, 0.8, "a boy is holding a rabbit", "bad quality", "LCM", 8, 512, 16, 512, 1.5, 123], [0.8, 0.8, "1girl smiling", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1233], [0.8, 0.8, "1girl,face,white background,", "bad quality", "LCM", 6, 512, 16, 512, 1.5, 1234], [0.8, 0.8, "clouds in the sky, best quality", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1234], ] gr.Examples( examples = examples, inputs=[ lora_alpha_slider, spatial_lora_slider, prompt_textbox, negative_prompt_textbox, sampler_dropdown, sample_step_slider, width_slider, length_slider, height_slider, cfg_scale_slider, seed_textbox, ], outputs=[result_video], fn=controller.animate, # cache_examples=True, ) return demo if __name__ == "__main__": demo = ui() # gr.close_all() demo.queue(api_open=False) demo.launch()