diffusion-faces / app.py
Anonymous Authors
Update app.py
5b42d79
raw
history blame contribute delete
2.63 kB
from PIL import Image
import os, csv
import pandas as pd
import numpy as np
import gradio as gr
prompts=pd.read_csv('promptsadjectives.csv')
masc = prompts['Masc-adj'][:10].tolist()
fem = prompts['Fem-adj'][:10].tolist()
adjectives = sorted(masc+fem)
adjectives.insert(0, '')
occupations = prompts['Occupation-Noun'][:150].tolist()
def get_averages(adj, profession):
if adj != "":
prompt = (adj + ' ' + profession).replace(' ','_')
else:
prompt = profession.replace(' ','_')
#TODO: fix upper/lowercase error
sd14_average = 'facer_faces/SDv14/'+prompt+'.png'
if os.path.isfile(sd14_average) == False:
sd14_average = 'facer_faces/blank.png'
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png'
if os.path.isfile(sdv2_average) == False:
sdv2_average = 'facer_faces/blank.png'
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png'
if os.path.isfile(dalle_average) == False:
dalle_average = 'facer_faces/blank.png'
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2"))
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Diffusion Model Average Faces")
gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.")
gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.")
gr.HTML("""<span style="color:red">⚠️ <b>DISCLAIMER: the images displayed by this tool are based on images which were generated by text-to-image models which may depict offensive stereotypes or contain explicit content.</b></span>""")
with gr.Row():
with gr.Column():
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True)
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True)
btn = gr.Button("Get average faces!")
with gr.Column():
gallery = gr.Gallery(
label="Average images", show_label=False, elem_id="gallery"
).style(grid=[0,3], height="auto")
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.")
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!")
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery)
demo.launch()