from transformers import CLIPTextModel, CLIPTokenizer, logging from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler # suppress partial model loading warning logging.set_verbosity_error() import os from tqdm import tqdm, trange import torch import torch.nn as nn import argparse from torchvision.io import write_video from pathlib import Path from util import * import torchvision.transforms as T def get_timesteps(scheduler, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = scheduler.timesteps[t_start:] return timesteps, num_inference_steps - t_start @torch.no_grad() def decode_latents(pipe, latents): decoded = [] batch_size = 8 for b in range(0, latents.shape[0], batch_size): latents_batch = 1 / 0.18215 * latents[b:b + batch_size] imgs = pipe.vae.decode(latents_batch).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) decoded.append(imgs) return torch.cat(decoded) @torch.no_grad() def ddim_inversion(pipe, cond, latent_frames, batch_size, save_latents=True, timesteps_to_save=None): timesteps = reversed(pipe.scheduler.timesteps) timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps for i, t in enumerate(tqdm(timesteps)): for b in range(0, latent_frames.shape[0], batch_size): x_batch = latent_frames[b:b + batch_size] model_input = x_batch cond_batch = cond.repeat(x_batch.shape[0], 1, 1) #remove comment from commented block to support controlnet # if self.sd_version == 'depth': # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]]) # model_input = torch.cat([x_batch, depth_maps],dim=1) alpha_prod_t = pipe.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( pipe.scheduler.alphas_cumprod[timesteps[i - 1]] if i > 0 else pipe.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 #remove line below and replace with commented block to support controlnet eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample # if self.sd_version != 'ControlNet': # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample # else: # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]])) pred_x0 = (x_batch - sigma_prev * eps) / mu_prev latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps # if save_latents and t in timesteps_to_save: # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt')) # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt')) return latent_frames @torch.no_grad() def ddim_sample(pipe, x, cond, batch_size): timesteps = pipe.scheduler.timesteps for i, t in enumerate(tqdm(timesteps)): for b in range(0, x.shape[0], batch_size): x_batch = x[b:b + batch_size] model_input = x_batch cond_batch = cond.repeat(x_batch.shape[0], 1, 1) #remove comment from commented block to support controlnet # if self.sd_version == 'depth': # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]]) # model_input = torch.cat([x_batch, depth_maps],dim=1) alpha_prod_t = pipe.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( pipe.scheduler.alphas_cumprod[timesteps[i + 1]] if i < len(timesteps) - 1 else pipe.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 #remove line below and replace with commented block to support controlnet eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample # if self.sd_version != 'ControlNet': # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample # else: # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]])) pred_x0 = (x_batch - sigma * eps) / mu x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps return x @torch.no_grad() def get_text_embeds(pipe, prompt, negative_prompt, batch_size=1, device="cuda"): # Tokenize text and get embeddings text_input = pipe.tokenizer(prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors='pt') text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0] # Do the same for unconditional embeddings uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length, return_tensors='pt') uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0] # Cat for final embeddings text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size) return text_embeddings @torch.no_grad() def extract_latents(pipe, num_steps, latent_frames, batch_size, timesteps_to_save, inversion_prompt=''): pipe.scheduler.set_timesteps(num_steps) cond = get_text_embeds(pipe, inversion_prompt, "", device=pipe.device)[1].unsqueeze(0) # latent_frames = self.latents inverted_latents = ddim_inversion(pipe, cond, latent_frames, batch_size=batch_size, save_latents=False, timesteps_to_save=timesteps_to_save) # latent_reconstruction = ddim_sample(pipe, inverted_latents, cond, batch_size=batch_size) # rgb_reconstruction = decode_latents(pipe, latent_reconstruction) # return rgb_reconstruction return inverted_latents @torch.no_grad() def encode_imgs(pipe, imgs, batch_size=10, deterministic=True): imgs = 2 * imgs - 1 latents = [] for i in range(0, len(imgs), batch_size): posterior = pipe.vae.encode(imgs[i:i + batch_size]).latent_dist latent = posterior.mean if deterministic else posterior.sample() latents.append(latent * 0.18215) latents = torch.cat(latents) return latents def get_data(pipe, frames, n_frames): """ converts frames to tensors, saves to device and encodes to obtain latents """ frames = frames[:n_frames] if frames[0].size[0] == frames[0].size[1]: frames = [frame.convert("RGB").resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames] stacked_tensor_frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(pipe.device) # encode to latents latents = encode_imgs(pipe, stacked_tensor_frames, deterministic=True).to(torch.float16).to(pipe.device) return stacked_tensor_frames, latents