sasha's picture
sasha HF staff
Update app.py
8ee8a44
raw history blame
No virus
6.37 kB
import gradio as gr
import random, os, shutil
from PIL import Image
import pandas as pd
import tempfile
def open_sd_ims(adj, group, seed):
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(prompt) == False:
shutil.unpack_archive('zipped_images/stablediffusion/'+ prompt.replace(' ', '_') +'.zip', prompt, 'zip')
else:
prompt=group
if os.path.isdir(prompt) == False:
shutil.unpack_archive('zipped_images/stablediffusion/'+ prompt.replace(' ', '_') +'.zip', prompt, 'zip')
imnames= os.listdir(prompt+'/Seed_'+ str(seed)+'/')
images = [(Image.open(prompt+'/Seed_'+ str(seed)+'/'+name)) for name in imnames]
return images[:9]
def open_ims(model, adj, group):
seed = 48040
with tempfile.TemporaryDirectory() as tmpdirname:
print('created temporary directory', tmpdirname)
if model == "Dall-E 2":
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/'+ model.replace(' ','').lower()+ '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname+ '/'+ model.replace(' ','').lower()+ '/'+ prompt, 'zip')
else:
prompt=group
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/' + model.replace(' ','').lower() + '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' + model.replace(' ','').lower()+ '/' + prompt, 'zip')
imnames= os.listdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/')
images = [(Image.open(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/'+name)).convert("RGB") for name in imnames]
return images[:9]
else:
if group != '':
if adj != '':
prompt=adj+'_'+group.replace(' ','_')
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/'+ model.replace(' ','').lower()+ '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' +model.replace(' ','').lower()+ '/'+ prompt, 'zip')
else:
prompt=group
if os.path.isdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt) == False:
shutil.unpack_archive('zipped_images/' + model.replace(' ','').lower() + '/'+ prompt.replace(' ', '_') +'.zip', tmpdirname + '/' + model.replace(' ','').lower()+'/'+ prompt, 'zip')
imnames= os.listdir(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt+'/'+'Seed_'+ str(seed)+'/')
images = [(Image.open(tmpdirname + '/' + model.replace(' ','').lower()+ '/'+ prompt +'/'+'Seed_'+ str(seed)+'/'+name)) for name in imnames]
return images[:9]
vowels = ["a","e","i","o","u"]
prompts = pd.read_csv('promptsadjectives.csv')
seeds = [46267, 48040, 51237, 54325, 60884, 64830, 67031, 72935, 92118, 93109]
m_adjectives = prompts['Masc-adj'].tolist()[:10]
f_adjectives = prompts['Fem-adj'].tolist()[:10]
adjectives = sorted(m_adjectives+f_adjectives)
#adjectives = ['attractive','strong']
adjectives.insert(0, '')
professions = sorted([p.lower() for p in prompts['Occupation-Noun'].tolist()])
models = ["Stable Diffusion 1.4", "Dall-E 2","Stable Diffusion 2"]
with gr.Blocks() as demo:
gr.Markdown("# Diffusion Bias Explorer")
gr.Markdown("## Choose from the prompts below to explore how the text-to-image models like [Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original), [Stable Diffion v.2](https://huggingface.co/stabilityai/stable-diffusion-2) and [DALLE-2](https://openai.com/dall-e-2/) represent different professions and adjectives")
# gr.Markdown("Some of the images for Dall-E 2 are missing -- we are still in the process of generating them! If you get an 'error', please pick another prompt.")
# seed_choice = gr.State(0)
# seed_choice = 93109
# print("Seed choice is: " + str(seed_choice))
with gr.Row():
with gr.Column():
model1 = gr.Dropdown(models, label = "Choose a model to compare results", value = models[0], interactive=True)
adj1 = gr.Dropdown(adjectives, label = "Choose a first adjective (or leave this blank!)", interactive=True)
choice1 = gr.Dropdown(professions, label = "Choose a first group", interactive=True)
# seed1= gr.Dropdown(seeds, label = "Choose a random seed to compare results", value = seeds[1], interactive=True)
images1 = gr.Gallery(label="Images").style(grid=[3], height="auto")
with gr.Column():
model2 = gr.Dropdown(models, label = "Choose a model to compare results", value = models[0], interactive=True)
adj2 = gr.Dropdown(adjectives, label = "Choose a second adjective (or leave this blank!)", interactive=True)
choice2 = gr.Dropdown(professions, label = "Choose a second group", interactive=True)
# seed2= gr.Dropdown(seeds, label = "Choose a random seed to compare results", value= seeds[1], interactive=True)
images2 = gr.Gallery(label="Images").style(grid=[3], height="auto")
gr.Markdown("### [Research](http://gender-decoder.katmatfield.com/static/documents/Gaucher-Friesen-Kay-JPSP-Gendered-Wording-in-Job-ads.pdf) has shown that \
certain words are considered more masculine- or feminine-coded based on how appealing job descriptions containing these words \
seemed to male and female research participants and to what extent the participants felt that they 'belonged' in that occupation.")
#demo.load(random_image, None, [images])
choice1.change(open_ims, [model1, adj1,choice1], [images1])
choice2.change(open_ims, [model2, adj2,choice2], [images2])
adj1.change(open_ims, [model1, adj1, choice1], [images1])
adj2.change(open_ims, [model2, adj2, choice2], [images2])
# seed1.change(open_ims, [adj1,choice1,seed1], [images1])
# seed2.change(open_ims, [adj2,choice2,seed2], [images2])
demo.launch()