meg-branch

#3
by meg HF staff - opened
Files changed (1) hide show
  1. app.py +79 -39
app.py CHANGED
@@ -17,28 +17,39 @@ clusters_by_size = {
17
  48: clusters_48,
18
  }
19
 
 
20
  def to_string(label):
21
  if label == "SD_2":
22
- label = "Stable Diffusion 2"
23
  elif label == "SD_14":
24
- label = "Stable Diffusion 14"
25
  elif label == "DallE":
26
  label = "Dall-E 2"
 
 
 
 
 
 
27
  return label
28
 
 
29
  def describe_cluster(cl_dict, block="label"):
30
  labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
31
  labels_values.reverse()
32
  total = float(sum(cl_dict.values()))
33
- lv_prcnt = list((item[0], round(item[1] * 100/total, 0)) for item in labels_values)
 
34
  top_label = lv_prcnt[0][0]
35
- description_string = "<span>The most represented %s is <b>%s</b>, making up about %d%% of the cluster.</span>" % (block, to_string(lv_prcnt[0][0]), lv_prcnt[0][1])
 
36
  description_string += "<p>This is followed by: "
37
  for lv in lv_prcnt[1:]:
38
  description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
39
  description_string += "</p>"
40
  return description_string
41
 
 
42
  def show_cluster(cl_id, num_clusters):
43
  if not cl_id:
44
  cl_id = 0
@@ -47,60 +58,89 @@ def show_cluster(cl_id, num_clusters):
47
  cl_dct = clusters_by_size[num_clusters][cl_id]
48
  images = []
49
  for i in range(6):
50
- img_path = "/".join([st.replace("/", "") for st in cl_dct['img_path_list'][i].split("//")][3:])
51
- images.append((Image.open(os.path.join("identities-images", img_path)), "_".join([img_path.split("/")[0], img_path.split("/")[-1]]).replace('Photo_portrait_of_an_','').replace('Photo_portrait_of_a_','').replace('SD_v2_random_seeds_identity_','(SD v.2) ').replace('dataset-identities-dalle2_','(Dall-E 2) ').replace('SD_v1.4_random_seeds_identity_','(SD v.1.4) ').replace('_',' ')))
 
 
 
 
 
 
 
 
 
52
  model_fig = go.Figure()
53
- model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()),
54
- values=list(dict(cl_dct["labels_model"]).values())))
 
55
  model_description = describe_cluster(dict(cl_dct["labels_model"]), "model")
56
 
57
  gender_fig = go.Figure()
58
- gender_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()),
59
- values=list(dict(cl_dct["labels_gender"]).values())))
60
- gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")
 
 
61
 
62
  ethnicity_fig = go.Figure()
63
- ethnicity_fig.add_trace(go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()),
64
- y=list(dict(cl_dct["labels_ethnicity"]).values()),
65
- marker_color=px.colors.qualitative.G10))
 
66
  return (len(cl_dct['img_path_list']),
67
- gender_fig,gender_description,
68
  model_fig, model_description,
69
  ethnicity_fig,
70
  images,
71
- gr.update(maximum=num_clusters-1))
 
72
 
73
  with gr.Blocks(title=TITLE) as demo:
74
  gr.Markdown(f"# {TITLE}")
75
- gr.Markdown("## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!")
76
- gr.Markdown("### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 diffusion models.")
77
- gr.Markdown("### Below, see results on how the images from different prompts cluster together.")
78
- gr.HTML("""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""")
79
- num_clusters = gr.Radio([12,24,48], value=12, label="How many clusters do you want to make from the data?")
 
 
 
 
 
80
 
81
-
82
  with gr.Row():
83
  with gr.Column(scale=4):
84
- gallery = gr.Gallery(label="Most representative images in cluster").style(grid=(3,3))
 
 
85
  with gr.Column():
86
- cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value-1, step=1, value=0, label="Click to move between clusters")
 
 
87
  a = gr.Text(label="Number of images")
88
  with gr.Row():
89
- with gr.Column(scale=1):
90
- c = gr.Plot(label="How many images from each model?")
91
- c_desc = gr.HTML(label="")
92
- with gr.Column(scale=1):
93
- b = gr.Plot(label="How many genders are represented?")
94
- b_desc = gr.HTML(label="")
95
- with gr.Column(scale=2):
96
- d = gr.Plot(label="Which ethnicities are present?")
97
 
98
- gr.Markdown(f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
99
- gr.Markdown('The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary, and unmarked, which we label "person".')
100
- gr.Markdown(f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.")
101
- demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
102
- num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
103
- cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters], outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- demo.queue().launch(debug=True)
 
17
  48: clusters_48,
18
  }
19
 
20
+
21
  def to_string(label):
22
  if label == "SD_2":
23
+ label = "Stable Diffusion 2.0"
24
  elif label == "SD_14":
25
+ label = "Stable Diffusion 1.4"
26
  elif label == "DallE":
27
  label = "Dall-E 2"
28
+ elif label == "non-binary":
29
+ label = "non-binary person"
30
+ elif label == "person":
31
+ label = "<i>unmarked</i> (person)"
32
+ elif label == "gender":
33
+ label = "gender term"
34
  return label
35
 
36
+
37
  def describe_cluster(cl_dict, block="label"):
38
  labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
39
  labels_values.reverse()
40
  total = float(sum(cl_dict.values()))
