yjernite commited on
Commit
7935f3f
1 Parent(s): 7044e35
Files changed (1) hide show
  1. app.py +148 -77
app.py CHANGED
@@ -15,7 +15,6 @@ professions_dset = load_from_disk("professions")
15
  professions_df = professions_dset.to_pandas()
16
 
17
 
18
-
19
  clusters_dicts = dict(
20
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
21
  for num_cl in [12, 24, 48]
@@ -91,8 +90,15 @@ def make_profession_plot(num_clusters, prof_name):
91
  )
92
  df = pd.DataFrame.from_dict(pre_pandas)
93
  prof_plot = df.plot(kind="bar", barmode="group")
94
- return prof_plot, gr.update(
95
- choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0]
 
 
 
 
 
 
 
96
  )
97
 
98
 
@@ -143,43 +149,76 @@ def make_profession_table(num_clusters, prof_names, mod_name, max_cols=8):
143
  for prof_name, prof_clusters in professions_list_clusters
144
  ]
145
  clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas)
146
- return [c[0] for c in totals], (
147
- clusters_df.style.background_gradient(
148
- axis=None, vmin=0, vmax=100, cmap="YlGnBu"
149
- )
150
- .format(precision=1)
151
- .to_html()
 
 
 
 
 
 
 
152
  )
153
 
 
154
  def get_image(model, fname, score):
155
  return (
156
  professions_dset.select(
157
  professions_df[
158
- (professions_df["image_path"] == fname) & (professions_df["model"] == model)
 
159
  ].index
160
  )["image"][0],
161
- " ".join(fname.split("/")[0].split("_")[4:]) + f" | {score:.2f}" + f" | {models[model]}"
 
 
162
  )
163
 
164
 
165
- def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.5):
166
- # only show images where the similarity to the centroid is > 0.7
167
  examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
168
  "cluster_examplars"
169
  ][str(cl_id)]
170
- l = [tuple(img) for img in examplars_dict["close"] + examplars_dict["mid"][:2] + examplars_dict["far"]]
171
- l = [img for i, img in enumerate(l) if img[0] > confidence_threshold and img not in l[:i]]
 
 
 
 
 
 
 
 
 
172
  return (
173
  [get_image(model, fname, score) for score, model, fname in l],
174
- gr.update(label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}")
 
 
175
  )
176
 
177
 
178
  with gr.Blocks(title=TITLE) as demo:
179
- gr.Markdown("# 🤗 Diffusion Cluster Explorer")
180
- gr.Markdown("This tool helps you explore the different clusters that we discovered in the images generated by 3 text-to-image models: Dall-E 2, Stable Diffusion v.1.4 and v.2. This work was done in the scope of the [Stable Bias Project](https://huggingface.co/spaces/society-ethics/StableBias).")
 
 
 
 
 
 
 
 
 
181
  with gr.Tab("Professions Overview"):
182
- gr.Markdown("Select one or more professions and models from the dropdowns on the left to see which clusters are most representative for this combination. Try choosing different numbers of clusters to see if the results change, and then go to the 'Profession Focus' tab to go more in-depth into these results.")
 
 
183
  with gr.Row():
184
  with gr.Column(scale=1):
185
  gr.Markdown("Select the parameters here:")
@@ -211,86 +250,118 @@ with gr.Blocks(title=TITLE) as demo:
211
  table = gr.HTML(
212
  label="Profession assignment per cluster", wrap=True
213
  )
 
214
  # clusters = gr.Dataframe(type="array", visible=False, col_count=1)
215
  clusters = gr.Textbox(label="clusters", visible=False)
216
- demo.load(
217
- make_profession_table,
218
- [num_clusters, profession_choices_overview, model_choices],
219
- [clusters, table],
220
- queue=False,
221
- )
222
- for var in [num_clusters, model_choices, profession_choices_overview]:
223
- var.change(
224
- make_profession_table,
225
- [num_clusters, profession_choices_overview, model_choices],
226
- [clusters, table],
227
- queue=False,
228
  )
229
-
 
 
 
 
230
  with gr.Tab("Profession Focus"):
231
  with gr.Row():
232
  with gr.Column():
233
- gr.Markdown("Select profession to visualize and see which clusters and identity groups are most represented in the profession, as well as some examples of generated images below.")
234
- num_clusters_focus = gr.Radio(
235
- [12, 24, 48],
236
- value=12,
237
- label="How many clusters do you want to use to represent identities?",
238
  )
239
  profession_choice_focus = gr.Dropdown(
240
  choices=professions,
241
- value="social worker",
242
  label="Select profession:",
243
  )
244
- gr.Markdown(
245
- "You can show examples of profession images assigned to each cluster:"
246
- )
247
- cluster_id_focus = gr.Dropdown(
248
- choices=[i for i in range(num_clusters_focus.value)],
249
- value=0,
250
- label="Select cluster to visualize:",
251
  )
252
  with gr.Column():
