Spaces:
Running
Running
import gradio as gr | |
from PIL import Image | |
import shutil | |
import pickle | |
import random | |
import json | |
import os | |
if not os.path.exists("./data/pacs/"): | |
shutil.unpack_archive("./data/pacs.zip", './data/', 'zip') | |
METHODS = { | |
"Textual Inversion (LDM)": "textualinversion_ldm", | |
"Textual Inversion (Stable Diffusion)": "none_with_emb_without_multires", | |
"DreamBooth": "unet_without_emb_without_multires", | |
"Custom Diffusion": "kv_with_emb_without_multires", | |
} | |
for method in list(METHODS.values()): | |
if not os.path.exists(f"./data/imagenet/images/{method}"): | |
shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') | |
if not os.path.exists(f"./data/imagenet/compositions/images/{method}"): | |
shutil.unpack_archive(f"./data/imagenet/compositions/images/{method}.zip", f"./", 'zip') | |
method="original" | |
if not os.path.exists(f"./data/imagenet/images/{method}"): | |
shutil.unpack_archive(f"./data/imagenet/images/{method}.zip", f"./", 'zip') | |
print("Ready to go") | |
CONCEPTS = { | |
"Art Painting": "art_painting", | |
"Cartoon": "cartoon", | |
"Photo": "photo", | |
"Sketch": "sketch", | |
} | |
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"] | |
with open("./data/imagenet/imagenet_mapping.pkl", "rb") as h: | |
imagenet_mapping = pickle.load(h) | |
OBJECTS = [] | |
for k,v in imagenet_mapping.items(): | |
CONCEPTS[f"{k}:{v}"] = k | |
OBJECTS.append(f"{k}:{v}") | |
def get_domains(method, concept): | |
gen_cls=random.choice(os.listdir(os.path.join('./data/pacs', method, concept))) | |
fname=random.choice(os.listdir(os.path.join('./data/pacs', method, concept, gen_cls))) | |
gen_img = Image.open(os.path.join('./data/pacs', method, concept, gen_cls, fname)).resize((128, 128)) | |
ref_images = [] | |
for i in range(3): | |
cls=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept))) | |
fname=random.choice(os.listdir(os.path.join('./data/pacs', 'original', concept, cls))) | |
img = Image.open(os.path.join('./data/pacs', 'original', concept, cls, fname)).resize((128, 128)) | |
ref_images.append(img) | |
return gen_img, f"a photo of {gen_cls} in the style of {concept}", ref_images | |
def get_objects(method, concept, evaluation): | |
if evaluation=="Concept Alignment": | |
gen_cls = "" | |
if "ldm" in method: | |
gen_cls="samples" | |
fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', method, concept, gen_cls))) | |
gen_img = Image.open(os.path.join('./data/imagenet/images', method, concept, gen_cls, fname)).resize((128, 128)) | |
ref_images = [] | |
for i in range(3): | |
fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) | |
img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) | |
ref_images.append(img) | |
return gen_img, f"a photo of **{imagenet_mapping[concept]}**", ref_images | |
else: | |
gen_cls = "" | |
if "ldm" in method: | |
gen_cls="samples" | |
with open(f"./data/imagenet/compositions/prompts/{concept}.json", "r") as h: | |
prompts = json.load(h) | |
fname=random.choice(os.listdir(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls))) | |
gen_img = Image.open(os.path.join('./data/imagenet/compositions/images', method, concept, gen_cls, fname)).resize((128, 128)) | |
idx = int(fname.split("_")[0]) | |
caption = prompts[idx]["caption"].replace(prompts[idx]["entity"], f"**{prompts[idx]['entity']}**") | |
ref_images = [] | |
for i in range(3): | |
fname=random.choice(os.listdir(os.path.join('./data/imagenet/images', 'original', concept))) | |
img = Image.open(os.path.join('./data/imagenet/images', 'original', concept, fname)).resize((128, 128)) | |
ref_images.append(img) | |
return gen_img, caption, ref_images | |
def get_images(method, concept, evaluation): | |
method = METHODS[method] | |
concept = CONCEPTS[concept] | |
if concept in DOMAINS: | |
images, captions, ref_images = get_domains(method, concept) | |
return images, captions, ref_images | |
elif concept in list(imagenet_mapping.keys()): | |
images, captions, ref_images = get_objects(method, concept, evaluation) | |
return images, captions, ref_images | |
else: | |
return | |
css=''' | |
#image_upload{min-height:4px} | |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{max-height: 5} | |
''' | |
image_blocks = gr.Blocks(css=css) | |
with image_blocks as demo: | |
# with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align: center;'>ConceptBed Benchmark Explorer</h1>") | |
gr.Markdown("<h1 style='text-align: center;'><a href='https://conceptbed.github.io'>Project Page</a> | <a href='https://arxiv.org/abs/2306.04695'>Paper</a> </h1>") | |
gr.Markdown(""" | |
## How to interpret results: | |
1. The shown three images are reference concept images learned by the diffusion model. | |
2. The output target concept image is generated by Stable Diffusion using selected methodologies. | |
3. The output text indicates the prompt used to generate the image. | |
# """) | |
gr.Markdown(""" | |
## Types of evaluations: | |
1. Concept Alignment: available for all concepts | |
2. Compositional Reasoning: available for all concepts except -- Art Painting, Cartoon, Sketch, Photo | |
# """) | |
gr.Markdown(""" | |
### For further details on the ConceptBed benchmark, please refer to the paper at: <a href="https://arxiv.org/abs/2306.04695">https://arxiv.org/abs/2306.04695</a> | |
# """) | |
with gr.Row(): | |
with gr.Column(): | |
methods1 = gr.Dropdown( | |
list(METHODS.keys()), | |
label="Concept Learner", | |
info="Select a concept learning strategy." | |
) | |
concept1 = gr.Dropdown( | |
list(CONCEPTS.keys()), | |
label="Concept", | |
info="Select a concept." | |
) | |
evaluation1 = gr.Dropdown( | |
["Concept Alignment", "Compositional Reasoning"], | |
label="Evaluation Type", | |
info="Select the evaluation type." | |
) | |
gallery1 = gr.Gallery( | |
label="Reference images", | |
show_label=False, | |
elem_id="gallery", | |
).style( | |
columns=[3], rows=[1], height="200px" | |
) | |
# image1 = gr.Gallery( | |
# label="Reference images", | |
# show_label=False, | |
# elem_id="gallery", | |
# ).style( | |
# columns=[1], rows=[1], height="200px" | |
# ) | |
image1 = gr.Image()#.style(height="200px", width="200px") | |
text1 = gr.Textbox(label="Caption used to generate above image") | |
btn1 = gr.Button(value="Get Image", full_width=False) | |
with gr.Column(): | |
methods2 = gr.Dropdown( | |
list(METHODS.keys()), | |
label="Concept Learner", | |
info="Select a concept learning strategy." | |
) | |
concept2 = gr.Dropdown( | |
list(CONCEPTS.keys()), | |
label="Concept", | |
info="Select a concept." | |
) | |
evaluation2 = gr.Dropdown( | |
["Concept Alignment", "Compositional Reasoning"], | |
label="Evaluation Type", | |
info="Select the evaluation type." | |
) | |
gallery2 = gr.Gallery( | |
label="Reference images", | |
show_label=False, | |
elem_id="gallery", | |
).style( | |
columns=[3], rows=[1], height="200px" | |
) | |
image2 = gr.Image(elem_id="image_upload") | |
text2 = gr.Textbox(label="Caption used to generate above image") | |
btn2 = gr.Button(value="Get Image", full_width=False) | |
btn1.click(get_images, inputs=[methods1, concept1, evaluation1], outputs=[image1, text1, gallery1]) | |
btn2.click(get_images, inputs=[methods2, concept2, evaluation2], outputs=[image2, text2, gallery2]) | |
with gr.Accordion(label="Notes", open=False): | |
gr.HTML( | |
"""<div class="acknowledgments"> | |
<p><h4>Generated Images:</h4> | |
As ConceptBed evaluations required training of 1000+ models (one for each concept), it is impossible to host a live demo. | |
Therefore, we generate 200,000+ images and randomly select a few images for this demo. | |
""" | |
) | |
if __name__ == "__main__": | |
demo.launch() |