taskswithcode commited on
Commit
5fe6115
1 Parent(s): e4cf805
Files changed (3) hide show
  1. app.py +21 -14
  2. clus_app_clustypes.json +4 -0
  3. twc_clustering.py +82 -22
app.py CHANGED
@@ -103,16 +103,16 @@ def load_model(model_name,model_class,load_model_name):
103
 
104
 
105
  @st.experimental_memo
106
- def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster):
107
  texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
108
- results = _cluster.cluster(None,texts,embeddings,threshold)
109
  return results
110
 
111
 
112
- def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster):
113
  with st.spinner('Computing vectors for sentences'):
114
  texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
115
- results = cluster.cluster(None,texts,embeddings,threshold)
116
  #st.success("Similarity computation complete")
117
  return results
118
 
@@ -124,7 +124,7 @@ def get_model_info(model_names,model_name):
124
  return get_model_info(model_names,DEFAULT_HF_MODEL)
125
 
126
 
127
- def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model):
128
  display_area.text("Loading model:" + model_name)
129
  #Note. model_name may get mapped to new name in the call below for custom models
130
  orig_model_name = model_name
@@ -140,10 +140,10 @@ def run_test(model_names,model_name,sentences,display_area,threshold,user_upload
140
  display_area.text("Model " + model_name + " load complete")
141
  try:
142
  if (user_uploaded):
143
- results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
144
  else:
145
  display_area.text("Computing vectors for sentences")
146
- results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"])
147
  display_area.text("Similarity computation complete")
148
  return results
149
 
@@ -193,16 +193,19 @@ def init_session():
193
  st.session_state["model_name"] = "ss_test"
194
  st.session_state["threshold"] = 1.5
195
  st.session_state["file_name"] = "default"
 
196
  st.session_state["cluster"] = TWCClustering()
197
  else:
198
  print("Skipping init session")
199
 
200
- def app_main(app_mode,example_files,model_name_files):
201
  init_session()
202
  with open(example_files) as fp:
203
  example_file_names = json.load(fp)
204
  with open(model_name_files) as fp:
205
  model_names = json.load(fp)
 
 
206
  curr_use_case = use_case[app_mode].split(".")[0]
207
  st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
208
  st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
@@ -215,7 +218,7 @@ def app_main(app_mode,example_files,model_name_files):
215
 
216
  with st.form('twc_form'):
217
 
218
- step1_line = "Step 1. Upload text file(one sentence in a line) or choose an example text file below"
219
  if (app_mode == DOC_RETRIEVAL):
220
  step1_line += ". The first line is treated as the query"
221
  uploaded_file = st.file_uploader(step1_line, type=".txt")
@@ -224,14 +227,17 @@ def app_main(app_mode,example_files,model_name_files):
224
  options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
225
  st.write("")
226
  options_arr,markdown_str = construct_model_info_for_display(model_names)
227
- selection_label = 'Step 2. Select Model'
228
  selected_model = st.selectbox(label=selection_label,
229
  options = options_arr, index=0, key = "twc_model")
230
  st.write("")
231
  custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
232
  hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
233
  st.markdown(hf_link_str, unsafe_allow_html=True)
234
- threshold = st.number_input('Step 3. Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
 
 
 
235
  st.write("")
236
  submit_button = st.form_submit_button('Run')
237
 
@@ -256,7 +262,8 @@ def app_main(app_mode,example_files,model_name_files):
256
  run_model = selected_model
257
  st.session_state["model_name"] = selected_model
258
  st.session_state["threshold"] = threshold
259
- results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0))
 
260
  display_area.empty()
261
  with display_area.container():
262
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
@@ -269,7 +276,7 @@ def app_main(app_mode,example_files,model_name_files):
269
  label="Download results as json",
270
  data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
271
  disabled = False if st.session_state["download_ready"] != None else True,
272
- file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
273
  mime='text/json',
274
  key ="download"
275
  )
@@ -288,5 +295,5 @@ if __name__ == "__main__":
288
  #print("comand line input:",len(sys.argv),str(sys.argv))
