File size: 2,389 Bytes
f53a084 91b2885 c8f2925 f53a084 c8f2925 f53a084 f253729 f53a084 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from PIL import Image
import os,csv
import pandas as pd
import numpy as np
import gradio as gr
prompts=pd.read_csv('promptsadjectives.csv')
masc = prompts['Masc-adj'][:10].tolist()
fem = prompts['Fem-adj'][:10].tolist()
adjectives = sorted(masc+fem)
adjectives.insert(0, '')
occupations = prompts['Occupation-Noun'][:150].tolist()
def get_averages(adj, profession):
if adj != "":
prompt = (adj + ' ' + profession).replace(' ','_')
else:
prompt = profession.replace(' ','_')
#TODO: fix upper/lowercase error
sd14_average = 'facer_faces/SDv14/'+prompt+'.png'
if os.path.isfile(sd14_average) == False:
sd14_average = 'facer_faces/blank.png'
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png'
if os.path.isfile(sdv2_average) == False:
sdv2_average = 'facer_faces/blank.png'
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png'
if os.path.isfile(dalle_average) == False:
dalle_average = 'facer_faces/blank.png'
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2"))
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Diffusion Model Average Faces")
gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.")
gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.")
with gr.Row():
with gr.Column():
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True)
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True)
btn = gr.Button("Get average faces!")
with gr.Column():
gallery = gr.Gallery(
label="Average images", show_label=False, elem_id="gallery"
).style(grid=[0,3], height="auto")
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.")
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!")
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery)
demo.launch()
|