import os from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers import torch import torch.nn.functional as F import tqdm import numpy as np import safetensors from PIL import Image from torchvision import transforms from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers import StableDiffusionPipeline from argparse import ArgumentParser from utils.model_utils import get_img, slerp, do_replace_attn from utils.lora_utils import train_lora, load_lora from utils.alpha_scheduler import AlphaScheduler class StoreProcessor(): def __init__(self, original_processor, value_dict, name): self.original_processor = original_processor self.value_dict = value_dict self.name = name self.value_dict[self.name] = dict() self.id = 0 def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): # Is self attention if encoder_hidden_states is None: self.value_dict[self.name][self.id] = hidden_states.detach() self.id += 1 res = self.original_processor(attn, hidden_states, *args, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **kwargs) return res class LoadProcessor(): def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6): super().__init__() self.original_processor = original_processor self.name = name self.img0_dict = img0_dict self.img1_dict = img1_dict self.alpha = alpha self.beta = beta self.lamd = lamd self.id = 0 def parent_call( self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view( batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask( attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm( hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) + scale * \ self.original_processor.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states( encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + scale * \ self.original_processor.to_k_lora(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) + scale * \ self.original_processor.to_v_lora(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores( query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0]( hidden_states) + scale * self.original_processor.to_out_lora(hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose( -1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): # Is self attention if encoder_hidden_states is None: # hardcode timestep if self.id < 50 * self.lamd: map0 = self.img0_dict[self.name][self.id] map1 = self.img1_dict[self.name][self.id] cross_map = self.beta * hidden_states + \ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1) # cross_map = self.beta * hidden_states + \ # (1 - self.beta) * slerp(map0, map1, self.alpha) # cross_map = slerp(slerp(map0, map1, self.alpha), # hidden_states, self.beta) # cross_map = hidden_states # cross_map = torch.cat( # ((1 - self.alpha) * map0, self.alpha * map1), dim=1) # res = self.original_processor(attn, hidden_states, *args, # encoder_hidden_states=cross_map, # attention_mask=attention_mask, # temb=temb, **kwargs) res = self.parent_call(attn, hidden_states, *args, encoder_hidden_states=cross_map, attention_mask=attention_mask, **kwargs) else: res = self.original_processor(attn, hidden_states, *args, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **kwargs) self.id += 1 # if self.id == len(self.img0_dict[self.name]): if self.id == len(self.img0_dict[self.name]): self.id = 0 else: res = self.original_processor(attn, hidden_states, *args, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, **kwargs) return res class DiffMorpherPipeline(StableDiffusionPipeline): def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) self.img0_dict = dict() self.img1_dict = dict() def inv_step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0., verbose=False ): """ Inverse sampling for DDIM Inversion """ if verbose: print("timestep: ", timestep) next_step = timestep timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999) alpha_prod_t = self.scheduler.alphas_cumprod[ timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir return x_next, pred_x0 @torch.no_grad() def invert( self, image: torch.Tensor, prompt, num_inference_steps=50, num_actual_inference_steps=None, guidance_scale=1., eta=0.0, **kwds): """ invert a real image into noise map with determinisc DDIM inversion """ DEVICE = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") batch_size = image.shape[0] if isinstance(prompt, list): if batch_size == 1: image = image.expand(len(prompt), -1, -1, -1) elif isinstance(prompt, str): if batch_size > 1: prompt = [prompt] * batch_size # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] print("input text embeddings :", text_embeddings.shape) # define initial latents latents = self.image2latent(image) # unconditional embedding for classifier free guidance if guidance_scale > 1.: max_length = text_input.input_ids.shape[-1] unconditional_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=77, return_tensors="pt" ) unconditional_embeddings = self.text_encoder( unconditional_input.input_ids.to(DEVICE))[0] text_embeddings = torch.cat( [unconditional_embeddings, text_embeddings], dim=0) print("latents shape: ", latents.shape) # interative sampling self.scheduler.set_timesteps(num_inference_steps) print("Valid timesteps: ", reversed(self.scheduler.timesteps)) # print("attributes: ", self.scheduler.__dict__) latents_list = [latents] pred_x0_list = [latents] for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): if num_actual_inference_steps is not None and i >= num_actual_inference_steps: continue if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents # predict the noise noise_pred = self.unet( model_inputs, t, encoder_hidden_states=text_embeddings).sample if guidance_scale > 1.: noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * \ (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t-1 -> x_t latents, pred_x0 = self.inv_step(noise_pred, t, latents) latents_list.append(latents) pred_x0_list.append(pred_x0) return latents @torch.no_grad() def ddim_inversion(self, latent, cond): timesteps = reversed(self.scheduler.timesteps) with torch.autocast(device_type='cuda', dtype=torch.float32): for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")): cond_batch = cond.repeat(latent.shape[0], 1, 1) alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[timesteps[i - 1]] if i > 0 else self.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 eps = self.unet( latent, t, encoder_hidden_states=cond_batch).sample pred_x0 = (latent - sigma_prev * eps) / mu_prev latent = mu * pred_x0 + sigma * eps # if save_latents: # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) return latent def step( self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, ): """ predict the sample of the next step in the denoise process. """ prev_timestep = timestep - \ self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[timestep] alpha_prod_t_prev = self.scheduler.alphas_cumprod[ prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir return x_prev, pred_x0 @torch.no_grad() def image2latent(self, image): DEVICE = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") if type(image) is Image: image = np.array(image) image = torch.from_numpy(image).float() / 127.5 - 1 image = image.permute(2, 0, 1).unsqueeze(0) # input image density range [-1, 1] latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean latents = latents * 0.18215 return latents @torch.no_grad() def latent2image(self, latents, return_type='np'): latents = 1 / 0.18215 * latents.detach() image = self.vae.decode(latents)['sample'] if return_type == 'np': image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy()[0] image = (image * 255).astype(np.uint8) elif return_type == "pt": image = (image / 2 + 0.5).clamp(0, 1) return image def latent2image_grad(self, latents): latents = 1 / 0.18215 * latents image = self.vae.decode(latents)['sample'] return image # range [-1, 1] @torch.no_grad() def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None): # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \ # torch.sin(alpha * torch.pi / 2) * img_noise_1 # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1 # latents = latents / ((1 - alpha) ** 2 + alpha ** 2) latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain) text_embeddings = (1 - alpha) * text_embeddings_0 + \ alpha * text_embeddings_1 self.scheduler.set_timesteps(num_inference_steps) if use_lora: if fix_lora is not None: self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora) else: self.unet = load_lora(self.unet, lora_0, lora_1, alpha) for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")): if guidance_scale > 1.: model_inputs = torch.cat([latents] * 2) else: model_inputs = latents if unconditioning is not None and isinstance(unconditioning, list): _, text_embeddings = text_embeddings.chunk(2) text_embeddings = torch.cat( [unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) # predict the noise noise_pred = self.unet( model_inputs, t, encoder_hidden_states=text_embeddings).sample if guidance_scale > 1.0: noise_pred_uncon, noise_pred_con = noise_pred.chunk( 2, dim=0) noise_pred = noise_pred_uncon + guidance_scale * \ (noise_pred_con - noise_pred_uncon) # compute the previous noise sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, return_dict=False)[0] return latents @torch.no_grad() def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size): DEVICE = torch.device( "cuda") if torch.cuda.is_available() else torch.device("cpu") # text embeddings text_input = self.tokenizer( prompt, padding="max_length", max_length=77, return_tensors="pt" ) text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0] if guidance_scale > 1.: if neg_prompt: uc_text = neg_prompt else: uc_text = "" unconditional_input = self.tokenizer( [uc_text] * batch_size, padding="max_length", max_length=77, return_tensors="pt" ) unconditional_embeddings = self.text_encoder( unconditional_input.input_ids.to(DEVICE))[0] text_embeddings = torch.cat( [unconditional_embeddings, text_embeddings], dim=0) return text_embeddings def __call__( self, img_0=None, img_1=None, img_path_0=None, img_path_1=None, prompt_0="", prompt_1="", imgs=[], img_paths=None, prompts=None, save_lora_dir="./lora", load_lora_path_0=None, load_lora_path_1=None, load_lora_paths=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, batch_size=1, height=512, width=512, num_inference_steps=50, num_actual_inference_steps=None, guidance_scale=1, attn_beta=0, lamd=0.6, use_lora=True, use_adain=True, use_reschedule=True, output_path = "./results", num_frames=50, fix_lora=None, progress=tqdm, unconditioning=None, neg_prompt=None, save_intermediates=False, **kwds): self.scheduler.set_timesteps(num_inference_steps) self.use_lora = use_lora self.use_adain = use_adain self.use_reschedule = use_reschedule self.output_path = output_path imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths] assert len(prompts) == len(imgs) # if img_path_0 or img_0: # img_paths = [img_path_0, img_path_1] # prompts = [prompt_0, prompt_1] # load_lora_paths = [load_lora_path_0, load_lora_path_1] # if img_0: # imgs.append(Image.fromarray(img_0).convert("RGB")) # if img_1: # imgs.append(Image.fromarray(img_1).convert("RGB")) # if imgs is None: # imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths] # if len(prompts) < len(imgs): # prompts += ["" for _ in range(len(imgs) - len(prompts))] if self.use_lora: loras = [] print("Loading lora...") for i, (img, prompt) in enumerate(zip(imgs, prompts)): if len(load_lora_paths) == i: weight_name = f"{output_path.split('/')[-1]}_lora_{i}.ckpt" load_lora_paths.append(save_lora_dir + "/" + weight_name) if not os.path.exists(load_lora_paths[i]): train_lora(img, prompt, save_lora_dir, None, self.tokenizer, self.text_encoder, self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) print(f"Load from {load_lora_paths[i]}.") if load_lora_paths[i].endswith(".safetensors"): loras.append(safetensors.torch.load_file( load_lora_paths[i], device="cpu")) else: loras.append(torch.load(load_lora_paths[i], map_location="cpu")) # if not load_lora_path_1: # weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt" # load_lora_path_1 = save_lora_dir + "/" + weight_name # if not os.path.exists(load_lora_path_1): # train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder, # self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) # print(f"Load from {load_lora_path_1}.") # if load_lora_path_1.endswith(".safetensors"): # lora_1 = safetensors.torch.load_file( # load_lora_path_1, device="cpu") # else: # lora_1 = torch.load(load_lora_path_1, map_location="cpu") def morph(alpha_list, progress, desc, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1): images = [] if attn_beta is not None: self.unet = load_lora(self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora) attn_processor_dict = {} for k in self.unet.attn_processors.keys(): if do_replace_attn(k): attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], self.img0_dict, k) else: attn_processor_dict[k] = self.unet.attn_processors[k] self.unet.set_attn_processor(attn_processor_dict) latents = self.cal_latent( num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha_list[0], False, fix_lora ) first_image = self.latent2image(latents) first_image = Image.fromarray(first_image) # if save_intermediates: # first_image.save(f"{self.output_path}/{0:02d}.png") self.unet = load_lora(self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora) attn_processor_dict = {} for k in self.unet.attn_processors.keys(): if do_replace_attn(k): attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], self.img1_dict, k) else: attn_processor_dict[k] = self.unet.attn_processors[k] self.unet.set_attn_processor(attn_processor_dict) latents = self.cal_latent( num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha_list[-1], False, fix_lora ) last_image = self.latent2image(latents) last_image = Image.fromarray(last_image) # if save_intermediates: # last_image.save( # f"{self.output_path}/{num_frames - 1:02d}.png") for i in progress.tqdm(range(1, num_frames - 1), desc=desc): alpha = alpha_list[i] self.unet = load_lora(self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora) attn_processor_dict = {} for k in self.unet.attn_processors.keys(): if do_replace_attn(k): attn_processor_dict[k] = LoadProcessor( self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd) else: attn_processor_dict[k] = self.unet.attn_processors[k] self.unet.set_attn_processor(attn_processor_dict) latents = self.cal_latent( num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha_list[i], False, fix_lora ) image = self.latent2image(latents) image = Image.fromarray(image) # if save_intermediates: # image.save(f"{self.output_path}/{i:02d}.png") images.append(image) images = [first_image] + images + [last_image] else: for k, alpha in enumerate(alpha_list): latents = self.cal_latent( num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha_list[k], self.use_lora, fix_lora ) image = self.latent2image(latents) image = Image.fromarray(image) # if save_intermediates: # image.save(f"{self.output_path}/{k:02d}.png") images.append(image) return images images = [] for img_0, img_1, prompt_0, prompt_1, lora_0, lora_1 in zip(imgs[:-1], imgs[1:], prompts[:-1], prompts[1:], loras[:-1], loras[1:]): text_embeddings_0 = self.get_text_embeddings( prompt_0, guidance_scale, neg_prompt, batch_size) text_embeddings_1 = self.get_text_embeddings( prompt_1, guidance_scale, neg_prompt, batch_size) img_0 = get_img(img_0) img_1 = get_img(img_1) if self.use_lora: self.unet = load_lora(self.unet, lora_0, lora_1, 0) img_noise_0 = self.ddim_inversion( self.image2latent(img_0), text_embeddings_0) if self.use_lora: self.unet = load_lora(self.unet, lora_0, lora_1, 1) img_noise_1 = self.ddim_inversion( self.image2latent(img_1), text_embeddings_1) print("latents shape: ", img_noise_0.shape) with torch.no_grad(): if self.use_reschedule: alpha_scheduler = AlphaScheduler() alpha_list = list(torch.linspace(0, 1, num_frames)) images_pt = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) images_pt = [transforms.ToTensor()(img).unsqueeze(0) for img in images_pt] alpha_scheduler.from_imgs(images_pt) alpha_list = alpha_scheduler.get_list() print(alpha_list) images_ = morph(alpha_list, progress, "Reschedule...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) else: alpha_list = list(torch.linspace(0, 1, num_frames)) print(alpha_list) images_ = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) if len(images) == 0: images = images_ else: images += images_[1:] if save_intermediates: for i, image in enumerate(images): image.save(f"{self.output_path}/{i:02d}.png") return images