289
  #app_main(sys.argv[1],sys.argv[2],sys.argv[3])
290
  #app_main("1","sim_app_examples.json","sim_app_models.json")
291
- app_main("3","clus_app_examples.json","clus_app_models.json")
292
 
 
103
 
104
 
105
  @st.experimental_memo
106
+ def cached_compute_similarity(sentences,_model,model_name,threshold,_cluster,clustering_type):
107
  texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
108
+ results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
109
  return results
110
 
111
 
112
+ def uncached_compute_similarity(sentences,_model,model_name,threshold,cluster,clustering_type):
113
  with st.spinner('Computing vectors for sentences'):
114
  texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
115
+ results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
116
  #st.success("Similarity computation complete")
117
  return results
118
 
 
124
  return get_model_info(model_names,DEFAULT_HF_MODEL)
125
 
126
 
127
+ def run_test(model_names,model_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
128
  display_area.text("Loading model:" + model_name)
129
  #Note. model_name may get mapped to new name in the call below for custom models
130
  orig_model_name = model_name
 
140
  display_area.text("Model " + model_name + " load complete")
141
  try:
142
  if (user_uploaded):
143
+ results = uncached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
144
  else:
145
  display_area.text("Computing vectors for sentences")
146
+ results = cached_compute_similarity(sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
147
  display_area.text("Similarity computation complete")
148
  return results
149
 
 
193
  st.session_state["model_name"] = "ss_test"
194
  st.session_state["threshold"] = 1.5
195
  st.session_state["file_name"] = "default"
196
+ st.session_state["overlapped"] = "overlapped"
197
  st.session_state["cluster"] = TWCClustering()
198
  else:
199
  print("Skipping init session")
200
 
201
+ def app_main(app_mode,example_files,model_name_files,clus_types):
202
  init_session()
203
  with open(example_files) as fp:
204
  example_file_names = json.load(fp)
205
  with open(model_name_files) as fp:
206
  model_names = json.load(fp)
207
+ with open(clus_types) as fp:
208
+ cluster_types = json.load(fp)
209
  curr_use_case = use_case[app_mode].split(".")[0]
210
  st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for tasks using sentence embeddings</h5>", unsafe_allow_html=True)
211
  st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
 
218
 
219
  with st.form('twc_form'):
220
 
221
+ step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
222
  if (app_mode == DOC_RETRIEVAL):
223
  step1_line += ". The first line is treated as the query"
224
  uploaded_file = st.file_uploader(step1_line, type=".txt")
 
227
  options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
228
  st.write("")
229
  options_arr,markdown_str = construct_model_info_for_display(model_names)
230
+ selection_label = 'Select Model'
231
  selected_model = st.selectbox(label=selection_label,
232
  options = options_arr, index=0, key = "twc_model")
233
  st.write("")
234
  custom_model_selection = st.text_input("Model not listed above? Type any Huggingface sentence embedding model name ", "",key="custom_model")
235
  hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Huggingface sentence embedding models</a><br/><br/><br/></div>"
236
  st.markdown(hf_link_str, unsafe_allow_html=True)
237
+ threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
238
+ st.write("")
239
+ clustering_type = st.selectbox(label=f'Select type of clustering',
240
+ options = list(dict.keys(cluster_types)), index=0, key = "twc_cluster_types")
241
  st.write("")
242
  submit_button = st.form_submit_button('Run')
243
 
 
262
  run_model = selected_model
263
  st.session_state["model_name"] = selected_model
264
  st.session_state["threshold"] = threshold
265
+ st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
266
+ results = run_test(model_names,run_model,sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
267
  display_area.empty()
268
  with display_area.container():
269
  device = 'GPU' if torch.cuda.is_available() else 'CPU'
 
276
  label="Download results as json",
277
  data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
278
  disabled = False if st.session_state["download_ready"] != None else True,
279
+ file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + st.session_state["overlapped"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
280
  mime='text/json',
281
  key ="download"
282
  )
 
295
  #print("comand line input:",len(sys.argv),str(sys.argv))
296
  #app_main(sys.argv[1],sys.argv[2],sys.argv[3])
297
  #app_main("1","sim_app_examples.json","sim_app_models.json")
298
+ app_main("3","clus_app_examples.json","clus_app_models.json","clus_app_clustypes.json")
299
 
clus_app_clustypes.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "Overlapped clustering (cluster size determined by zscore)": {"type":"overlapped"},
3
+ "Non-overlapped clustering (overlapped clusters aggregated)":{"type":"non-overlapped"}
4
+ }
twc_clustering.py CHANGED
@@ -31,27 +31,30 @@ class TWCClustering:
31
  picked_arr = []
32
  while (run_index < len(embeddings)):
33
  if (matrix[pivot_index][run_index] >= threshold):
34
- #picked_arr.append({"index":run_index,"val":matrix[pivot_index][run_index]})
35
- picked_arr.append({"index":run_index})
36
  run_index += 1
37
  return picked_arr
38
 
 
 
 
 
39
  def update_picked_dict(self,picked_dict,in_dict):
40
  for key in in_dict:
41
  picked_dict[key] = 1
42
 
43
- def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold):
44
  center_index = pivot_index
45
  center_score = 0
46
  center_dict = {}
47
  for i in range(len(arr)):
48
- node_i_index = arr[i]["index"]
49
  running_score = 0
50
  temp_dict = {}
51
  for j in range(len(arr)):
52
- node_j_index = arr[j]["index"]
53
  cosine_dist = matrix[node_i_index][node_j_index]
54
- if (cosine_dist < threshold):
55
  continue
56
  running_score += cosine_dist
57
  temp_dict[node_j_index] = cosine_dist
@@ -80,8 +83,76 @@ class TWCClustering:
80
  bucket_dict[overlap_dict[key]] += 1
81
  sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
82
  return sorted_d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- def cluster(self,output_file,texts,embeddings,threshold = 1.5):
 
85
  matrix = self.compute_matrix(embeddings)
86
  mean = np.mean(matrix)
87
  std = np.std(matrix)
@@ -95,22 +166,11 @@ class TWCClustering:
95
  #print("In clustering:",round(std,2),zscores)
96
  cluster_dict = {}
97
  cluster_dict["clusters"] = []
98
- picked_dict = {}
99
- overlap_dict = {}
100
-
101
- for i in range(len(embeddings)):
102
- if (i in picked_dict):
103
- continue
104
- zscore = mean + threshold*std
105
- arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
106
- cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore)
107
- self.update_picked_dict(picked_dict,cluster_info["neighs"])
108
- self.update_overlap_stats(overlap_dict,cluster_info)
109
- cluster_dict["clusters"].append(cluster_info)
110
  curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
