yjernite commited on
Commit
91c823d
β€’
1 Parent(s): 2fe028c

profession visualization

Browse files
Files changed (1) hide show
  1. app.py +80 -19
app.py CHANGED
@@ -17,38 +17,99 @@ clusters_by_size = {
17
  48: clusters_48,
18
  }
19
 
20
- prompts = pd.read_csv('promptsadjectives.csv')
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- m_adjectives = prompts['Masc-adj'].tolist()[:10]
23
- f_adjectives = prompts['Fem-adj'].tolist()[:10]
24
- adjectives = sorted(m_adjectives+f_adjectives)
25
- adjectives.insert(0, '')
26
- professions = sorted([p.lower() for p in prompts['Occupation-Noun'].tolist()])
27
- models = ["Dall-E 2", "Stable Diffusion 1.4", "Stable Diffusion 2"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  with gr.Blocks() as demo:
31
  gr.Markdown("# πŸ€— Diffusion Cluster Explorer")
32
  gr.Markdown("description will go here")
33
- with gr.Tab("Exploring all professions together"):
34
  gr.Markdown("TODO")
35
  with gr.Row():
36
  with gr.Column():
37
  gr.Markdown("Select your settings below:")
38
- num_clusters = gr.Radio([12,24,48], value=12, label="How many clusters do you want to make from the data?")
39
- professions = gr.Dropdown
 
 
 
40
  with gr.Column():
41
  gr.Markdown("")
42
- order = gr.Dropdown(["entropy", "cluster/sum of clusters"], value="entryopy", label="Order rows by:", interactive=True)
43
- # with gr.Row():
44
-
45
- # with gr.Accordion("Tag Frequencies", open=False):
 
 
 
46
 
47
- with gr.Tab("Tab 2"):
48
- gr.Markdown("TODO"
49
- )
50
-
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- demo.launch()
54
 
 
 
17
  48: clusters_48,
18
  }
19
 
20
+ prompts = pd.read_csv("promptsadjectives.csv")
21
+ # m_adjectives = prompts['Masc-adj'].tolist()[:10]
22
+ # f_adjectives = prompts['Fem-adj'].tolist()[:10]
23
+ # adjectives = sorted(m_adjectives+f_adjectives)
24
+ # adjectives.insert(0, '')
25
+ professions = list(sorted([p.lower() for p in prompts["Occupation-Noun"].tolist()]))
26
+ models = {
27
+ "All": "All Models",
28
+ "SD_14": "Stable Diffusion 1.4",
29
+ "SD_2": "Stable Diffusion 2",
30
+ "DallE": "Dall-E 2",
31
+ }
32
+
33
 
34
+ def make_profession_table(num_clusters, prof_name):
35
+ pre_pandas = dict(
36
+ [
37
+ (
38
+ models[mod_name],
39
+ dict(
40
+ (
41
+ f"Cluster {k}",
42
+ clusters_by_size[num_clusters][mod_name][prof_name][
43
+ "cluster_proportions"
44
+ ][k],
45
+ )
46
+ for k, v in sorted(
47
+ clusters_by_size[num_clusters]["All"][prof_name][
48
+ "cluster_proportions"
49
+ ].items(),
50
+ key=lambda x: x[1],
51
+ reverse=True,
52
+ )
53
+ if v > 0
54
+ ),
55
+ )
56
+ for mod_name in models
57
+ ]
58
+ )
59
+ df = pd.DataFrame.from_dict(pre_pandas)
60
+ prof_plot = df.plot(kind="bar", barmode="group")
61
+ return prof_plot
62
 
63
 
64
  with gr.Blocks() as demo:
65
  gr.Markdown("# πŸ€— Diffusion Cluster Explorer")
66
  gr.Markdown("description will go here")
67
+ with gr.Tab("Professions Overview"):
68
  gr.Markdown("TODO")
69
  with gr.Row():
70
  with gr.Column():
71
  gr.Markdown("Select your settings below:")
72
+ num_clusters = gr.Radio(
73
+ [12, 24, 48],
74
+ value=12,
75
+ label="How many clusters do you want to use to represent identities?",
76
+ )
77
  with gr.Column():
78
  gr.Markdown("")
79
+ order = gr.Dropdown(
80
+ ["entropy", "cluster/sum of clusters"],
81
+ value="entryopy",
82
+ label="Order rows by:",
83
+ interactive=True,
84
+ )
85
+ # with gr.Row():
86
 
87
+ # with gr.Accordion("Tag Frequencies", open=False):
 
 
 
88
 
89
+ with gr.Tab("Profession Focus"):
90
+ with gr.Row():
91
+ num_clusters = gr.Radio(
92
+ [12, 24, 48],
93
+ value=12,
94
+ label="How many clusters do you want to use to represent identities?",
95
+ )
96
+ with gr.Row():
97
+ with gr.Column():
98
+ profession_choice = gr.Dropdown(
99
+ choices=professions, label="Select profession:"
100
+ )
101
+ with gr.Column():
102
+ plot = gr.Plot(
103
+ label=f"Makeup of the cluster assignments for profession {profession_choice}"
104
+ )
105
+ profession_choice.change(
106
+ make_profession_table,
107
+ [num_clusters, profession_choice],
108
+ plot,
109
+ queue=False,
110
+ )
111
+ with gr.Row():
112
+ gr.Markdown("TODO: show examplars for cluster")
113
 
 
114
 
115
+ demo.launch()