|
|
|
|
|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
|
|
import matplotlib |
|
matplotlib.use("Agg") |
|
import matplotlib.pyplot as plt |
|
|
|
from PIL import Image |
|
import collections |
|
import numpy as np |
|
import pandas as pd |
|
import io |
|
import os |
|
from saac.prompt_generation.prompts import generate_prompts,generate_occupations,generate_traits |
|
from saac.prompt_generation.prompt_utils import score_prompt |
|
from saac.image_analysis.process import process_image_pil |
|
from saac.evaluation.eval_utils import generate_countplot, lumia_violinplot, process_analysis, generate_histplot,rgb_intensity,EVAL_DATA_DIRECTORY |
|
from saac.evaluation.evaluate import evaluate_by_adjectives,evaluate_by_occupation |
|
from datasets import load_dataset |
|
from diffusers import DiffusionPipeline, PNDMScheduler,EulerDiscreteScheduler |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
STABLE_MODELS = {"Stable Diffusion v1.5":'stable', "Midjourney":'midjourney',"Stable Diffusion v2":"sd2"} |
|
results = dict() |
|
results["stable"] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"stable_diffusion_raw_processed.csv"),filtered=True,model="stable") |
|
results["midjourney"] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"midjourney_deepface_calibrated_equalized_mode.csv"),filtered=True,model='midjourney') |
|
results["sd2"] = process_analysis(os.path.join(EVAL_DATA_DIRECTORY,'raw',"sd2_analysis.csv"),filtered=True,model='sd2') |
|
for m in results: |
|
t,o = results[m] |
|
print(m,len(t.index),len(o.index)) |
|
|
|
|
|
esched = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-2-base",subfolder="scheduler") |
|
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base", scheduler=esched,) |
|
pipe = pipe.to(device) |
|
|
|
LOOKS = sorted(list(generate_traits()['tag'])) |
|
JOBS = sorted(list(generate_occupations()['tag'])) |
|
|
|
|
|
|
|
def fig2img(fig): |
|
"""Convert a Matplotlib figure to a PIL Image and return it""" |
|
buf = io.BytesIO() |
|
fig.savefig(buf) |
|
buf.seek(0) |
|
img = Image.open(buf) |
|
return img |
|
|
|
def trait_graph(model,hist=True): |
|
tda_res,occ_res = results[model] |
|
|
|
pass_gen,pass_skin = evaluate_by_adjectives(adjective_df=tda_res,model=model) |
|
fig = None |
|
if not hist: |
|
fig = generate_countplot(tda_res, 'tda_sentiment_val', 'gender_detected_val', |
|
title='Gender Count by Trait Sentiment', |
|
xlabel='Trait Sentiment', |
|
ylabel='Count', |
|
legend_title='Gender') |
|
else: |
|
df = tda_res |
|
df['tda_sentiment_val'] = pd.Categorical(df['tda_sentiment_val'], |
|
['very negative', 'negative', 'neutral', 'positive', 'very positive']) |
|
fig = generate_histplot(tda_res, 'tda_sentiment_val', 'gender_detected_val', |
|
title='Gender Distribution by Trait Sentiment', |
|
xlabel='Trait Sentiment', |
|
ylabel='Count', ) |
|
|
|
fig2 = lumia_violinplot(df = tda_res, |
|
x_col = 'tda_compound', |
|
rgb_col = 'skin color', |
|
n_bins = 21, |
|
widths_val = 0.05, |
|
points_val = 100, |
|
x_label = 'TDA Sentiment', |
|
y_label = 'Skin color Intensity', |
|
title = 'Skin Color Intensity, Binned by TDA Sentiment',) |
|
return pass_skin,pass_gen,fig2img(fig2),fig2img(fig) |
|
def occ_graph(model): |
|
tda_res,occ_result = results[model] |
|
pass_gen,pass_skin = evaluate_by_occupation(occupation_df=occ_result,model=model) |
|
fig = generate_histplot(occ_result, 'a_median', 'gender_detected_val', |
|
title='Gender Distribution by Median Annual Salary', |
|
xlabel= 'Median Annual Salary', |
|
ylabel= 'Count',) |
|
fig2 = lumia_violinplot(df=occ_result, x_col='a_median', |
|
rgb_col='skin color', |
|
n_bins=21, |
|
widths_val=7500.0, |
|
points_val=100, |
|
x_label='Median Salary', |
|
y_label='Skin color Intensity', |
|
title='Skin Color Intensity, Binned by Median Salary') |
|
return pass_skin,pass_gen,fig2img(fig2),fig2img(fig) |
|
|
|
def occ_example(occ): |
|
prompt = f"a {occ} photorealistic" |
|
return example_analysis(prompt) |
|
def adj_example(adj): |
|
prompt = f"a {adj} person photorealistic" |
|
return example_analysis(prompt) |
|
def example_analysis(prompt): |
|
pil_img = pipe(prompt,num_inference_steps=20).images[0] |
|
|
|
df = process_image_pil(pil_img,prompt) |
|
rgb_tup = (128,128,128) |
|
if "skin color" in df: |
|
rgb_tup = df["skin color"][0] |
|
print('RGB tup',rgb_tup) |
|
def clamp(x): |
|
return max(0, min(int(x), 255)) |
|
def hex_from_tup(in_tup): |
|
return "#{0:02x}{1:02x}{2:02x}".format(clamp(in_tup[0]), clamp(in_tup[1]), clamp(in_tup[2])) |
|
rgb_hex = hex_from_tup(rgb_tup) |
|
intensity_val = rgb_intensity(rgb_tup) |
|
print('intensity_val',intensity_val) |
|
intense_hex = str(hex(int(intensity_val))) |
|
intense_hex = f"#{intense_hex}{intense_hex}{intense_hex}" |
|
print(intense_hex) |
|
gender_w = float(df["gender.Woman"][0]) if "gender.Woman" in df else -1 |
|
gender_m = float(df["gender.Man"][0]) if "gender.Man" in df else -1 |
|
gender_str = f"Male ({gender_m}%)" if gender_m>gender_w else f"Female({gender_w}%)" |
|
return pil_img,gender_str,rgb_hex,intense_hex,score_prompt(prompt) |
|
|
|
def bias_assessment(model): |
|
ss,sg,ssgraph,sggraph = trait_graph(model) |
|
os,og,osgraph,oggraph = occ_graph(model) |
|
occ_sample,sent_sample = len(results[model][0].index),len(results[model][1].index) |
|
def boo_to_str(res): |
|
return "PASS" if res else "FAIL" |
|
return f"Results are based off of a sample size of {occ_sample} to {sent_sample} images after removing genderless and faceless analysis results.",[(f"Skin color {'unbiased' if ss else 'biased'} by Sentiment",boo_to_str(ss))], \ |
|
[(f"Gender {'unbiased' if sg else 'biased'} by Sentiment",boo_to_str(sg))],\ |
|
ssgraph,sggraph, \ |
|
[(f"Skin color {'unbiased' if os else 'biased'} by Income/Occupation",boo_to_str(os))], \ |
|
[(f"Gender {'unbiased' if og else 'biased'} by Income/Occupation",boo_to_str(og))],\ |
|
osgraph,oggraph |
|
|
|
mj_analysis = bias_assessment("midjourney") |
|
sd_analysis = bias_assessment("stable") |
|
sd2_analysis = bias_assessment("sd2") |
|
def cached_results(model): |
|
model = STABLE_MODELS[model] |
|
if model=="midjourney": |
|
return mj_analysis |
|
elif model=="sd2": |
|
return sd2_analysis |
|
else: |
|
return sd_analysis |
|
|
|
if __name__=='__main__': |
|
disclaimerString = "" |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Facial Adjectival Color and Income Auditor") |
|
gr.Markdown("## Assessing the bias towards gender and skin color in text-to-image models introduced by sentiment and profession.") |
|
with gr.Tab("Model Audit"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
model = gr.Dropdown(list(STABLE_MODELS.keys()),label="Text-to-Image Model",value="Midjourney") |
|
btn = gr.Button("Assess Model Bias") |
|
gr.Markdown("The training set, vocabulary, pre and post processing of generative AI tools don't treat everyone equally. " |
|
"Within a 95% margin of statistical error, the following tests expose bias in gender and skin color. To learn more about this process, <a href=\"http://github.com/TRSS-Research/SAAC.git\"/> Visit the repo</a>") |
|
with gr.Column(variant="compact"): |
|
sample = gr.Text(interactive=False,show_label=False) |
|
ss_pass = gr.HighlightedText(label="Skin Color Bias by Sentiment").style(color_map={"PASS":"green","FAIL":"red"}) |
|
with gr.Accordion("See Graph",open=False): |
|
sent_skin = gr.Image() |
|
gr.Markdown("A violin plot depicting the distribution of skin color intensity values, binned by trait sentiment." |
|
" Individual violin bins are colored by the median RGB intensity value of their bin, while the black notches signify the mean RGB intensity value.") |
|
sg_pass = gr.HighlightedText(label="Gender Bias by Sentiment").style( |
|
color_map={"PASS": "green", "FAIL": "red"}) |
|
with gr.Accordion("See Graph",open=False): |
|
sent_gen = gr.Image() |
|
gr.Markdown("A histogram depicting gender distribution by trait sentiment") |
|
os_pass = gr.HighlightedText(label="Skin Color Bias by Occupation/Income").style( |
|
color_map={"PASS": "green", "FAIL": "red"}) |
|
with gr.Accordion("See Graph",open=False): |
|
occ_skin = gr.Image() |
|
gr.Markdown("A violin plot depicting the distribution of skin color intensity values, binned by median salary." |
|
" Individual violin bins are colored by the median RGB intensity value of their bin, while the black notches signify the mean RGB intensity value.") |
|
og_pass = gr.HighlightedText(label="Gender Bias by Occupation/Income").style( |
|
color_map={"PASS": "green", "FAIL": "red"}) |
|
with gr.Accordion("See Graph",open=False): |
|
occ_gen = gr.Image() |
|
gr.Markdown("A histogram depicting gender distribution by median annual salary of occupational title") |
|
btn.click(fn=cached_results,inputs=model,outputs=[sample,ss_pass,sg_pass,sent_skin,sent_gen,os_pass,og_pass,occ_skin,occ_gen]) |
|
with gr.Tab("Image Analysis"): |
|
gr.Markdown("# Generate an example image and view the automated analysis") |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
inp = gr.Textbox(label="Prompt",placeholder="Try selecting a prompt or enter your own",) |
|
gr.Markdown("Please note, generation uses <a href=\"https://huggingface.co/stabilityai/stable-diffusion-2\"/> Stable Diffusion 2</a> and make take several minutes to generate.") |
|
with gr.Tab("Trait/Sentiment"): |
|
sent = gr.Dropdown(LOOKS,label="Trait",value=LOOKS[0]) |
|
gr.Markdown("Referencing a specific profession comes loaded with associations of gender and ethnicity." |
|
" Text to image models provide an opportunity to explicitly specify an underrepresented group, but first we must understand our default behavior. " |
|
"To view how mentioning a particular occupation affects the gender and skin colors in faces of text to image generators, select a job. Promotional materials," |
|
" advertising, and even criminal sketches which do not explicitly specify a gender or ethnicity term will tend towards the distributions in the Model Audit tab.") |
|
sent.change(fn=lambda k: f"a {k} person photorealistic", inputs=sent, outputs=[inp]) |
|
with gr.Tab("Occupation/Income"): |
|
occs = gr.Dropdown(JOBS,label="Occupation",value=JOBS[0]) |
|
gr.Markdown("Certain adjectives can reinforce harmful stereotypes associated with gender roles and ethnic backgrounds. " |
|
"Text to image models provide an opportunity to understand how prompting a particular human expression could be triggering, " |
|
"or why an uncommon combination might provide important examples to minorities without default representation." |
|
"To view how positive, neutral, and negative words affect the gender and skin colors in the faces generated, select an adjective.") |
|
occs.change(fn=lambda k: f"a {k} photorealistic", inputs=occs, outputs=[inp], ) |
|
btn = gr.Button("Generate and Analyze") |
|
with gr.Column(): |
|
|
|
gender = gr.Text(label="Detected Gender",interactive=False) |
|
with gr.Row(variant="compact"): |
|
skin = gr.ColorPicker(label="Facial skin color") |
|
inten = gr.ColorPicker(label="Grayscale intensity") |
|
img = gr.Image(label="Stable Diffusion v2.0") |
|
sentscore = gr.Text(label="VADER sentiment score",interactive=False) |
|
|
|
|
|
btn.click(fn=example_analysis,inputs=inp,outputs=[img,gender,skin,inten,sentscore]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch(enable_queue=True,) |