Spaces:
Paused
Paused
| import os | |
| from glob import glob | |
| import imageio | |
| import torch | |
| import torchvision | |
| import wandb | |
| from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan | |
| from loaders import load_vqgan | |
| from PIL import Image | |
| from torch import nn | |
| from transformers import CLIPModel, CLIPTokenizerFast | |
| from utils import get_device, get_timestamp, show_pil | |
| class ProcessorGradientFlow: | |
| """ | |
| This wraps the huggingface CLIP processor to allow backprop through the image processing step. | |
| The original processor forces conversion to PIL images, which is faster for image processing but breaks gradient flow. | |
| We call the original processor to get the text embeddings, but use our own image processing to keep images as torch tensors. | |
| """ | |
| def __init__(self, device: str = "cpu", clip_model: str = "openai/clip-vit-large-patch14") -> None: | |
| self.device = device | |
| self.tokenizer = CLIPTokenizerFast.from_pretrained(clip_model) | |
| self.image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| self.image_std = [0.26862954, 0.26130258, 0.27577711] | |
| self.normalize = torchvision.transforms.Normalize(self.image_mean, self.image_std) | |
| self.resize = torchvision.transforms.Resize(224) | |
| self.center_crop = torchvision.transforms.CenterCrop(224) | |
| def preprocess_img(self, images): | |
| images = self.resize(images) | |
| images = self.center_crop(images) | |
| images = self.normalize(images) | |
| return images | |
| def __call__(self, text=None, images=None, **kwargs): | |
| encoding = self.tokenizer(text=text, **kwargs) | |
| encoding["pixel_values"] = self.preprocess_img(images) | |
| encoding = {key: value.to(self.device) for (key, value) in encoding.items()} | |
| return encoding | |
| class VQGAN_CLIP(nn.Module): | |
| def __init__( | |
| self, | |
| iterations=10, | |
| lr=0.01, | |
| vqgan=None, | |
| vqgan_config=None, | |
| vqgan_checkpoint=None, | |
| clip=None, | |
| clip_preprocessor=None, | |
| device=None, | |
| log=False, | |
| save_vector=True, | |
| return_val="image", | |
| quantize=True, | |
| save_intermediate=False, | |
| show_intermediate=False, | |
| make_grid=False, | |
| ) -> None: | |
| """ | |
| Instantiate a VQGAN_CLIP model. If you want to use a custom VQGAN model, pass it as vqgan. | |
| """ | |
| super().__init__() | |
| self.latent = None | |
| self.device = device if device else get_device() | |
| if vqgan: | |
| self.vqgan = vqgan | |
| else: | |
| self.vqgan = load_vqgan(self.device, conf_path=vqgan_config, ckpt_path=vqgan_checkpoint) | |
| self.vqgan.eval() | |
| if clip: | |
| self.clip = clip | |
| else: | |
| self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.clip.to(self.device) | |
| self.clip_preprocessor = ProcessorGradientFlow(device=self.device) | |
| self.iterations = iterations | |
| self.lr = lr | |
| self.log = log | |
| self.make_grid = make_grid | |
| self.return_val = return_val | |
| self.quantize = quantize | |
| self.latent_dim = self.vqgan.decoder.z_shape | |
| def make_animation(self, input_path=None, output_path=None, total_duration=5, extend_frames=True): | |
| """ | |
| Make an animation from the intermediate images saved during generation. | |
| By default, uses the images from the most recent generation created by the generate function. | |
| If you want to use images from a different generation, pass the path to the folder containing the images as input_path. | |
| """ | |
| images = [] | |
| if output_path is None: | |
| output_path = "./animation.gif" | |
| if input_path is None: | |
| input_path = self.save_path | |
| paths = sorted(glob(input_path + "/*")) | |
| if not len(paths): | |
| raise ValueError( | |
| "No images found in save path, aborting (did you pass save_intermediate=True to the generate" | |
| " function?)" | |
| ) | |
| if len(paths) == 1: | |
| print("Only one image found in save path, (did you pass save_intermediate=True to the generate function?)") | |
| frame_duration = total_duration / len(paths) | |
| durations = [frame_duration] * len(paths) | |
| if extend_frames: | |
| durations[0] = 1.5 | |
| durations[-1] = 3 | |
| for file_name in paths: | |
| if file_name.endswith(".png"): | |
| images.append(imageio.imread(file_name)) | |
| imageio.mimsave(output_path, images, duration=durations) | |
| print(f"gif saved to {output_path}") | |
| def _get_latent(self, path=None, img=None): | |
| if not (path or img): | |
| raise ValueError("Input either path or tensor") | |
| if img is not None: | |
| raise NotImplementedError | |
| x = preprocess(Image.open(path), target_image_size=256).to(self.device) | |
| x_processed = preprocess_vqgan(x) | |
| z, *_ = self.vqgan.encode(x_processed) | |
| return z | |
| def _add_vector(self, transform_vector): | |
| """Add a vector transform to the base latent and returns the resulting image.""" | |
| base_latent = self.latent.detach().requires_grad_() | |
| trans_latent = base_latent + transform_vector | |
| if self.quantize: | |
| z_q, *_ = self.vqgan.quantize(trans_latent) | |
| else: | |
| z_q = trans_latent | |
| return self.vqgan.decode(z_q) | |
| def _get_clip_similarity(self, prompts, image, weights=None): | |
| clip_inputs = self.clip_preprocessor(text=prompts, images=image, return_tensors="pt", padding=True) | |
| clip_outputs = self.clip(**clip_inputs) | |
| similarity_logits = clip_outputs.logits_per_image | |
| if weights is not None: | |
| similarity_logits = similarity_logits * weights | |
| return similarity_logits.sum() | |
| def _get_clip_loss(self, pos_prompts, neg_prompts, image): | |
| pos_logits = self._get_clip_similarity(pos_prompts["prompts"], image, weights=(1 / pos_prompts["weights"])) | |
| if neg_prompts: | |
| neg_logits = self._get_clip_similarity(neg_prompts["prompts"], image, weights=neg_prompts["weights"]) | |
| else: | |
| neg_logits = torch.tensor([1], device=self.device) | |
| loss = -torch.log(pos_logits) + torch.log(neg_logits) | |
| return loss | |
| def _optimize_CLIP(self, original_img, pos_prompts, neg_prompts): | |
| vector = torch.randn_like(self.latent, requires_grad=True, device=self.device) | |
| optim = torch.optim.Adam([vector], lr=self.lr) | |
| for i in range(self.iterations): | |
| optim.zero_grad() | |
| transformed_img = self._add_vector(vector) | |
| processed_img = loop_post_process(transformed_img) | |
| clip_loss = self._get_CLIP_loss(pos_prompts, neg_prompts, processed_img) | |
| print("CLIP loss", clip_loss) | |
| if self.log: | |
| wandb.log({"CLIP Loss": clip_loss}) | |
| clip_loss.backward(retain_graph=True) | |
| optim.step() | |
| if self.return_val == "image": | |
| yield custom_to_pil(transformed_img[0]) | |
| else: | |
| yield vector | |
| def _init_logging(self, positive_prompts, negative_prompts, image_path): | |
| wandb.init(reinit=True, project="face-editor") | |
| wandb.config.update({"Positive Prompts": positive_prompts}) | |
| wandb.config.update({"Negative Prompts": negative_prompts}) | |
| wandb.config.update({"lr": self.lr, "iterations": self.iterations}) | |
| if image_path: | |
| image = Image.open(image_path) | |
| image = image.resize((256, 256)) | |
| wandb.log("Original Image", wandb.Image(image)) | |
| def process_prompts(self, prompts): | |
| if not prompts: | |
| return [] | |
| processed_prompts = [] | |
| weights = [] | |
| if isinstance(prompts, str): | |
| prompts = [prompt.strip() for prompt in prompts.split("|")] | |
| for prompt in prompts: | |
| if isinstance(prompt, (tuple, list)): | |
| processed_prompt = prompt[0] | |
| weight = float(prompt[1]) | |
| elif ":" in prompt: | |
| processed_prompt, weight = prompt.split(":") | |
| weight = float(weight) | |
| else: | |
| processed_prompt = prompt | |
| weight = 1.0 | |
| processed_prompts.append(processed_prompt) | |
| weights.append(weight) | |
| return { | |
| "prompts": processed_prompts, | |
| "weights": torch.tensor(weights, device=self.device), | |
| } | |
| def generate( | |
| self, | |
| pos_prompts, | |
| neg_prompts=None, | |
| image_path=None, | |
| show_intermediate=True, | |
| save_intermediate=False, | |
| show_final=True, | |
| save_final=True, | |
| save_path=None, | |
| ): | |
| """Generate an image from the given prompts. | |
| If image_path is provided, the image is used as a starting point for the optimization. | |
| If image_path is not provided, a random latent vector is used as a starting point. | |
| You must provide at least one positive prompt, and optionally provide negative prompts. | |
| Prompts must be formatted in one of the following ways: | |
| - A single prompt as a string, e.g "A smiling woman" | |
| - A set of prompts separated by pipes: "A smiling woman | a woman with brown hair" | |
| - A set of prompts and their weights separated by colons: "A smiling woman:1 | a woman with brown hair: 3" (default weight is 1) | |
| - A list of prompts, e.g ["A smiling woman", "a woman with brown hair"] | |
| - A list of prompts and weights, e.g [("A smiling woman", 1), ("a woman with brown hair", 3)] | |
| """ | |
| if image_path: | |
| self.latent = self._get_latent(image_path) | |
| else: | |
| self.latent = torch.randn(self.latent_dim, device=self.device) | |
| if self.log: | |
| self._init_logging(pos_prompts, neg_prompts, image_path) | |
| assert pos_prompts, "You must provide at least one positive prompt." | |
| pos_prompts = self.process_prompts(pos_prompts) | |
| neg_prompts = self.process_prompts(neg_prompts) | |
| if save_final and save_path is None: | |
| save_path = os.path.join("./outputs/", "_".join(pos_prompts["prompts"])) | |
| if not os.path.exists(save_path): | |
| os.makedirs(save_path) | |
| else: | |
| save_path = save_path + "_" + get_timestamp() | |
| os.makedirs(save_path) | |
| self.save_path = save_path | |
| original_img = self.vqgan.decode(self.latent)[0] | |
| if show_intermediate: | |
| print("Original Image") | |
| show_pil(custom_to_pil(original_img)) | |
| original_img = loop_post_process(original_img) | |
| for iter, transformed_img in enumerate(self._optimize_CLIP(original_img, pos_prompts, neg_prompts)): | |
| if show_intermediate: | |
| show_pil(transformed_img) | |
| if save_intermediate: | |
| transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}.png")) | |
| if self.log: | |
| wandb.log({"Image": wandb.Image(transformed_img)}) | |
| if show_final: | |
| show_pil(transformed_img) | |
| if save_final: | |
| transformed_img.save(os.path.join(self.save_path, f"iter_{iter:03d}_final.png")) | |