import os import imageio import numpy as np from typing import Union import torch import torchvision import torch.distributed as dist from PIL import Image from transformers import AutoProcessor, CLIPModel import torch.nn as nn from safetensors import safe_open from tqdm import tqdm from einops import rearrange from animatediff.utils.convert_from_ckpt import ( convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, #convert_ldm_vae_checkpoint, ) from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_vae_checkpoint from animatediff.utils.convert_lora_safetensor_to_diffusers import ( convert_lora, convert_motion_lora_ckpt_to_diffusers, ) def zero_rank_print(s): if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) def ToImage(videos: torch.Tensor, rescale=False, n_rows=6): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(Image.fromarray(x)) return outputs def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) # DDIM Inversion @torch.no_grad() def init_prompt(prompt, pipeline): uncond_input = pipeline.tokenizer( [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt", ) uncond_embeddings = pipeline.text_encoder( uncond_input.input_ids.to(pipeline.device) )[0] text_input = pipeline.tokenizer( [prompt], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] context = torch.cat([uncond_embeddings, text_embeddings]) return context def next_step( model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler, ): timestep, next_timestep = ( min( timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999, ), timestep, ) alpha_prod_t = ( ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod ) alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] beta_prod_t = 1 - alpha_prod_t next_original_sample = ( sample - beta_prod_t**0.5 * model_output ) / alpha_prod_t**0.5 next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction return next_sample def get_noise_pred_single(latents, t, context, unet): noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] return noise_pred @torch.no_grad() def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): context = init_prompt(prompt, pipeline) uncond_embeddings, cond_embeddings = context.chunk(2) all_latent = [latent] latent = latent.clone().detach() for i in tqdm(range(num_inv_steps)): t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) latent = next_step(noise_pred, t, latent, ddim_scheduler) all_latent.append(latent) return all_latent @torch.no_grad() def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): ddim_latents = ddim_loop( pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt ) return ddim_latents def load_weights( animation_pipeline, # motion module motion_module_path="", motion_module_lora_configs=[], # image layers dreambooth_model_path="", lora_model_path="", lora_alpha=0.8, ): # 1.1 motion module unet_state_dict = {} if motion_module_path != "": print(f"load motion module from {motion_module_path}") motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = ( motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict ) unet_state_dict.update( { name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name } ) missing, unexpected = animation_pipeline.unet.load_state_dict( unet_state_dict, strict=False ) assert len(unexpected) == 0 del unet_state_dict if dreambooth_model_path != "": print(f"load dreambooth model from {dreambooth_model_path}") if dreambooth_model_path.endswith(".safetensors"): dreambooth_state_dict = {} with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) elif dreambooth_model_path.endswith(".ckpt"): dreambooth_state_dict = torch.load( dreambooth_model_path, map_location="cpu" ) # 1. vae converted_vae_checkpoint = convert_ldm_vae_checkpoint( dreambooth_state_dict, animation_pipeline.vae.config ) animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) # 2. unet converted_unet_checkpoint = convert_ldm_unet_checkpoint( dreambooth_state_dict, animation_pipeline.unet.config ) animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) # 3. text_model animation_pipeline.text_encoder = convert_ldm_clip_checkpoint( dreambooth_state_dict ) del dreambooth_state_dict if lora_model_path != "": print(f"load lora model from {lora_model_path}") assert lora_model_path.endswith(".safetensors") lora_state_dict = {} with safe_open(lora_model_path, framework="pt", device="cpu") as f: for key in f.keys(): lora_state_dict[key] = f.get_tensor(key) animation_pipeline = convert_lora( animation_pipeline, lora_state_dict, alpha=lora_alpha ) del lora_state_dict for motion_module_lora_config in motion_module_lora_configs: path, alpha = ( motion_module_lora_config["path"], motion_module_lora_config["alpha"], ) print(f"load motion LoRA from {path}") motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = ( motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict ) animation_pipeline = convert_motion_lora_ckpt_to_diffusers( animation_pipeline, motion_lora_state_dict, alpha ) return animation_pipeline