|
from PIL import Image |
|
import os,csv |
|
import pandas as pd |
|
import numpy as np |
|
import gradio as gr |
|
|
|
prompts=pd.read_csv('promptsadjectives.csv') |
|
masc = prompts['Masc-adj'][:10].tolist() |
|
fem = prompts['Fem-adj'][:10].tolist() |
|
adjectives = sorted(masc+fem) |
|
adjectives.insert(0, '') |
|
occupations = prompts['Occupation-Noun'][:150].tolist() |
|
|
|
|
|
def get_averages(adj, profession): |
|
if adj != "": |
|
prompt = (adj + ' ' + profession).replace(' ','_') |
|
else: |
|
prompt = profession.replace(' ','_') |
|
|
|
sd14_average = 'facer_faces/SDv14/'+prompt+'.png' |
|
if os.path.isfile(sd14_average) == False: |
|
sd14_average = 'facer_faces/blank.png' |
|
sdv2_average = 'facer_faces/SDv2/'+prompt+'.png' |
|
if os.path.isfile(sdv2_average) == False: |
|
sdv2_average = 'facer_faces/blank.png' |
|
dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png' |
|
if os.path.isfile(dalle_average) == False: |
|
dalle_average = 'facer_faces/blank.png' |
|
|
|
return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2")) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.Markdown("# Text-to-Image Diffusion Model Average Faces") |
|
gr.Markdown("### We ran 150 professions through 3 diffusion models to examine what they generate.") |
|
gr.Markdown("#### Choose one of the professions and adjectives and see the average face generated by each model.") |
|
with gr.Row(): |
|
|
|
with gr.Column(): |
|
adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True) |
|
prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True) |
|
btn = gr.Button("Get average faces!") |
|
with gr.Column(): |
|
gallery = gr.Gallery( |
|
label="Average images", show_label=False, elem_id="gallery" |
|
).style(grid=[0,3], height="auto") |
|
gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.") |
|
gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession, sorry!") |
|
|
|
|
|
btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery) |
|
|
|
demo.launch(share=True) |
|
|
|
|