|
import gradio as gr |
|
import json |
|
import numpy as np |
|
import pandas as pd |
|
from datasets import load_from_disk |
|
from itertools import chain |
|
import operator |
|
|
|
pd.options.plotting.backend = "plotly" |
|
|
|
|
|
TITLE = "Identity Biases in Diffusion Models: Professions" |
|
|
|
_INTRO = """ |
|
# Identity Biases in Diffusion Models: Professions |
|
|
|
Explore profession-level social biases in the data from [DiffusionBiasExplorer](https://hf.co/spaces/tti-bias/diffusion-bias-explorer)! |
|
This demo leverages the gender and ethnicity representation clusters described in the [companion app](https://hf.co/spaces/tti-bias/diffusion-face-clustering) |
|
to analyze social trends in machine-generated visual representations of professions. |
|
The **Professions Overview** tab lets you compare the distribution over |
|
[identity clusters](https://hf.co/spaces/tti-bias/diffusion-face-clustering "Identity clusters identify visual features in the systems' output space correlated with variation of gender and ethnicity in input prompts.") |
|
across professions for Stable Diffusion and Dalle-2 systems (or aggregated for `All Models`). |
|
The **Professions Focus** tab provides more details for each of the individual professions, including direct system comparisons and examples of profession images for each cluster. |
|
This work was done in the scope of the [Stable Bias Project](https://hf.co/spaces/tti-bias/stable-bias). |
|
""" |
|
|
|
_ = """ |
|
For example, you can use this tool to investigate: |
|
- How do each model's representation of professions correlate with the gender ratios reported by the [U.S. Bureau of Labor |
|
Statistics](https://www.bls.gov/cps/cpsaat11.htm "The reported percentage of women in each profession in the US is indicated in the `Labor Women` column in the Professions Overview tab.")? |
|
Are social trends reflected, are they exaggerated? |
|
- Which professions have the starkest differences in how different models represent them? |
|
""" |
|
|
|
professions_dset = load_from_disk("professions") |
|
professions_df = professions_dset.to_pandas() |
|
|
|
|
|
clusters_dicts = dict( |
|
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json"))) |
|
for num_cl in [12, 24, 48] |
|
) |
|
|
|
cluster_summaries_by_size = json.load(open("clusters/cluster_summaries_by_size.json")) |
|
|
|
prompts = pd.read_csv("promptsadjectives.csv") |
|
professions = ["all professions"] + list( |
|
|
|
sorted([p for p in prompts["Occupation-Noun"].tolist()]) |
|
) |
|
models = { |
|
"All": "All Models", |
|
"SD_14": "Stable Diffusion 1.4", |
|
"SD_2": "Stable Diffusion 2", |
|
"DallE": "Dall-E 2", |
|
} |
|
|
|
df_models = { |
|
"All Models": "All", |
|
"Stable Diffusion 1.4": "SD_14", |
|
"Stable Diffusion 2": "SD_2", |
|
"Dall-E 2": "DallE", |
|
} |
|
|
|
|
|
def describe_cluster(num_clusters, block="label"): |
|
cl_dict = clusters_dicts[num_clusters] |
|
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1)) |
|
labels_values.reverse() |
|
total = float(sum(cl_dict.values())) |
|
lv_prcnt = list( |
|
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values |
|
) |
|
top_label = lv_prcnt[0][0] |
|
description_string = ( |
|
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" |
|
% (to_string(block), to_string(top_label), lv_prcnt[0][1]) |
|
) |
|
description_string += "<p>This is followed by: " |
|
for lv in lv_prcnt[1:]: |
|
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1]) |
|
description_string += "</p>" |
|
return description_string |
|
|
|
|
|
def make_profession_plot(num_clusters, prof_name): |
|
sorted_cl_scores = [ |
|
(k, v) |
|
for k, v in sorted( |
|
clusters_dicts[num_clusters]["All"][prof_name][ |
|
"cluster_proportions" |
|
].items(), |
|
key=lambda x: x[1], |
|
reverse=True, |
|
) |
|
if v > 0 |
|
] |
|
pre_pandas = dict( |
|
[ |
|
( |
|
models[mod_name], |
|
dict( |
|
( |
|
f"Cluster {k}", |
|
clusters_dicts[num_clusters][mod_name][prof_name][ |
|
"cluster_proportions" |
|
][k], |
|
) |
|
for k, _ in sorted_cl_scores |
|
), |
|
) |
|
for mod_name in models |
|
] |
|
) |
|
df = pd.DataFrame.from_dict(pre_pandas) |
|
prof_plot = df.plot(kind="bar", barmode="group") |
|
cl_summary_text = f"Profession '{prof_name}':\n" |
|
for cl_id, _ in sorted_cl_scores: |
|
cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][int(cl_id)].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" |
|
return ( |
|
prof_plot, |
|
gr.update( |
|
choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0] |
|
), |
|
gr.update(value=cl_summary_text), |
|
) |
|
|
|
|
|
def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8): |
|
professions_list_clusters = [ |
|
( |
|
prof_name, |
|
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ |
|
"cluster_proportions" |
|
], |
|
) |
|
for prof_name in prof_names |
|
] |
|
totals = sorted( |
|
[ |
|
( |
|
k, |
|
sum( |
|
prof_clusters[str(k)] |
|
for _, prof_clusters in professions_list_clusters |
|
), |
|
) |
|
for k in range(num_clusters) |
|
], |
|
key=lambda x: x[1], |
|
reverse=True, |
|
)[:max_cols] |
|
prof_list_pre_pandas = [ |
|
dict( |
|
[ |
|
("Profession", prof_name), |
|
( |
|
"Entropy", |
|
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ |
|
"entropy" |
|
], |
|
), |
|
( |
|
"Labor Women", |
|
clusters_dicts[num_clusters][df_models[mod_name]][prof_name][ |
|
"labor_fm" |
|
][0], |
|
), |
|
("", ""), |
|
] |
|
+ [(f"Cluster {k}", prof_clusters[str(k)]) for k, v in totals if v > 0] |
|
) |
|
for prof_name, prof_clusters in professions_list_clusters |
|
] |
|
clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas) |
|
cl_summary_text = "" |
|
for cl_id, _ in totals[:max_cols]: |
|
cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][cl_id].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n" |
|
return ( |
|
[c[0] for c in totals], |
|
( |
|
clusters_df.style.background_gradient( |
|
axis=None, vmin=0, vmax=100, cmap="YlGnBu" |
|
) |
|
.format(precision=1) |
|
.to_html() |
|
), |
|
gr.update(value=cl_summary_text), |
|
) |
|
|
|
|
|
def get_image(model, fname, score): |
|
return ( |
|
professions_dset.select( |
|
professions_df[ |
|
(professions_df["image_path"] == fname) |
|
& (professions_df["model"] == model) |
|
].index |
|
)["image"][0], |
|
" ".join(fname.split("/")[0].split("_")[4:]) |
|
+ f" | {score:.2f}" |
|
+ f" | {models[model]}", |
|
) |
|
|
|
|
|
def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.6): |
|
|
|
examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][ |
|
"cluster_examplars" |
|
][str(cl_id)] |
|
l = [ |
|
tuple(img) |
|
for img in examplars_dict["close"] |
|
+ examplars_dict["mid"][:2] |
|
+ examplars_dict["far"] |
|
] |
|
l = [ |
|
img |
|
for i, img in enumerate(l) |
|
if img[0] > confidence_threshold and img not in l[:i] |
|
] |
|
return ( |
|
[get_image(model, fname, score) for score, model, fname in l], |
|
gr.update( |
|
label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}" |
|
), |
|
) |
|
|
|
|
|
with gr.Blocks(title=TITLE) as demo: |
|
gr.Markdown(_INTRO) |
|
gr.HTML( |
|
"""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>""" |
|
) |
|
with gr.Tab("Professions Overview"): |
|
gr.Markdown( |
|
""" |
|
Select one or more professions and models from the dropdowns on the left to see which clusters are most representative for this combination. |
|
Try choosing different numbers of clusters to see if the results change, and then go to the 'Profession Focus' tab to go more in-depth into these results. |
|
The `Labor Women` column provided for comparison corresponds to the gender ratio reported by the |
|
[U.S. Bureau of Labor Statistics](https://www.bls.gov/cps/cpsaat11.htm) for each profession. |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown("Select the parameters here:") |
|
num_clusters = gr.Radio( |
|
[12, 24, 48], |
|
value=12, |
|
label="How many clusters do you want to use to represent identities?", |
|
) |
|
model_choices = gr.Dropdown( |
|
[ |
|
"All Models", |
|
"Stable Diffusion 1.4", |
|
"Stable Diffusion 2", |
|
"Dall-E 2", |
|
], |
|
value="All Models", |
|
label="Which models do you want to compare?", |
|
interactive=True, |
|
) |
|
profession_choices_overview = gr.Dropdown( |
|
professions, |
|
value=[ |
|
"all professions", |
|
"CEO", |
|
"director", |
|
"social assistant", |
|
"social worker", |
|
], |
|
label="Which professions do you want to compare?", |
|
multiselect=True, |
|
interactive=True, |
|
) |
|
with gr.Column(scale=3): |
|
with gr.Row(): |
|
table = gr.HTML( |
|
label="Profession assignment per cluster", wrap=True |
|
) |
|
with gr.Row(): |
|
|
|
clusters = gr.Textbox(label="clusters", visible=False) |
|
gr.Markdown( |
|
""" |
|
##### What do the clusters mean? |
|
Below is a summary of the identity cluster compositions. |
|
For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/DiffusionFaceClustering): |
|
""" |
|
) |
|
with gr.Row(): |
|
with gr.Accordion(label="Cluster summaries", open=True): |
|
cluster_descriptions_table = gr.Text( |
|
"TODO", label="Cluster summaries", show_label=False |
|
) |
|
with gr.Tab("Profession Focus"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
"Select a profession to visualize and see which clusters and identity groups are most represented in the profession, as well as some examples of generated images below." |
|
) |
|
profession_choice_focus = gr.Dropdown( |
|
choices=professions, |
|
value="scientist", |
|
label="Select profession:", |
|
) |
|
num_clusters_focus = gr.Radio( |
|
[12, 24, 48], |
|
value=12, |
|
label="How many clusters do you want to use to represent identities?", |
|
) |
|
with gr.Column(): |
|
plot = gr.Plot( |
|
label=f"Makeup of the cluster assignments for profession {profession_choice_focus}" |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
##### What do the clusters mean? |
|
Below is a summary of the identity cluster compositions. |
|
For more details, see the [companion demo](https://huggingface.co/spaces/tti-bias/DiffusionFaceClustering): |
|
""" |
|
) |
|
with gr.Accordion(label="Cluster summaries", open=True): |
|
cluster_descriptions = gr.Text( |
|
"TODO", label="Cluster summaries", show_label=False |
|
) |
|
with gr.Column(): |
|
gr.Markdown( |
|
""" |
|
##### What's in the clusters? |
|
You can show examples of profession images assigned to each identity cluster by selecting one here: |
|
""" |
|
) |
|
with gr.Accordion(label="Cluster selection", open=True): |
|
cluster_id_focus = gr.Dropdown( |
|
choices=[i for i in range(num_clusters_focus.value)], |
|
value=0, |
|
label="Select cluster to visualize:", |
|
) |
|
with gr.Row(): |
|
examplars_plot = gr.Gallery( |
|
label="Profession images assigned to the selected cluster." |
|
).style(grid=4, height="auto", container=True) |
|
demo.load( |
|
make_profession_table, |
|
[num_clusters, profession_choices_overview, model_choices], |
|
[clusters, table, cluster_descriptions_table], |
|
queue=False, |
|
) |
|
demo.load( |
|
make_profession_plot, |
|
[num_clusters_focus, profession_choice_focus], |
|
[plot, cluster_id_focus, cluster_descriptions], |
|
queue=False, |
|
) |
|
demo.load( |
|
show_examplars, |
|
[ |
|
num_clusters_focus, |
|
profession_choice_focus, |
|
cluster_id_focus, |
|
], |
|
[examplars_plot, examplars_plot], |
|
queue=False, |
|
) |
|
for var in [num_clusters, model_choices, profession_choices_overview]: |
|
var.change( |
|
make_profession_table, |
|
[num_clusters, profession_choices_overview, model_choices], |
|
[clusters, table, cluster_descriptions_table], |
|
queue=False, |
|
) |
|
for var in [num_clusters_focus, profession_choice_focus]: |
|
var.change( |
|
make_profession_plot, |
|
[num_clusters_focus, profession_choice_focus], |
|
[plot, cluster_id_focus, cluster_descriptions], |
|
queue=False, |
|
) |
|
for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]: |
|
var.change( |
|
show_examplars, |
|
[ |
|
num_clusters_focus, |
|
profession_choice_focus, |
|
cluster_id_focus, |
|
], |
|
[examplars_plot, examplars_plot], |
|
queue=False, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(debug=True) |
|
|