sasha's picture
sasha HF staff
Update app.py
91b2885
raw history blame
No virus
2.39 kB
from PIL import Image
import os,csv
import pandas as pd
import numpy as np
import gradio as gr
prompts=pd.read_csv('promptsadjectives.csv')
masc = prompts['Masc-adj'][:10].tolist()
fem = prompts['Fem-adj'][:10].tolist()
adjectives = sorted(masc+fem)
adjectives.insert(0, '')
occupations = prompts['Occupation-Noun'][:150].tolist()
def get_averages(adj, profession):
if adj != "":
prompt = (adj + ' ' + profession).replace(' ','_')
else:
prompt = profession.replace(' ','_')
#TODO: fix upper/lowercase error
sd14_average = 'facer_faces/SDv14/'+prompt+'.png'
if os.path.isfile(sd14_average) == False:
sd14_average = 'facer_faces/blank.png'
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png'
if os.path.isfile(sdv2_average) == False:
sdv2_average = 'facer_faces/blank.png'
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png'
if os.path.isfile(dalle_average) == False:
dalle_average = 'facer_faces/blank.png'
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2"))
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Diffusion Model Average Faces")
gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.")
gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.")
with gr.Row():
with gr.Column():
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True)
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True)
btn = gr.Button("Get average faces!")
with gr.Column():
gallery = gr.Gallery(
label="Average images", show_label=False, elem_id="gallery"
).style(grid=[0,3], height="auto")
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.")
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!")
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery)
demo.launch()