# ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- # ytedance Inc.. # ************************************************************************* # Adapted from https://github.com/guoyww/AnimateDiff import os import imageio import numpy as np import torch import torchvision from PIL import Image from typing import Union from tqdm import tqdm from einops import rearrange def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=25): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x + 1.0) / 2.0 # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) outputs.append(x) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, fps=fps) def save_images_grid(images: torch.Tensor, path: str): assert images.shape[2] == 1 # no time dimension images = images.squeeze(2) grid = torchvision.utils.make_grid(images) grid = (grid * 255).numpy().transpose(1, 2, 0).astype(np.uint8) os.makedirs(os.path.dirname(path), exist_ok=True) Image.fromarray(grid).save(path) # DDIM Inversion @torch.no_grad() def init_prompt(prompt, pipeline): uncond_input = pipeline.tokenizer( [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] text_input = pipeline.tokenizer( [prompt], padding="max_length", max_length=pipeline.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] context = torch.cat([uncond_embeddings, text_embeddings]) return context def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): timestep, next_timestep = min( timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] beta_prod_t = 1 - alpha_prod_t next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction return next_sample def get_noise_pred_single(latents, t, context, unet): noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] return noise_pred @torch.no_grad() def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): context = init_prompt(prompt, pipeline) uncond_embeddings, cond_embeddings = context.chunk(2) all_latent = [latent] latent = latent.clone().detach() for i in tqdm(range(num_inv_steps)): t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) latent = next_step(noise_pred, t, latent, ddim_scheduler) all_latent.append(latent) return all_latent @torch.no_grad() def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) return ddim_latents def video2images(path, step=4, length=16, start=0): reader = imageio.get_reader(path) frames = [] for frame in reader: frames.append(np.array(frame)) frames = frames[start::step][:length] return frames def images2video(video, path, fps=8): imageio.mimsave(path, video, fps=fps) return tensor_interpolation = None def get_tensor_interpolation_method(): return tensor_interpolation def set_tensor_interpolation_method(is_slerp): global tensor_interpolation tensor_interpolation = slerp if is_slerp else linear def linear(v1, v2, t): return (1.0 - t) * v1 + t * v2 def slerp( v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 ) -> torch.Tensor: u0 = v0 / v0.norm() u1 = v1 / v1.norm() dot = (u0 * u1).sum() if dot.abs() > DOT_THRESHOLD: #logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.') return (1.0 - t) * v0 + t * v1 omega = dot.acos() return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()