from PIL import Image import os,csv import pandas as pd import numpy as np import gradio as gr prompts=pd.read_csv('promptsadjectives.csv') masc = prompts['Masc-adj'][:10].tolist() fem = prompts['Fem-adj'][:10].tolist() adjectives = sorted(masc+fem) adjectives.insert(0, '') occupations = prompts['Occupation-Noun'][:150].tolist() def get_averages(adj, profession): if adj != "": prompt = (adj + ' ' + profession).replace(' ','_') else: prompt = profession.replace(' ','_') #TODO: fix upper/lowercase error sd14_average = 'facer_faces/SDv14/'+prompt+'.png' if os.path.isfile(sd14_average) == False: sd14_average = 'facer_faces/blank.png' sdv2_average = 'facer_faces/SDv2/'+prompt+'.png' if os.path.isfile(sdv2_average) == False: sdv2_average = 'facer_faces/blank.png' dalle_average = 'facer_faces/dalle2/'+prompt.lower()+'.png' if os.path.isfile(dalle_average) == False: dalle_average = 'facer_faces/blank.png' return((Image.open(sd14_average), "Stable Diffusion v 1.4"), (Image.open(sdv2_average), "Stable Diffusion v 2"), (Image.open(dalle_average), "Dall-E 2")) with gr.Blocks() as demo: gr.Markdown("# Text-to-Image Diffusion Model Average Faces") gr.Markdown("### We ran 150 professions and 20 adjectives through 3 text-to-image diffusion models to examine what they generate.") gr.Markdown("#### Choose one of the professions and adjectives from the dropdown menus and see the average face generated by each model.") with gr.Row(): with gr.Column(): adj = gr.Dropdown(sorted(adjectives, key=str.casefold), value = '',label = "Choose an adjective", interactive= True) prof = gr.Dropdown(sorted(occupations, key=str.casefold), value = '', label = "Choose a profession", interactive= True) btn = gr.Button("Get average faces!") with gr.Column(): gallery = gr.Gallery( label="Average images", show_label=False, elem_id="gallery" ).style(grid=[0,3], height="auto") gr.Markdown("The three models are: Stable Diffusion v.1.4, Stable Diffusion v.2, and Dall-E 2.") gr.Markdown("If you see a black square above, we weren't able to compute an average face for this profession!") btn.click(fn=get_averages, inputs=[adj,prof], outputs=gallery) demo.launch()