sasha's picture
sasha HF staff
initial commit
f53a084
raw
history blame
2.35 kB
from PIL import Image
import os,csv
import pandas as pd
import numpy as np
import gradio as gr
prompts=pd.read_csv('promptsadjectives.csv')
masc = prompts['Masc-adj'][:10].tolist()
fem = prompts['Fem-adj'][:10].tolist()
adjectives = sorted(masc+fem)
adjectives.insert(0, '')
occupations = prompts['Occupation-Noun'][:150].tolist()
def get_averages(adj, profession):
if adj != "":
prompt = (adj + ' ' + profession).replace(' ','_')
else:
prompt = profession.replace(' ','_')
#TODO: fix upper/lowercase error
sd14_average = 'facer_faces/SDv14/'+prompt+'.png'
if os.path.isfile(sd14_average) == False:
sd14_average = 'facer_faces/blank.png'
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png'
if os.path.isfile(sdv2_average) == False:
sdv2_average = 'facer_faces/blank.png'
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png'
if os.path.isfile(dalle_average) == False:
dalle_average = 'facer_faces/blank.png'
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2"))
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Diffusion Model Average Faces")
gr.Markdown("### We ran 150 professions through 3 diffusion models to examine what they generate.")
gr.Markdown("#### Choose one of the professions and adjectives and see the average face generated by each model.")
with gr.Row():
with gr.Column():
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True)
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True)
btn = gr.Button("Get average faces!")
with gr.Column():
gallery = gr.Gallery(
label="Average images", show_label=False, elem_id="gallery"
).style(grid=[0,3], height="auto")
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.")
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession, sorry!")
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery)
demo.launch(share=True)