Spaces:
Runtime error
Runtime error
from PIL import Image | |
import os, csv | |
import pandas as pd | |
import numpy as np | |
import gradio as gr | |
prompts=pd.read_csv('promptsadjectives.csv') | |
masc = prompts['Masc-adj'][:10].tolist() | |
fem = prompts['Fem-adj'][:10].tolist() | |
adjectives = sorted(masc+fem) | |
adjectives.insert(0, '') | |
occupations = prompts['Occupation-Noun'][:150].tolist() | |
def get_averages(adj, profession): | |
if adj != "": | |
prompt = (adj + ' ' + profession).replace(' ','_') | |
else: | |
prompt = profession.replace(' ','_') | |
#TODO: fix upper/lowercase error | |
sd14_average = 'facer_faces/SDv14/'+prompt+'.png' | |
if os.path.isfile(sd14_average) == False: | |
sd14_average = 'facer_faces/blank.png' | |
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png' | |
if os.path.isfile(sdv2_average) == False: | |
sdv2_average = 'facer_faces/blank.png' | |
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png' | |
if os.path.isfile(dalle_average) == False: | |
dalle_average = 'facer_faces/blank.png' | |
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2")) | |
with gr.Blocks() as demo: | |
gr.Markdown("# Text-to-Image Diffusion Model Average Faces") | |
gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.") | |
gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.") | |
gr.HTML("""<span style="color:red">⚠️ <b>DISCLAIMER: the images displayed by this tool are based on images which were generated by text-to-image models which may depict offensive stereotypes or contain explicit content.</b></span>""") | |
with gr.Row(): | |
with gr.Column(): | |
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True) | |
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True) | |
btn = gr.Button("Get average faces!") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Average images", show_label=False, elem_id="gallery" | |
).style(grid=[0,3], height="auto") | |
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.") | |
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!") | |
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery) | |
demo.launch() | |