Spaces:
Configuration error
Configuration error
# from align import align_from_path | |
from animation import clear_img_dir | |
from app_backend import ImagePromptOptimizer, log | |
from functools import cache | |
import importlib | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import torch | |
import torchvision | |
import wandb | |
from icecream import ic | |
from torch import nn | |
from torchvision.transforms.functional import resize | |
from tqdm import tqdm | |
from transformers import CLIPModel, CLIPProcessor | |
import lpips | |
from app_backend import get_resized_tensor | |
from edit import blend_paths | |
from img_processing import * | |
from img_processing import custom_to_pil | |
from loaders import load_default | |
num = 0 | |
class PromptTransformHistory(): | |
def __init__(self, iterations) -> None: | |
self.iterations = iterations | |
self.transforms = [] | |
class ImageState: | |
def __init__(self, vqgan, prompt_optimizer: ImagePromptOptimizer) -> None: | |
self.vqgan = vqgan | |
self.device = vqgan.device | |
self.blend_latent = None | |
self.quant = True | |
self.path1 = None | |
self.path2 = None | |
self.transform_history = [] | |
self.attn_mask = None | |
self.prompt_optim = prompt_optimizer | |
self._load_vectors() | |
self.init_transforms() | |
def _load_vectors(self): | |
self.lip_vector = torch.load("./latent_vectors/lipvector.pt", map_location=self.device) | |
self.red_blue_vector = torch.load("./latent_vectors/2blue_eyes.pt", map_location=self.device) | |
self.green_purple_vector = torch.load("./latent_vectors/nose_vector.pt", map_location=self.device) | |
self.asian_vector = torch.load("./latent_vectors/asian10.pt", map_location=self.device) | |
def init_transforms(self): | |
self.blue_eyes = torch.zeros_like(self.lip_vector) | |
self.lip_size = torch.zeros_like(self.lip_vector) | |
self.asian_transform = torch.zeros_like(self.lip_vector) | |
self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)] | |
self.hair_gp = torch.zeros_like(self.lip_vector) | |
def clear_transforms(self): | |
global num | |
self.init_transforms() | |
clear_img_dir() | |
num = 0 | |
return self._render_all_transformations() | |
def _apply_vector(self, src, vector): | |
new_latent = torch.lerp(src, src + vector, 1) | |
return new_latent | |
def _decode_latent_to_pil(self, latent): | |
current_im = self.vqgan.decode(latent.to(self.device))[0] | |
return custom_to_pil(current_im) | |
# def _get_current_vector_transforms(self): | |
# current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms)) | |
# return (self.blend_latent, current_vector_transforms) | |
# @cache | |
def get_mask(self, img, mask=None): | |
if img and "mask" in img and img["mask"] is not None: | |
attn_mask = torchvision.transforms.ToTensor()(img["mask"]) | |
attn_mask = torch.ceil(attn_mask[0].to(self.device)) | |
plt.imshow(attn_mask.detach().cpu(), cmap="Blues") | |
plt.show() | |
torch.save(attn_mask, "test_mask.pt") | |
print("mask set successfully") | |
# attn_mask = self.rescale_mask(attn_mask) | |
print(type(attn_mask)) | |
print(attn_mask.shape) | |
else: | |
attn_mask = mask | |
print("mask in apply ", get_resized_tensor(attn_mask), get_resized_tensor(attn_mask).shape) | |
return attn_mask | |
def set_mask(self, img): | |
attn_mask = self.get_mask(img) | |
self.attn_mask = attn_mask | |
# attn_mask = torch.ones_like(img, device=self.device) | |
x = attn_mask.clone() | |
x = x.detach().cpu() | |
x = torch.clamp(x, -1., 1.) | |
x = (x + 1.)/2. | |
x = x.numpy() | |
x = (255*x).astype(np.uint8) | |
x = Image.fromarray(x, "L") | |
return x | |
def _render_all_transformations(self, return_twice=True): | |
global num | |
current_vector_transforms = (self.blue_eyes, self.lip_size, self.hair_gp, self.asian_transform, sum(self.current_prompt_transforms)) | |
new_latent = self.blend_latent + sum(current_vector_transforms) | |
if self.quant: | |
new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device)) | |
image = self._decode_latent_to_pil(new_latent) | |
img_dir = "./img_history" | |
if not os.path.exists(img_dir): | |
os.mkdir(img_dir) | |
image.save(f"./img_history/img_{num:06}.png") | |
num += 1 | |
return (image, image) if return_twice else image | |
def apply_gp_vector(self, weight): | |
self.hair_gp = weight * self.green_purple_vector | |
return self._render_all_transformations() | |
def apply_rb_vector(self, weight): | |
self.blue_eyes = weight * self.red_blue_vector | |
return self._render_all_transformations() | |
def apply_lip_vector(self, weight): | |
self.lip_size = weight * self.lip_vector | |
return self._render_all_transformations() | |
def update_requant(self, val): | |
print(f"val = {val}") | |
self.quant = val | |
return self._render_all_transformations() | |
def apply_gender_vector(self, weight): | |
self.asian_transform = weight * self.asian_vector | |
return self._render_all_transformations() | |
def update_images(self, path1, path2, blend_weight): | |
if path1 is None and path2 is None: | |
return None | |
if path1 is None: path1 = path2 | |
if path2 is None: path2 = path1 | |
self.path1, self.path2 = path1, path2 | |
# self.aligned_path1 = align_from_path(path1) | |
# self.aligned_path2 = align_from_path(path2) | |
return self.blend(blend_weight) | |
def blend(self, weight): | |
_, latent = blend_paths(self.vqgan, self.path1, self.path2, weight=weight, show=False, device=self.device) | |
self.blend_latent = latent | |
return self._render_all_transformations() | |
def rewind(self, index): | |
if not self.transform_history: | |
print("no history") | |
return self._render_all_transformations() | |
prompt_transform = self.transform_history[-1] | |
latent_index = int(index / 100 * (prompt_transform.iterations - 1)) | |
print(latent_index) | |
self.current_prompt_transforms[-1] = prompt_transform.transforms[latent_index] | |
# print(self.current_prompt_transform) | |
# print(self.current_prompt_transforms.mean()) | |
return self._render_all_transformations() | |
def rescale_mask(self, mask): | |
rep = mask.clone() | |
rep[mask < 0.03] = -1000000 | |
rep[mask >= 0.03] = 1 | |
return rep | |
def apply_prompts(self, positive_prompts, negative_prompts, lr, iterations, lpips_weight, reconstruction_steps): | |
transform_log = PromptTransformHistory(iterations + reconstruction_steps) | |
transform_log.transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False)) | |
self.current_prompt_transforms.append(torch.zeros_like(self.blend_latent, requires_grad=False)) | |
if log: | |
wandb.init(reinit=True, project="face-editor") | |
wandb.config.update({"Positive Prompts": positive_prompts}) | |
wandb.config.update({"Negative Prompts": negative_prompts}) | |
wandb.config.update(dict( | |
lr=lr, | |
iterations=iterations, | |
lpips_weight=lpips_weight | |
)) | |
positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")] | |
negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")] | |
self.prompt_optim.set_params(lr, iterations, lpips_weight, attn_mask=self.attn_mask, reconstruction_steps=reconstruction_steps) | |
for i, transform in enumerate(self.prompt_optim.optimize(self.blend_latent, | |
positive_prompts, | |
negative_prompts)): | |
transform_log.transforms.append(transform.clone().detach()) | |
self.current_prompt_transforms[-1] = transform | |
with torch.no_grad(): | |
image = self._render_all_transformations(return_twice=False) | |
if log: | |
wandb.log({"image": wandb.Image(image)}) | |
yield (image, image) | |
if log: | |
wandb.finish() | |
self.attn_mask = None | |
self.transform_history.append(transform_log) | |
# transform = self.prompt_optim.optimize(self.blend_latent, | |
# positive_prompts, | |
# negative_prompts) | |
# self.prompt_transforms = transform | |
# return self._render_all_transformations() |