import gradio as gr from PIL import Image import shutil import pickle import random import json import os if not os.path.exists("./data/pacs/"): shutil.unpack_archive("./data/pacs.zip", './data/', 'zip') METHODS = { "Textual Inversion (LDM)": "textualinversion_ldm", "Textual Inversion (Stable Diffusion)": "none_with_emb_without_multires", "DreamBooth": "unet_without_emb_without_multires", "Custom Diffusion": "kv_with_emb_without_multires", } for method in list(METHODS.values()): if not os.path.exists(f"./data/imagenet/images/{method}"): shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') if not os.path.exists(f"./data/imagenet/compositions/images/{method}"): shutil.unpack_archive(f"./data/imagenet/compositions/images/{method}.zip", f"./", 'zip') method="original" if not os.path.exists(f"./data/imagenet/images/{method}"): shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') print("Ready to go") CONCEPTS = { "Art Painting": "art_painting", "Cartoon": "cartoon", "Photo": "photo", "Sketch": "sketch", } DOMAINS = ["art_painting", "cartoon", "photo", "sketch"] with open("./data/imagenet/imagenet_mapping.pkl", "rb") as h: imagenet_mapping = pickle.load(h) OBJECTS = [] for k,v in imagenet_mapping.items(): CONCEPTS[f"{k}:{v}"] = k OBJECTS.append(f"{k}:{v}") def get_domains(method, concept): gen_cls=random.choice(os.listdir(os.path.join('./data/pacs', method, concept))) fname=random.choice(os.listdir(os.path.join('./data/pacs', method, concept, gen_cls))) gen_img = Image.open(os.path.join('./data/pacs', method, concept, gen_cls, fname)).resize((128, 128)) ref_images = [] for i in range(3): cls=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept))) fname=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept, cls))) img = Image.open(os.path.join('./data/pacs', 'original', concept, cls, fname)).resize((128, 128)) ref_images.append(img) return gen_img, f"a photo of {gen_cls} in the style of {concept}", ref_images def get_objects(method, concept, evaluation): if evaluation=="Concept Alignment": gen_cls = "" if "ldm" in method: gen_cls="samples" fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', method, concept, gen_cls))) gen_img = Image.open(os.path.join('./data/imagenet/images', method, concept, gen_cls, fname)).resize((128, 128)) ref_images = [] for i in range(3): fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) ref_images.append(img) return gen_img, f"a photo of **{imagenet_mapping[concept]}**", ref_images else: gen_cls = "" if "ldm" in method: gen_cls="samples" with open(f"./data/imagenet/compositions/prompts/{concept}.json", "r") as h: prompts = json.load(h) fname=random.choice(os.listdir(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls))) gen_img = Image.open(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls, fname)).resize((128, 128)) idx = int(fname.split("_")[0]) caption = prompts[idx]["caption"].replace(prompts[idx]["entity"], f"**{prompts[idx]['entity']}**") ref_images = [] for i in range(3): fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) ref_images.append(img) return gen_img, caption, ref_images def get_images(method, concept, evaluation): method = METHODS[method] concept = CONCEPTS[concept] if concept in DOMAINS: images, captions, ref_images = get_domains(method, concept) return images, captions, ref_images elif concept in list(imagenet_mapping.keys()): images, captions, ref_images = get_objects(method, concept, evaluation) return images, captions, ref_images else: return css=''' #image_upload{min-height:4px} #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{max-height: 5} ''' image_blocks = gr.Blocks(css=css) with image_blocks as demo: # with gr.Blocks() as demo: gr.Markdown("