111
- sorted_d = OrderedDict(sorted(overlap_dict.items(), key=lambda kv: kv[1], reverse=True))
112
- #print(sorted_d)
113
- sorted_d = self.bucket_overlap(overlap_dict)
114
  cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
115
  return cluster_dict
116
 
 
31
  picked_arr = []
32
  while (run_index < len(embeddings)):
33
  if (matrix[pivot_index][run_index] >= threshold):
34
+ picked_arr.append(run_index)
 
35
  run_index += 1
36
  return picked_arr
37
 
38
+ def update_picked_dict_arr(self,picked_dict,arr):
39
+ for i in range(len(arr)):
40
+ picked_dict[arr[i]] = 1
41
+
42
  def update_picked_dict(self,picked_dict,in_dict):
43
  for key in in_dict:
44
  picked_dict[key] = 1
45
 
46
+ def find_pivot_subgraph(self,pivot_index,arr,matrix,threshold,strict_cluster = True):
47
  center_index = pivot_index
48
  center_score = 0
49
  center_dict = {}
50
  for i in range(len(arr)):
51
+ node_i_index = arr[i]
52
  running_score = 0
53
  temp_dict = {}
54
  for j in range(len(arr)):
55
+ node_j_index = arr[j]
56
  cosine_dist = matrix[node_i_index][node_j_index]
57
+ if ((cosine_dist < threshold) and strict_cluster):
58
  continue
