Spaces:
Configuration error
Configuration error
# from functools import cache | |
import importlib | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import torch | |
import torchvision | |
import wandb | |
from icecream import ic | |
from torch import nn | |
from torchvision.transforms.functional import resize | |
from tqdm import tqdm | |
from transformers import CLIPModel, CLIPProcessor | |
import lpips | |
from edit import blend_paths | |
from img_processing import * | |
from img_processing import custom_to_pil | |
from loaders import load_default | |
import glob | |
global log | |
log=False | |
# ic.disable() | |
# ic.enable() | |
def get_resized_tensor(x): | |
if len(x.shape) == 2: | |
re = x.unsqueeze(0) | |
else: re = x | |
re = resize(re, (10, 10)) | |
return re | |
class ProcessorGradientFlow(): | |
""" | |
This wraps the huggingface CLIP processor to allow backprop through the image processing step. | |
The original processor forces conversion to PIL images, which breaks gradient flow. | |
""" | |
def __init__(self, device="cuda") -> None: | |
self.device = device | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
self.image_mean = [0.48145466, 0.4578275, 0.40821073] | |
self.image_std = [0.26862954, 0.26130258, 0.27577711] | |
self.normalize = torchvision.transforms.Normalize( | |
self.image_mean, | |
self.image_std | |
) | |
self.resize = torchvision.transforms.Resize(224) | |
self.center_crop = torchvision.transforms.CenterCrop(224) | |
def preprocess_img(self, images): | |
images = self.center_crop(images) | |
images = self.resize(images) | |
images = self.center_crop(images) | |
images = self.normalize(images) | |
return images | |
def __call__(self, images=[], **kwargs): | |
processed_inputs = self.processor(**kwargs) | |
processed_inputs["pixel_values"] = self.preprocess_img(images) | |
processed_inputs = {key:value.to(self.device) for (key, value) in processed_inputs.items()} | |
return processed_inputs | |
class ImagePromptOptimizer(nn.Module): | |
def __init__(self, | |
vqgan, | |
clip, | |
clip_preprocessor, | |
lpips_fn, | |
iterations=100, | |
lr = 0.01, | |
save_vector=True, | |
return_val="vector", | |
quantize=True, | |
make_grid=False, | |
lpips_weight = 6.2) -> None: | |
super().__init__() | |
self.latent = None | |
self.device = vqgan.device | |
vqgan.eval() | |
self.vqgan = vqgan | |
self.clip = clip | |
self.iterations = iterations | |
self.lr = lr | |
self.clip_preprocessor = clip_preprocessor | |
self.make_grid = make_grid | |
self.return_val = return_val | |
self.quantize = quantize | |
self.lpips_weight = lpips_weight | |
self.perceptual_loss = lpips_fn | |
def set_latent(self, latent): | |
self.latent = latent.detach().to(self.device) | |
def set_params(self, lr, iterations, lpips_weight, reconstruction_steps, attn_mask): | |
self.attn_mask = attn_mask | |
self.iterations = iterations | |
self.lr = lr | |
self.lpips_weight = lpips_weight | |
self.reconstruction_steps = reconstruction_steps | |
def forward(self, vector): | |
base_latent = self.latent.detach().requires_grad_() | |
trans_latent = base_latent + vector | |
if self.quantize: | |
z_q, *_ = self.vqgan.quantize(trans_latent) | |
else: | |
z_q = trans_latent | |
dec = self.vqgan.decode(z_q) | |
return dec | |
def _get_clip_similarity(self, prompts, image, weights=None): | |
if isinstance(prompts, str): | |
prompts = [prompts] | |
elif not isinstance(prompts, list): | |
raise TypeError("Provide prompts as string or list of strings") | |
clip_inputs = self.clip_preprocessor(text=prompts, | |
images=image, return_tensors="pt", padding=True) | |
clip_outputs = self.clip(**clip_inputs) | |
similarity_logits = clip_outputs.logits_per_image | |
if weights: | |
similarity_logits *= weights | |
return similarity_logits.sum() | |
def get_similarity_loss(self, pos_prompts, neg_prompts, image): | |
pos_logits = self._get_clip_similarity(pos_prompts, image) | |
if neg_prompts: | |
neg_logits = self._get_clip_similarity(neg_prompts, image) | |
else: | |
neg_logits = torch.tensor([1], device=self.device) | |
loss = -torch.log(pos_logits) + torch.log(neg_logits) | |
return loss | |
def visualize(self, processed_img): | |
if self.make_grid: | |
self.index += 1 | |
plt.subplot(1, 13, self.index) | |
plt.imshow(get_pil(processed_img[0]).detach().cpu()) | |
else: | |
plt.imshow(get_pil(processed_img[0]).detach().cpu()) | |
plt.show() | |
def attn_masking(self, grad): | |
# print("attnmask 1") | |
# print(f"input grad.shape = {grad.shape}") | |
# print(f"input grad = {get_resized_tensor(grad)}") | |
newgrad = grad | |
if self.attn_mask is not None: | |
# print("masking mult") | |
newgrad = grad * (self.attn_mask) | |
# print("output grad, ", get_resized_tensor(newgrad)) | |
# print("end atn 1") | |
return newgrad | |
def attn_masking2(self, grad): | |
# print("attnmask 2") | |
# print(f"input grad.shape = {grad.shape}") | |
# print(f"input grad = {get_resized_tensor(grad)}") | |
newgrad = grad | |
if self.attn_mask is not None: | |
# print("masking mult") | |
newgrad = grad * ((self.attn_mask - 1) * -1) | |
# print("output grad, ", get_resized_tensor(newgrad)) | |
# print("end atn 2") | |
return newgrad | |
def optimize(self, latent, pos_prompts, neg_prompts): | |
self.set_latent(latent) | |
# self.make_grid=True | |
transformed_img = self(torch.zeros_like(self.latent, requires_grad=True, device=self.device)) | |
original_img = loop_post_process(transformed_img) | |
vector = torch.randn_like(self.latent, requires_grad=True, device=self.device) | |
optim = torch.optim.Adam([vector], lr=self.lr) | |
if self.make_grid: | |
plt.figure(figsize=(35, 25)) | |
self.index = 1 | |
for i in tqdm(range(self.iterations)): | |
optim.zero_grad() | |
transformed_img = self(vector) | |
processed_img = loop_post_process(transformed_img) #* self.attn_mask | |
processed_img.retain_grad() | |
lpips_input = processed_img.clone() | |
lpips_input.register_hook(self.attn_masking2) | |
lpips_input.retain_grad() | |
clip_clone = processed_img.clone() | |
clip_clone.register_hook(self.attn_masking) | |
clip_clone.retain_grad() | |
with torch.autocast("cuda"): | |
clip_loss = self.get_similarity_loss(pos_prompts, neg_prompts, clip_clone) | |
print("CLIP loss", clip_loss) | |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight | |
print("LPIPS loss: ", perceptual_loss) | |
if log: | |
wandb.log({"Perceptual Loss": perceptual_loss}) | |
wandb.log({"CLIP Loss": clip_loss}) | |
clip_loss.backward(retain_graph=True) | |
perceptual_loss.backward(retain_graph=True) | |
p2 = processed_img.grad | |
print("Sum Loss", perceptual_loss + clip_loss) | |
optim.step() | |
# if i % self.iterations // 10 == 0: | |
# self.visualize(transformed_img) | |
yield vector | |
if self.make_grid: | |
plt.savefig(f"plot {pos_prompts[0]}.png") | |
plt.show() | |
print("lpips solo op") | |
for i in range(self.reconstruction_steps): | |
optim.zero_grad() | |
transformed_img = self(vector) | |
processed_img = loop_post_process(transformed_img) #* self.attn_mask | |
processed_img.retain_grad() | |
lpips_input = processed_img.clone() | |
lpips_input.register_hook(self.attn_masking2) | |
lpips_input.retain_grad() | |
with torch.autocast("cuda"): | |
perceptual_loss = self.perceptual_loss(lpips_input, original_img.clone()) * self.lpips_weight | |
if log: | |
wandb.log({"Perceptual Loss": perceptual_loss}) | |
print("LPIPS loss: ", perceptual_loss) | |
perceptual_loss.backward(retain_graph=True) | |
optim.step() | |
yield vector | |
# torch.save(vector, "nose_vector.pt") | |
# print("") | |
# print("DISC STEPS") | |
# print("*************") | |
# for i in range(self.reconstruction_steps): | |
# optim.zero_grad() | |
# transformed_img = self(vector) | |
# processed_img = loop_post_process(transformed_img) #* self.attn_mask | |
# disc_logits = self.disc(transformed_img) | |
# disc_loss = self.disc_loss_fn(disc_logits) | |
# print(f"disc_loss = {disc_loss}") | |
# if log: | |
# wandb.log({"Disc Loss": disc_loss}) | |
# print("LPIPS loss: ", perceptual_loss) | |
# disc_loss.backward(retain_graph=True) | |
# optim.step() | |
# yield vector | |
yield vector if self.return_val == "vector" else self.latent + vector | |