41
+ lv_prcnt = list(
42
+ (item[0], round(item[1] * 100 / total, 0)) for item in labels_values)
43
  top_label = lv_prcnt[0][0]
44
+ description_string = "<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>" % (
45
+ to_string(block), to_string(top_label), lv_prcnt[0][1])
46
  description_string += "<p>This is followed by: "
47
  for lv in lv_prcnt[1:]:
48
  description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
49
  description_string += "</p>"
50
  return description_string
51
 
52
+
53
  def show_cluster(cl_id, num_clusters):
54
  if not cl_id:
55
  cl_id = 0
 
58
  cl_dct = clusters_by_size[num_clusters][cl_id]
59
  images = []
60
  for i in range(6):
61
+ img_path = "/".join([st.replace("/", "") for st in
62
+ cl_dct['img_path_list'][i].split("//")][3:])
63
+ images.append((Image.open(os.path.join("identities-images", img_path)),
64
+ "_".join([img_path.split("/")[0],
65
+ img_path.split("/")[-1]]).replace(
66
+ 'Photo_portrait_of_an_', '').replace(
67
+ 'Photo_portrait_of_a_', '').replace(
68
+ 'SD_v2_random_seeds_identity_', '(SD v.2) ').replace(
69
+ 'dataset-identities-dalle2_', '(Dall-E 2) ').replace(
70
+ 'SD_v1.4_random_seeds_identity_',
71
+ '(SD v.1.4) ').replace('_', ' ')))
72
  model_fig = go.Figure()
73
+ model_fig.add_trace(go.Pie(labels=list(dict(cl_dct["labels_model"]).keys()),
74
+ values=list(
75
+ dict(cl_dct["labels_model"]).values())))
76
  model_description = describe_cluster(dict(cl_dct["labels_model"]), "model")
77
 
78
  gender_fig = go.Figure()
79
+ gender_fig.add_trace(
80
+ go.Pie(labels=list(dict(cl_dct["labels_gender"]).keys()),
81
+ values=list(dict(cl_dct["labels_gender"]).values())))
82
+ gender_description = describe_cluster(dict(cl_dct["labels_gender"]),
83
+ "gender")
84
 
85
  ethnicity_fig = go.Figure()
86
+ ethnicity_fig.add_trace(
87
+ go.Bar(x=list(dict(cl_dct["labels_ethnicity"]).keys()),
88
+ y=list(dict(cl_dct["labels_ethnicity"]).values()),
89
+ marker_color=px.colors.qualitative.G10))
90
  return (len(cl_dct['img_path_list']),
91
+ gender_fig, gender_description,
92
  model_fig, model_description,
93
  ethnicity_fig,
94
  images,
95
+ gr.update(maximum=num_clusters - 1))
96
+
97
 
98
  with gr.Blocks(title=TITLE) as demo:
99
  gr.Markdown(f"# {TITLE}")
100
+ gr.Markdown(
101
+ "## Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!")
102
+ gr.Markdown(
103
+ "### This demo showcases patterns in the images generated from different prompts input to Stable Diffusion and Dalle-2 diffusion models.")
104
+ gr.Markdown(
105
+ "### Below, see results on how the images from different prompts cluster together.")
106
+ gr.HTML(
107
+ """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image models and may depict offensive stereotypes or contain explicit content.</span>""")
108
+ num_clusters = gr.Radio([12, 24, 48], value=12,
109
+ label="How many clusters do you want to make from the data?")
110
 
 
111
  with gr.Row():
112
  with gr.Column(scale=4):
113
+ gallery = gr.Gallery(
114
+ label="Most representative images in cluster").style(
115
+ grid=(3, 3))
116
  with gr.Column():
117
+ cluster_id = gr.Slider(minimum=0, maximum=num_clusters.value - 1,
118
+ step=1, value=0,
119
+ label="Click to move between clusters")
120
  a = gr.Text(label="Number of images")
121
  with gr.Row():
122
+ with gr.Column(scale=1):
123
+ c = gr.Plot(label="How many images from each model?")
124
+ c_desc = gr.HTML(label="")
125
+ with gr.Column(scale=1):
126
+ b = gr.Plot(label="How many gender terms are represented?")
127
+ b_desc = gr.HTML(label="")
128
+ with gr.Column(scale=2):
129
+ d = gr.Plot(label="Which ethnicity terms are present?")
130
 
131
+ gr.Markdown(
132
+ f"The 'Model makeup' plot corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.")
133
+ gr.Markdown(
134
+ 'The Gender plot shows the number of images based on the input prompts that used the words man, woman, non-binary person, and unmarked, which we label "person".')
135
+ gr.Markdown(
136
+ f"The 'Ethnicity label makeup' plot corresponds to the number of images from each of the 18 ethnicities used in the prompts. A blank value means unmarked ethnicity.")
137
+ demo.load(fn=show_cluster, inputs=[cluster_id, num_clusters],
138
+ outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
139
+ num_clusters.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
140
+ outputs=[a, b, b_desc, c, c_desc, d, gallery,
141
+ cluster_id])
142
+ cluster_id.change(fn=show_cluster, inputs=[cluster_id, num_clusters],
143
+ outputs=[a, b, b_desc, c, c_desc, d, gallery, cluster_id])
144
 
145
  if __name__ == "__main__":
146
+ demo.queue().launch(debug=True)