import matplotlib.pyplot as plt import torch import torchvision import wandb from torch import nn from tqdm import tqdm from transformers import CLIPProcessor from img_processing import get_pil, loop_post_process global log log = False class ProcessorGradientFlow: """ This wraps the huggingface CLIP processor to allow backprop through the image processing step. The original processor forces conversion to numpy then PIL images, which is faster for image processing but breaks gradient flow. """ def __init__(self, device="cuda") -> None: self.device = device self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") self.image_mean = [0.48145466, 0.4578275, 0.40821073] self.image_std = [0.26862954, 0.26130258, 0.27577711] self.normalize = torchvision.transforms.Normalize( self.image_mean, self.image_std ) self.resize = torchvision.transforms.Resize(224) self.center_crop = torchvision.transforms.CenterCrop(224) def preprocess_img(self, images): images = self.center_crop(images) images = self.resize(images) images = self.center_crop(images) images = self.normalize(images) return images def __call__(self, images=[], **kwargs): processed_inputs = self.processor(**kwargs) processed_inputs["pixel_values"] = self.preprocess_img(images) processed_inputs = { key: value.to(self.device) for (key, value) in processed_inputs.items() } return processed_inputs class ImagePromptEditor(nn.Module): def __init__( self, vqgan, clip, clip_preprocessor, lpips_fn, iterations=100, lr=0.01, save_vector=True, return_val="vector", quantize=True, make_grid=False, lpips_weight=6.2, ) -> None: super().__init__() self.latent = None self.device = vqgan.device vqgan.eval() self.vqgan = vqgan self.clip = clip self.iterations = iterations self.lr = lr self.clip_preprocessor = clip_preprocessor self.make_grid = make_grid self.return_val = return_val self.quantize = quantize self.lpips_weight = lpips_weight self.perceptual_loss = lpips_fn def set_latent(self, latent): self.latent = latent.detach().to(self.device) def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask): self._attn_mask = attn_mask self.iterations = iterations self.lr = lr self.lpips_weight = lpips_weight self.reconstruction_steps = reconstruction_steps def forward(self, vector): base_latent = self.latent.detach().requires_grad_() trans_latent = base_latent + vector if self.quantize: z_q, *_ = self.vqgan.quantize(trans_latent) else: z_q = trans_latent dec = self.vqgan.decode(z_q) return dec def _get_clip_similarity(self, prompts, image, weights=None): if isinstance(prompts, str): prompts = [prompts] elif not isinstance(prompts, list): raise TypeError("Provide prompts as string or list of strings") clip_inputs = self.clip_preprocessor( text=prompts, images=image, return_tensors="pt", padding=True ) clip_outputs = self.clip(**clip_inputs) similarity_logits = clip_outputs.logits_per_image if weights: similarity_logits *= weights return similarity_logits.sum() def _get_CLIP_loss(self, pos_prompts, neg_prompts, image): pos_logits = self._get_clip_similarity(pos_prompts, image) if neg_prompts: neg_logits = self._get_clip_similarity(neg_prompts, image) else: neg_logits = torch.tensor([1], device=self.device) loss = -torch.log(pos_logits) + torch.log(neg_logits) return loss def visualize(self, processed_img): if self.make_grid: self.index += 1 plt.subplot(1, 13, self.index) plt.imshow(get_pil(processed_img[0]).detach().cpu()) else: plt.imshow(get_pil(processed_img[0]).detach().cpu()) plt.show() def _attn_mask(self, grad): newgrad = grad if self._attn_mask is not None: newgrad = grad * (self._attn_mask) return newgrad def _attn_mask_inverse(self, grad): newgrad = grad if self._attn_mask is not None: newgrad = grad * ((self._attn_mask - 1) * -1) return newgrad def _get_next_inputs(self, transformed_img): processed_img = loop_post_process(transformed_img) # * self.attn_mask processed_img.retain_grad() lpips_input = processed_img.clone() lpips_input.register_hook(self._attn_mask_inverse) lpips_input.retain_grad() clip_input = processed_img.clone() clip_input.register_hook(self._attn_mask) clip_input.retain_grad() return (processed_img, lpips_input, clip_input) def _optimize_CLIP_LPIPS(self, optim, original_img, vector, pos_prompts, neg_prompts): optim.zero_grad() transformed_img = self(vector) processed_img, lpips_input, clip_input = self._get_next_inputs( transformed_img ) with torch.autocast("cuda"): clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, clip_input) print("CLIP loss", clip_loss) perceptual_loss = ( self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight ) print("LPIPS loss: ", perceptual_loss) print("Sum Loss", perceptual_loss + clip_loss) if log: wandb.log({"Perceptual Loss": perceptual_loss}) wandb.log({"CLIP Loss": clip_loss}) # These gradients will be masked if attn_mask has been set clip_loss.backward(retain_graph=True) perceptual_loss.backward(retain_graph=True) optim.step() yield vector def _optimize_LPIPS(self, vector, original_img, optim): optim.zero_grad() transformed_img = self(vector) processed_img = loop_post_process(transformed_img) # * self.attn_mask processed_img.retain_grad() lpips_input = processed_img.clone() lpips_input.register_hook(self._attn_mask_inverse) lpips_input.retain_grad() with torch.autocast("cuda"): perceptual_loss = ( self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight ) if log: wandb.log({"Perceptual Loss": perceptual_loss}) print("LPIPS loss: ", perceptual_loss) perceptual_loss.backward(retain_graph=True) optim.step() yield vector def optimize(self, latent, pos_prompts, neg_prompts): self.set_latent(latent) transformed_img = self( torch.zeros_like(self.latent, requires_grad=True, device=self.device) ) original_img = loop_post_process(transformed_img) vector = torch.randn_like(self.latent, requires_grad=True, device=self.device) optim = torch.optim.Adam([vector], lr=self.lr) for i in tqdm(range(self.iterations)): yield self._optimize_CLIP_LPIPS(optim, original_img, vector, pos_prompts, neg_prompts) print("Running LPIPS optim only") for i in range(self.reconstruction_steps): yield self._optimize_LPIPS(vector, original_img, transformed_img, optim) yield vector if self.return_val == "vector" else self.latent + vector