253
  plot = gr.Plot(
254
  label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
255
  )
256
- demo.load(
257
- make_profession_plot,
258
- [num_clusters_focus, profession_choice_focus],
259
- [plot, cluster_id_focus],
260
- queue=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  )
262
- for var in [num_clusters_focus, profession_choice_focus]:
263
- var.change(
264
- make_profession_plot,
265
- [num_clusters_focus, profession_choice_focus],
266
- [plot, cluster_id_focus],
267
- queue=False,
268
  )
269
  with gr.Row():
270
  examplars_plot = gr.Gallery(
271
  label="Profession images assigned to the selected cluster."
272
  ).style(grid=4, height="auto", container=True)
273
- demo.load(
274
- show_examplars,
275
- [
276
- num_clusters_focus,
277
- profession_choice_focus,
278
- cluster_id_focus,
279
- ],
280
- [examplars_plot, examplars_plot],
281
- queue=False,
282
- )
283
- for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]:
284
- var.change(
285
- show_examplars,
286
- [
287
- num_clusters_focus,
288
- profession_choice_focus,
289
- cluster_id_focus,
290
- ],
291
- [examplars_plot, examplars_plot],
292
- queue=False,
293
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
 
296
  if __name__ == "__main__":
 
15
  professions_df = professions_dset.to_pandas()
16
 
17
 
 
18
  clusters_dicts = dict(
19
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
20
  for num_cl in [12, 24, 48]
 
90
  )
91
  df = pd.DataFrame.from_dict(pre_pandas)
92
  prof_plot = df.plot(kind="bar", barmode="group")
93
+ cl_summary_text = f"Profession ``{prof_name}'':\n"
94
+ for cl_id, _ in sorted_cl_scores:
95
+ cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][int(cl_id)].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n"
96
+ return (
97
+ prof_plot,
98
+ gr.update(
99
+ choices=[k for k, _ in sorted_cl_scores], value=sorted_cl_scores[0][0]
100
+ ),
101
+ gr.update(value=cl_summary_text),
102
  )
103
 
104
 
 
149
  for prof_name, prof_clusters in professions_list_clusters
150
  ]
151
  clusters_df = pd.DataFrame.from_dict(prof_list_pre_pandas)
152
+ cl_summary_text = ""
153
+ for cl_id, _ in totals[:max_cols]:
154
+ cl_summary_text += f"- {cluster_summaries_by_size[str(num_clusters)][cl_id].replace(' gender terms', '').replace('; ethnicity terms:', ',')} \n"
155
+ return (
156
+ [c[0] for c in totals],
157
+ (
158
+ clusters_df.style.background_gradient(
159
+ axis=None, vmin=0, vmax=100, cmap="YlGnBu"
160
+ )
161
+ .format(precision=1)
162
+ .to_html()
163
+ ),
164
+ gr.update(value=cl_summary_text),
165
  )
166
 
167
+
168
  def get_image(model, fname, score):
169
  return (
170
  professions_dset.select(
171
  professions_df[
172
+ (professions_df["image_path"] == fname)
173
+ & (professions_df["model"] == model)
174
  ].index
175
  )["image"][0],
176
+ " ".join(fname.split("/")[0].split("_")[4:])
177
+ + f" | {score:.2f}"
178
+ + f" | {models[model]}",
179
  )
180
 
181
 
182
+ def show_examplars(num_clusters, prof_name, cl_id, confidence_threshold=0.6):
183
+ # only show images where the similarity to the centroid is > confidence_threshold
184
  examplars_dict = clusters_dicts[num_clusters]["All"][prof_name][
185
  "cluster_examplars"
186
  ][str(cl_id)]
187
+ l = [
188
+ tuple(img)
189
+ for img in examplars_dict["close"]
190
+ + examplars_dict["mid"][:2]
191
+ + examplars_dict["far"]
192
+ ]
193
+ l = [
194
+ img
195
+ for i, img in enumerate(l)
196
+ if img[0] > confidence_threshold and img not in l[:i]
197
+ ]
198
  return (
199
  [get_image(model, fname, score) for score, model, fname in l],
200
+ gr.update(
201
+ label=f"Generations for profession ''{prof_name}'' assigned to cluster {cl_id} of {num_clusters}"
202
+ ),
203
  )
204
 
205
 
206
  with gr.Blocks(title=TITLE) as demo:
207
+ gr.Markdown(
208
+ """
209
+ # Identity Biases in Diffusion Models: Professions
210
+
211
+ This tool helps you explore the different clusters that we discovered in the images generated by 3 text-to-image models: Dall-E 2, Stable Diffusion v.1.4 and v.2.
212
+ This work was done in the scope of the [Stable Bias Project](https://huggingface.co/spaces/society-ethics/StableBias).
213
+ """
214
+ )
215
+ gr.HTML(
216
+ """<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>"""
217
+ )
218
  with gr.Tab("Professions Overview"):