59
  running_score += cosine_dist
60
  temp_dict[node_j_index] = cosine_dist
 
83
  bucket_dict[overlap_dict[key]] += 1
84
  sorted_d = OrderedDict(sorted(bucket_dict.items(), key=lambda kv: kv[1], reverse=False))
85
  return sorted_d
86
+
87
+ def merge_clusters(self,ref_cluster,curr_cluster):
88
+ dup_arr = ref_cluster.copy()
89
+ for j in range(len(curr_cluster)):
90
+ if (curr_cluster[j] not in dup_arr):
91
+ ref_cluster.append(curr_cluster[j])
92
+
93
+
94
+ def non_overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
95
+ picked_dict = {}
96
+ overlap_dict = {}
97
+ candidates = []
98
+
99
+ for i in range(len(embeddings)):
100
+ if (i in picked_dict):
101
+ continue
102
+ zscore = mean + threshold*std
103
+ arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
104
+ candidates.append(arr)
105
+ self.update_picked_dict_arr(picked_dict,arr)
106
+
107
+ # Merge arrays to create non-overlapping sets
108
+ run_index_i = 0
109
+ while (run_index_i < len(candidates)):
110
+ ref_cluster = candidates[run_index_i]
111
+ run_index_j = run_index_i + 1
112
+ found = False
113
+ while (run_index_j < len(candidates)):
114
+ curr_cluster = candidates[run_index_j]
115
+ for k in range(len(curr_cluster)):
116
+ if (curr_cluster[k] in ref_cluster):
117
+ self.merge_clusters(ref_cluster,curr_cluster)
118
+ candidates.pop(run_index_j)
119
+ found = True
120
+ run_index_i = 0
121
+ break
122
+ if (found):
123
+ break
124
+ else:
125
+ run_index_j += 1
126
+ if (not found):
127
+ run_index_i += 1
128
+
129
+
130
+ zscore = mean + threshold*std
131
+ for i in range(len(candidates)):
132
+ arr = candidates[i]
133
+ cluster_info = self.find_pivot_subgraph(arr[0],arr,matrix,zscore,strict_cluster = False)
134
+ cluster_dict["clusters"].append(cluster_info)
135
+ return {}
136
+
137
+ def overlapped_clustering(self,matrix,embeddings,threshold,mean,std,cluster_dict):
138
+ picked_dict = {}
139
+ overlap_dict = {}
140
+
141
+ zscore = mean + threshold*std
142
+ for i in range(len(embeddings)):
143
+ if (i in picked_dict):
144
+ continue
145
+ arr = self.get_terms_above_threshold(matrix,embeddings,i,zscore)
146
+ cluster_info = self.find_pivot_subgraph(i,arr,matrix,zscore,strict_cluster = True)
147
+ self.update_picked_dict(picked_dict,cluster_info["neighs"])
148
+ self.update_overlap_stats(overlap_dict,cluster_info)
149
+ cluster_dict["clusters"].append(cluster_info)
150
+ sorted_d = self.bucket_overlap(overlap_dict)
151
+ return sorted_d
152
+
153
 
154
+ def cluster(self,output_file,texts,embeddings,threshold,clustering_type):
155
+ is_overlapped = True if clustering_type == "overlapped" else False
156
  matrix = self.compute_matrix(embeddings)
157
  mean = np.mean(matrix)
158
  std = np.std(matrix)
 
166
  #print("In clustering:",round(std,2),zscores)
167
  cluster_dict = {}
168
  cluster_dict["clusters"] = []
169
+ if (is_overlapped):
170
+ sorted_d = self.overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
171
+ else:
172
+ sorted_d = self.non_overlapped_clustering(matrix,embeddings,threshold,mean,std,cluster_dict)
 
 
 
 
 
 
 
 
173
  curr_threshold = f"{threshold} (cosine:{mean+threshold*std:.2f})"
 
 
 
174
  cluster_dict["info"] ={"mean":mean,"std":std,"current_threshold":curr_threshold,"zscores":zscores,"overlap":list(sorted_d.items())}
175
  return cluster_dict
176