from typing import List, Union import torch from PIL import Image from transformers import ( CLIPProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, ) from diffusers import StableDiffusionPipeline from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path import os import glob import math EXAMPLE_PROMPTS = [ " swimming in a pool", " at a beach with a view of seashore", " in times square", " wearing sunglasses", " in a construction outfit", " playing with a ball", " wearing headphones", " oil painting ghibli inspired", " working on the laptop", " with mountains and sunset in background", "Painting of at a beach by artist claude monet", " digital painting 3d render geometric style", "A screaming ", "A depressed ", "A sleeping ", "A sad ", "A joyous ", "A frowning ", "A sculpture of ", " near a pool", " at a beach with a view of seashore", " in a garden", " in grand canyon", " floating in ocean", " and an armchair", "A maple tree on the side of ", " and an orange sofa", " with chocolate cake on it", " with a vase of rose flowers on it", "A digital illustration of ", "Georgia O'Keeffe style painting", "A watercolor painting of on a beach", ] def image_grid(_imgs, rows=None, cols=None): if rows is None and cols is None: rows = cols = math.ceil(len(_imgs) ** 0.5) if rows is None: rows = math.ceil(len(_imgs) / cols) if cols is None: cols = math.ceil(len(_imgs) / rows) w, h = _imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) grid_w, grid_h = grid.size for i, img in enumerate(_imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def text_img_alignment(img_embeds, text_embeds, target_img_embeds): # evaluation inspired from textual inversion paper # https://arxiv.org/abs/2208.01618 # text alignment assert img_embeds.shape[0] == text_embeds.shape[0] text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / ( img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1) ) # image alignment img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True) avg_target_img_embed = ( (target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True)) .mean(dim=0) .unsqueeze(0) .repeat(img_embeds.shape[0], 1) ) img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1) return { "text_alignment_avg": text_img_sim.mean().item(), "image_alignment_avg": img_img_sim.mean().item(), "text_alignment_all": text_img_sim.tolist(), "image_alignment_all": img_img_sim.tolist(), } def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"): text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id) tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id) vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id) processor = CLIPProcessor.from_pretrained(eval_clip_id) return text_model, tokenizer, vis_model, processor def evaluate_pipe( pipe, target_images: List[Image.Image], class_token: str = "", learnt_token: str = "", guidance_scale: float = 5.0, seed=0, clip_model_sets=None, eval_clip_id: str = "openai/clip-vit-large-patch14", n_test: int = 10, n_step: int = 50, ): if clip_model_sets is not None: text_model, tokenizer, vis_model, processor = clip_model_sets else: text_model, tokenizer, vis_model, processor = prepare_clip_model_sets( eval_clip_id ) images = [] img_embeds = [] text_embeds = [] for prompt in EXAMPLE_PROMPTS[:n_test]: prompt = prompt.replace("", learnt_token) torch.manual_seed(seed) with torch.autocast("cuda"): img = pipe( prompt, num_inference_steps=n_step, guidance_scale=guidance_scale ).images[0] images.append(img) # image inputs = processor(images=img, return_tensors="pt") img_embed = vis_model(**inputs).image_embeds img_embeds.append(img_embed) prompt = prompt.replace(learnt_token, class_token) # prompts inputs = tokenizer([prompt], padding=True, return_tensors="pt") outputs = text_model(**inputs) text_embed = outputs.text_embeds text_embeds.append(text_embed) # target images inputs = processor(images=target_images, return_tensors="pt") target_img_embeds = vis_model(**inputs).image_embeds img_embeds = torch.cat(img_embeds, dim=0) text_embeds = torch.cat(text_embeds, dim=0) return text_img_alignment(img_embeds, text_embeds, target_img_embeds) def visualize_progress( path_alls: Union[str, List[str]], prompt: str, model_id: str = "runwayml/stable-diffusion-v1-5", device="cuda:0", patch_unet=True, patch_text=True, patch_ti=True, unet_scale=1.0, text_sclae=1.0, num_inference_steps=50, guidance_scale=5.0, offset: int = 0, limit: int = 10, seed: int = 0, ): imgs = [] if isinstance(path_alls, str): alls = list(set(glob.glob(path_alls))) alls.sort(key=os.path.getmtime) else: alls = path_alls pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16 ).to(device) print(f"Found {len(alls)} checkpoints") for path in alls[offset:limit]: print(path) patch_pipe( pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti ) tune_lora_scale(pipe.unet, unet_scale) tune_lora_scale(pipe.text_encoder, text_sclae) torch.manual_seed(seed) image = pipe( prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ).images[0] imgs.append(image) return imgs