219
+ gr.Markdown(
220
+ "Select one or more professions and models from the dropdowns on the left to see which clusters are most representative for this combination. Try choosing different numbers of clusters to see if the results change, and then go to the 'Profession Focus' tab to go more in-depth into these results."
221
+ )
222
  with gr.Row():
223
  with gr.Column(scale=1):
224
  gr.Markdown("Select the parameters here:")
 
250
  table = gr.HTML(
251
  label="Profession assignment per cluster", wrap=True
252
  )
253
+ with gr.Row():
254
  # clusters = gr.Dataframe(type="array", visible=False, col_count=1)
255
  clusters = gr.Textbox(label="clusters", visible=False)
256
+ gr.Markdown(
257
+ """
258
+ ##### What do the clusters mean?
259
+ Below is a summary of the identity cluster compositions.
260
+ For more details, see the [companion demo](https://huggingface.co/spaces/society-ethics/DiffusionFaceClustering):
261
+ """
 
 
 
 
 
 
262
  )
263
+ with gr.Row():
264
+ with gr.Accordion(label="Cluster summaries", open=False):
265
+ cluster_descriptions_table = gr.Text(
266
+ "TODO", label="Cluster summaries", show_label=False
267
+ )
268
  with gr.Tab("Profession Focus"):
269
  with gr.Row():
270
  with gr.Column():
271
+ gr.Markdown(
272
+ "Select profession to visualize and see which clusters and identity groups are most represented in the profession, as well as some examples of generated images below."
 
 
 
273
  )
274
  profession_choice_focus = gr.Dropdown(
275
  choices=professions,
276
+ value="scientist",
277
  label="Select profession:",
278
  )
279
+ num_clusters_focus = gr.Radio(
280
+ [12, 24, 48],
281
+ value=12,
282
+ label="How many clusters do you want to use to represent identities?",
 
 
 
283
  )
284
  with gr.Column():
285
  plot = gr.Plot(
286
  label=f"Makeup of the cluster assignments for profession {profession_choice_focus}"
287
  )
288
+ with gr.Row():
289
+ with gr.Column():
290
+ gr.Markdown(
291
+ """
292
+ ##### What do the clusters mean?
293
+ Below is a summary of the identity cluster compositions.
294
+ For more details, see the [companion demo](https://huggingface.co/spaces/society-ethics/DiffusionFaceClustering):
295
+ """
296
+ )
297
+ with gr.Accordion(label="Cluster summaries", open=False):
298
+ cluster_descriptions = gr.Text(
299
+ "TODO", label="Cluster summaries", show_label=False
300
+ )
301
+ with gr.Column():
302
+ gr.Markdown(
303
+ """
304
+ ##### What's in the clusters?
305
+ You can show examples of profession images assigned to each identity cluster by selecting one here:
306
+ """
307
  )
308
+ with gr.Accordion(label="Cluster selection", open=False):
309
+ cluster_id_focus = gr.Dropdown(
310
+ choices=[i for i in range(num_clusters_focus.value)],
311
+ value=0,
312
+ label="Select cluster to visualize:",
 
313
  )
314
  with gr.Row():
315
  examplars_plot = gr.Gallery(
316
  label="Profession images assigned to the selected cluster."
317
  ).style(grid=4, height="auto", container=True)
318
+ demo.load(
319
+ make_profession_table,
320
+ [num_clusters, profession_choices_overview, model_choices],
321
+ [clusters, table, cluster_descriptions_table],
322
+ queue=False,
323
+ )
324
+ demo.load(
325
+ make_profession_plot,
326
+ [num_clusters_focus, profession_choice_focus],
327
+ [plot, cluster_id_focus, cluster_descriptions],
328
+ queue=False,
329
+ )
330
+ demo.load(
331
+ show_examplars,
332
+ [
333
+ num_clusters_focus,
334
+ profession_choice_focus,
335
+ cluster_id_focus,
336
+ ],
337
+ [examplars_plot, examplars_plot],
338
+ queue=False,
339
+ )
340
+ for var in [num_clusters, model_choices, profession_choices_overview]:
341
+ var.change(
342
+ make_profession_table,
343
+ [num_clusters, profession_choices_overview, model_choices],
344
+ [clusters, table, cluster_descriptions_table],
345
+ queue=False,
346
+ )
347
+ for var in [num_clusters_focus, profession_choice_focus]:
348
+ var.change(
349
+ make_profession_plot,
350
+ [num_clusters_focus, profession_choice_focus],
351
+ [plot, cluster_id_focus, cluster_descriptions],
352
+ queue=False,
353
+ )
354
+ for var in [num_clusters_focus, profession_choice_focus, cluster_id_focus]:
355
+ var.change(
356
+ show_examplars,
357
+ [
358
+ num_clusters_focus,
359
+ profession_choice_focus,
360
+ cluster_id_focus,
361
+ ],
362
+ [examplars_plot, examplars_plot],
363
+ queue=False,
364
+ )
365
 
366
 
367
  if __name__ == "__main__":