# from functools import cache import importlib import gradio as gr import matplotlib.pyplot as plt import torch import torchvision import wandb from icecream import ic from torch import nn from torchvision.transforms.functional import resize from tqdm import tqdm from transformers import CLIPModel, CLIPProcessor import lpips from edit import blend_paths from img_processing import * from img_processing import custom_to_pil from loaders import load_default import glob import gc global log log=False # ic.disable() # ic.enable() def get_resized_tensor(x): if len(x.shape) == 2: re = x.unsqueeze(0) else: re = x re = resize(re, (10, 10)) return re class ProcessorGradientFlow(): """ This wraps the huggingface CLIP processor to allow backprop through the image processing step. The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow. """ def __init__(self, device="cuda") -> None: self.device = device self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.image_mean = [0.48145466, 0.4578275, 0.40821073] self.image_std = [0.26862954, 0.26130258, 0.27577711] self.normalize = torchvision.transforms.Normalize( self.image_mean, self.image_std ) self.resize = torchvision.transforms.Resize(224) self.center_crop = torchvision.transforms.CenterCrop(224) def preprocess_img(self, images): images = self.center_crop(images) images = self.resize(images) images = self.center_crop(images) images = self.normalize(images) return images def __call__(self, images=[], **kwargs): processed_inputs = self.processor(**kwargs) processed_inputs["pixel_values"] = self.preprocess_img(images) processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()} return processed_inputs class ImagePromptOptimizer(nn.Module): def __init__(self, vqgan, clip, clip_preprocessor, lpips_fn, iterations=100, lr = 0.01, save_vector=True, return_val="vector", quantize=True, make_grid=False, lpips_weight = 6.2) -> None: super().__init__() self.latent = None self.device = vqgan.device vqgan.eval() self.vqgan = vqgan self.clip = clip self.iterations = iterations self.lr = lr self.clip_preprocessor = clip_preprocessor self.make_grid = make_grid self.return_val = return_val self.quantize = quantize self.lpips_weight = lpips_weight self.perceptual_loss = lpips_fn def set_latent(self, latent): self.latent = latent.detach().to(self.device) def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask): self._attn_mask = attn_mask self.iterations = iterations self.lr = lr self.lpips_weight = lpips_weight self.reconstruction_steps = reconstruction_steps def forward(self, vector): base_latent = self.latent.detach().requires_grad_() trans_latent = base_latent + vector if self.quantize: z_q, *_ = self.vqgan.quantize(trans_latent) else: z_q = trans_latent dec = self.vqgan.decode(z_q) return dec def _get_clip_similarity(self, prompts, image, weights=None): if isinstance(prompts, str): prompts = [prompts] elif not isinstance(prompts, list): raise TypeError("Provide prompts as string or list of strings") clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True) clip_outputs = self.clip(**clip_inputs) similarity_logits = clip_outputs.logits_per_image if weights: similarity_logits *= weights return similarity_logits.sum() def get_similarity_loss(self, pos_prompts, neg_prompts, image): pos_logits = self._get_clip_similarity(pos_prompts, image) if neg_prompts: neg_logits = self._get_clip_similarity(neg_prompts, image) else: neg_logits = torch.tensor([1], device=self.device) loss = -torch.log(pos_logits) + torch.log(neg_logits) return loss def visualize(self, processed_img): if self.make_grid: self.index += 1 plt.subplot(1, 13, self.index) plt.imshow(get_pil(processed_img[0]).detach().cpu()) else: plt.imshow(get_pil(processed_img[0]).detach().cpu()) plt.show() def _attn_mask(self, grad): newgrad = grad if self._attn_mask is not None: newgrad = grad * (self._attn_mask) return newgrad def _attn_mask_inverse(self, grad): newgrad = grad if self._attn_mask is not None: newgrad = grad * ((self._attn_mask - 1) * -1) return newgrad def _get_next_inputs(self, transformed_img): processed_img = loop_post_process(transformed_img) #* self.attn_mask processed_img.retain_grad() lpips_input = processed_img.clone() lpips_input.register_hook(self._attn_mask_inverse) lpips_input.retain_grad() clip_input = processed_img.clone() clip_input.register_hook(self._attn_mask) clip_input.retain_grad() return processed_img, lpips_input, clip_input def optimize(self, latent, pos_prompts, neg_prompts): self.set_latent(latent) transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device)) original_img = loop_post_process(transformed_img) vector = torch.randn_like(self.latent, requires_grad=True, device=self.device) optim = torch.optim.Adam([vector], lr=self.lr) if self.make_grid: plt.figure(figsize=(35, 25)) self.index = 1 for i in tqdm(range(self.iterations)): optim.zero_grad() transformed_img = self(vector) processed_img, lpips_input, clip_input = self._get_next_inputs(transformed_img) with torch.autocast("cuda"): clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_input) print("CLIP loss", clip_loss) perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight print("LPIPS loss: ", perceptual_loss) if log: wandb.log({"Perceptual Loss": perceptual_loss}) wandb.log({"CLIP Loss": clip_loss}) clip_loss.backward(retain_graph=True) perceptual_loss.backward(retain_graph=True) p2 = processed_img.grad print("Sum Loss", perceptual_loss + clip_loss) optim.step() # if i % self.iterations // 10 == 0: # self.visualize(transformed_img) yield vector if self.make_grid: plt.savefig(f"plot {pos_prompts[0]}.png") plt.show() print("lpips solo op") for i in range(self.reconstruction_steps): optim.zero_grad() transformed_img = self(vector) processed_img = loop_post_process(transformed_img) #* self.attn_mask processed_img.retain_grad() lpips_input = processed_img.clone() lpips_input.register_hook(self._attn_mask_inverse) lpips_input.retain_grad() with torch.autocast("cuda"): perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight if log: wandb.log({"Perceptual Loss": perceptual_loss}) print("LPIPS loss: ", perceptual_loss) perceptual_loss.backward(retain_graph=True) optim.step() yield vector yield vector if self.return_val == "vector" else self.latent + vector