cakiki commited on
Commit
792be34
1 Parent(s): 70e85c0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -2
app.py CHANGED
@@ -2,6 +2,8 @@ import gradio as gr
2
  import json
3
  import numpy as np
4
  import pandas as pd
 
 
5
  import operator
6
 
7
  pd.options.plotting.backend = "plotly"
@@ -9,6 +11,12 @@ pd.options.plotting.backend = "plotly"
9
 
10
  TITLE = "Diffusion Professions Cluster Explorer"
11
 
 
 
 
 
 
 
12
  clusters_dicts = dict(
13
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
14
  for num_cl in [12, 24, 48]
@@ -142,7 +150,8 @@ def show_examplars(num_clusters, prof_name, mod_name, cl_id):
142
  examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
143
  "cluster_examplars"
144
  ][str(cl_id)]
145
- return json.dumps(examplars_dict)
 
146
 
147
 
148
  with gr.Blocks(title=TITLE) as demo:
@@ -249,7 +258,7 @@ with gr.Blocks(title=TITLE) as demo:
249
  )
250
  with gr.Row():
251
  examplars_plot = (
252
- gr.JSON()
253
  ) # TODO: turn this into a plot with the actual images
254
  demo.load(
255
  show_examplars,
 
2
  import json
3
  import numpy as np
4
  import pandas as pd
5
+ from datasets import load_from_disk
6
+ from itertools import chain
7
  import operator
8
 
9
  pd.options.plotting.backend = "plotly"
 
11
 
12
  TITLE = "Diffusion Professions Cluster Explorer"
13
 
14
+ professions = load_from_disk("professions")
15
+ professions_df = professions.to_pandas()
16
+
17
+ def get_image(model, fname):
18
+ return professions.select(professions_df[(professions_df["image_path"]==fname) & (professions_df["model"]==model)].index)["image"][0]
19
+
20
  clusters_dicts = dict(
21
  (num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
22
  for num_cl in [12, 24, 48]
 
150
  examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
151
  "cluster_examplars"
152
  ][str(cl_id)]
153
+ l = list(chain(*[examplars_dict[k] for k in examplars_dict]))
154
+ return [get_image(model,fname) for _,model,fname in l]
155
 
156
 
157
  with gr.Blocks(title=TITLE) as demo:
 
258
  )
259
  with gr.Row():
260
  examplars_plot = (
261
+ gr.Gallery()
262
  ) # TODO: turn this into a plot with the actual images
263
  demo.load(
264
  show_examplars,