vaibhavsharda commited on
Commit
bd4f103
1 Parent(s): 7e8b361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -203
app.py CHANGED
@@ -2,135 +2,136 @@ import time
2
  import sys
3
  import streamlit as st
4
  import string
5
- from io import StringIO
6
  import pdb
7
  import json
8
- from twc_embeddings import HFModel, SimCSEModel, SGPTModel, CausalLMModel, SGPTQnAModel
9
  from twc_openai_embeddings import OpenAIModel
10
  from twc_clustering import TWCClustering
11
  import torch
12
  import requests
13
  import socket
14
 
 
15
  MAX_INPUT = 10000
16
 
17
- SEM_SIMILARITY = "1"
18
- DOC_RETRIEVAL = "2"
19
- CLUSTERING = "3"
 
 
 
 
 
20
 
21
- use_case = {"1": "Finding similar phrases/sentences",
22
- "2": "Retrieving semantically matching information to a query. It may not be a factual match",
23
- "3": "Clustering"}
24
- use_case_url = {"1": "https://huggingface.co/spaces/taskswithcode/semantic_similarity",
25
- "2": "https://huggingface.co/spaces/taskswithcode/semantic_search", "3": ""}
26
 
27
  from transformers import BertTokenizer, BertForMaskedLM
28
 
 
29
  APP_NAME = "hf/semantic_clustering"
30
  INFO_URL = "https://www.taskswithcode.com/stats/"
31
 
32
 
 
 
 
33
  def get_views(action):
34
  ret_val = 0
35
  hostname = socket.gethostname()
36
  ip_address = socket.gethostbyname(hostname)
37
  if ("view_count" not in st.session_state):
38
  try:
39
- app_info = {'name': APP_NAME, "action": action, "host": hostname, "ip": ip_address}
40
- res = requests.post(INFO_URL, json=app_info).json()
41
- print(res)
42
- data = res["count"]
43
  except:
44
- data = 0
45
  ret_val = data
46
  st.session_state["view_count"] = data
47
  else:
48
  ret_val = st.session_state["view_count"]
49
  if (action != "init"):
50
- app_info = {'name': APP_NAME, "action": action, "host": hostname, "ip": ip_address}
51
- res = requests.post(INFO_URL, json=app_info).json()
52
  return "{:,}".format(ret_val)
 
 
53
 
54
 
55
  def construct_model_info_for_display(model_names):
56
- options_arr = []
57
  markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
58
  markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
59
  for node in model_names:
60
- options_arr.append(node["name"])
61
  if (node["mark"] == "True"):
62
  markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\">&nbsp;•&nbsp;Model:&nbsp;<a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Code released by:&nbsp;<a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Model info:&nbsp;<a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
63
  if ("Note" in node):
64
  markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\">&nbsp;&nbsp;&nbsp;&nbsp;{node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
65
  markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
66
-
67
  markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>•&nbsp;Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
68
  limit = "{:,}".format(MAX_INPUT)
69
  markdown_str += f"<div style=\"font-size:12px; color: #9f9f9f; text-align: left\">•&nbsp;User uploaded file has a maximum limit of {limit} sentences.</div>"
70
- return options_arr, markdown_str
71
 
72
 
73
- st.set_page_config(
74
- page_title='TWC - Compare popular/state-of-the-art models for semantic clustering using sentence embeddings',
75
- page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
76
- menu_items={
77
- 'About': 'This app was created by taskswithcode. http://taskswithcode.com'
78
-
79
- })
80
- col, pad = st.columns([85, 15])
81
 
82
  with col:
83
  st.image("long_form_logo_with_icon.png")
84
 
85
 
86
  @st.experimental_memo
87
- def load_model(model_name, model_class, load_model_name):
88
  try:
89
  ret_model = None
90
  obj_class = globals()[model_class]
91
  ret_model = obj_class()
92
  ret_model.init_model(load_model_name)
93
- assert (ret_model is not None)
94
  except Exception as e:
95
- st.error(
96
- f"Unable to load model class:{model_class} model_name: {model_name} load_model_name: {load_model_name} {str(e)}")
97
  pass
98
  return ret_model
99
 
100
 
 
101
  @st.experimental_memo
102
- def cached_compute_similarity(input_file_name, sentences, _model, model_name, threshold, _cluster, clustering_type):
103
- texts, embeddings = _model.compute_embeddings(input_file_name, sentences, is_file=False)
104
- results = _cluster.cluster(None, texts, embeddings, threshold, clustering_type)
105
  return results
106
 
107
 
108
- def uncached_compute_similarity(input_file_name, sentences, _model, model_name, threshold, cluster, clustering_type):
109
  with st.spinner('Computing vectors for sentences'):
110
- texts, embeddings = _model.compute_embeddings(input_file_name, sentences, is_file=False)
111
- results = cluster.cluster(None, texts, embeddings, threshold, clustering_type)
112
- # st.success("Similarity computation complete")
113
  return results
114
 
115
-
116
  DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
117
-
118
-
119
- def get_model_info(model_names, model_name):
120
  for node in model_names:
121
  if (model_name == node["name"]):
122
- return node, model_name
123
- return get_model_info(model_names, DEFAULT_HF_MODEL)
124
 
125
 
126
- def run_test(model_names, model_name, input_file_name, sentences, display_area, threshold, user_uploaded, custom_model,
127
- clustering_type):
128
  display_area.text("Loading model:" + model_name)
129
- # Note. model_name may get mapped to new name in the call below for custom models
130
  orig_model_name = model_name
131
- model_info, model_name = get_model_info(model_names, model_name)
132
  if (model_name != orig_model_name):
133
- load_model_name = orig_model_name
134
  else:
135
  load_model_name = model_info["model"]
136
  if ("Note" in model_info):
@@ -139,27 +140,28 @@ def run_test(model_names, model_name, input_file_name, sentences, display_area,
139
  if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
140
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
141
  display_area.write(fail_link)
142
- return {"error": fail_link}
143
- model = load_model(model_name, model_info["class"], load_model_name)
144
- display_area.text("Model " + model_name + " load complete")
145
  try:
146
- if (user_uploaded):
147
- results = uncached_compute_similarity(input_file_name, sentences, model, model_name, threshold,
148
- st.session_state["cluster"], clustering_type)
149
- else:
150
- display_area.text("Computing vectors for sentences")
151
- results = cached_compute_similarity(input_file_name, sentences, model, model_name, threshold,
152
- st.session_state["cluster"], clustering_type)
153
- display_area.text("Similarity computation complete")
154
- return results
155
-
156
  except Exception as e:
157
  st.error("Some error occurred during prediction" + str(e))
158
  st.stop()
159
  return {}
160
 
161
 
162
- def display_results(orig_sentences, results, response_info, app_mode, model_name):
 
 
 
163
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
164
  main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model:&nbsp;<b>{model_name}</b></div>"
165
  score_text = "cosine distance"
@@ -169,32 +171,30 @@ def display_results(orig_sentences, results, response_info, app_mode, model_name
169
  for i in range(len(results["clusters"])):
170
  pivot_index = results["clusters"][i]["pivot_index"]
171
  pivot_sent = orig_sentences[pivot_index]
172
- pivot_index += 1
173
  d_cluster = {}
174
  download_data[i + 1] = d_cluster
175
- d_cluster["pivot"] = {"pivot_index": pivot_index, "sent": pivot_sent, "children": {}}
176
- body_sent.append(
177
- f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{pivot_index}]&nbsp;{pivot_sent}&nbsp;<b><i>(Cluster {i + 1})</i></b>&nbsp;&nbsp;</div>")
178
  neighs_dict = results["clusters"][i]["neighs"]
179
  for key in neighs_dict:
180
  cosine_dist = neighs_dict[key]
181
  child_index = key
182
  sentence = orig_sentences[child_index]
183
  child_index += 1
184
- body_sent.append(
185
- f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{child_index}]&nbsp;{sentence}&nbsp;&nbsp;&nbsp;<b>{cosine_dist:.2f}</b></div>")
186
- d_cluster["pivot"]["children"][sentence] = f"{cosine_dist:.2f}"
187
  body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">&nbsp;</div>")
188
  main_sent = main_sent + "\n" + '\n'.join(body_sent)
189
- st.markdown(main_sent, unsafe_allow_html=True)
190
- st.session_state["download_ready"] = json.dumps(download_data, indent=4)
191
  get_views("submit")
192
 
193
 
194
  def init_session():
195
  if ("model_name" not in st.session_state):
196
  st.session_state["model_name"] = "ss_test"
197
- st.session_state["download_ready"] = None
198
  st.session_state["model_name"] = "ss_test"
199
  st.session_state["threshold"] = 1.5
200
  st.session_state["file_name"] = "default"
@@ -202,139 +202,106 @@ def init_session():
202
  st.session_state["cluster"] = TWCClustering()
203
  else:
204
  print("Skipping init session")
205
-
206
-
207
- def app_main(app_mode, example_files, model_name_files, clus_types):
208
- init_session()
209
- with open(example_files) as fp:
210
- example_file_names = json.load(fp)
211
- with open(model_name_files) as fp:
212
  model_names = json.load(fp)
213
- with open(clus_types) as fp:
214
  cluster_types = json.load(fp)
215
- curr_use_case = use_case[app_mode].split(".")[0]
216
- st.markdown(
217
- "<h5 style='text-align: center;'>Compare popular/state-of-the-art models for semantic clustering using sentence embeddings</h5>",
218
- unsafe_allow_html=True)
219
- st.markdown(
220
- f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>",
221
- unsafe_allow_html=True)
222
- st.markdown(
223
- f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>",
224
- unsafe_allow_html=True)
225
- st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>",
226
- unsafe_allow_html=True)
227
-
228
- try:
229
-
230
- with st.form('twc_form'):
231
-
232
- step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
233
- if (app_mode == DOC_RETRIEVAL):
234
- step1_line += ". The first line is treated as the query"
235
- uploaded_file = st.file_uploader(step1_line, type=".txt")
236
-
237
- selected_file_index = st.selectbox(label=f'Example files ({len(example_file_names)})',
238
- options=list(dict.keys(example_file_names)), index=0, key="twc_file")
239
- st.write("")
240
- options_arr, markdown_str = construct_model_info_for_display(model_names)
241
- selection_label = 'Select Model'
242
- selected_model = st.selectbox(label=selection_label,
243
- options=options_arr, index=0, key="twc_model")
244
- st.write("")
245
- custom_model_selection = st.text_input(
246
- "Model not listed above? Type any Hugging Face sentence embedding model name ", "", key="custom_model")
247
- hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Hugging Face sentence embedding models</a><br/><br/><br/></div>"
248
- st.markdown(hf_link_str, unsafe_allow_html=True)
249
- threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',
250
- value=st.session_state["threshold"], min_value=0.0, step=.01)
251
- st.write("")
252
- clustering_type = st.selectbox(label=f'Select type of clustering',
253
- options=list(dict.keys(cluster_types)), index=0, key="twc_cluster_types")
254
- st.write("")
255
- submit_button = st.form_submit_button('Run')
256
-
257
- input_status_area = st.empty()
258
- display_area = st.empty()
259
-
260
- st.download_button(
261
- label="Download results as json",
262
- data=st.session_state["download_ready"] if st.session_state["download_ready"] is not None else "",
263
- disabled=False if st.session_state["download_ready"] is not None else True,
264
- file_name=(st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" +
265
- st.session_state["overlapped"] + "_" + '_'.join(
266
- st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/", "_"),
267
- mime='text/json',
268
- key="download"
269
- )
270
- if submit_button:
271
- start = time.time()
272
- if uploaded_file is not None:
273
- st.session_state["file_name"] = uploaded_file.name
274
- sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read()
 
 
275
  else:
276
- st.session_state["file_name"] = example_file_names[selected_file_index]["name"]
277
- sentences = open(example_file_names[selected_file_index]["name"]).read()
278
- sentences = sentences.split("\n")[:-1]
279
- if (len(sentences) > MAX_INPUT):
280
- st.info(
281
- f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
282
- sentences = sentences[:MAX_INPUT]
283
- if (len(custom_model_selection) != 0):
284
- run_model = custom_model_selection
285
- else:
286
- run_model = selected_model
287
- st.session_state["model_name"] = selected_model
288
- st.session_state["threshold"] = threshold
289
- st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
290
- results = run_test(model_names, run_model, st.session_state["file_name"], sentences, display_area,
291
- threshold, (uploaded_file is not None), (len(custom_model_selection) != 0),
292
- cluster_types[clustering_type]["type"])
293
- display_area.empty()
294
- with display_area.container():
295
- if ("error" in results):
296
- st.error(results["error"])
297
- else:
298
- device = 'GPU' if torch.cuda.is_available() else 'CPU'
299
- response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
300
- if (len(custom_model_selection) != 0):
301
- st.info(
302
- "Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
303
- display_results(sentences, results, response_info, app_mode, run_model)
304
- # st.json(results)
305
- st.download_button(
306
- label="Download results as json",
307
- data=st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
308
- disabled=False if st.session_state["download_ready"] != None else True,
309
- file_name=(st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" +
310
- st.session_state["overlapped"] + "_" + '_'.join(
311
- st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/", "_"),
312
- mime='text/json',
313
- key="download"
314
  )
 
 
315
 
316
-
317
-
318
- except Exception as e:
319
- st.download_button(
320
- label="Download results as json",
321
- data=st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
322
- disabled=False if st.session_state["download_ready"] != None else True,
323
- file_name=(st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" +
324
- st.session_state["overlapped"] + "_" + '_'.join(
325
- st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/", "_"),
326
- mime='text/json',
327
- key="download"
328
- )
329
- st.error("Some error occurred during loading" + str(e))
330
- #st.stop()
331
-
332
- st.markdown(markdown_str, unsafe_allow_html=True)
333
-
334
 
335
  if __name__ == "__main__":
336
- # print("comand line input:",len(sys.argv),str(sys.argv))
337
- # app_main(sys.argv[1],sys.argv[2],sys.argv[3])
338
- # app_main("1","sim_app_examples.json","sim_app_models.json")
339
- app_main("3", "clus_app_examples.json", "clus_app_models.json", "clus_app_clustypes.json")
340
 
 
2
  import sys
3
  import streamlit as st
4
  import string
5
+ from io import StringIO
6
  import pdb
7
  import json
8
+ from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
9
  from twc_openai_embeddings import OpenAIModel
10
  from twc_clustering import TWCClustering
11
  import torch
12
  import requests
13
  import socket
14
 
15
+
16
  MAX_INPUT = 10000
17
 
18
+ SEM_SIMILARITY="1"
19
+ DOC_RETRIEVAL="2"
20
+ CLUSTERING="3"
21
+
22
+
23
+ use_case = {"1":"Finding similar phrases/sentences","2":"Retrieving semantically matching information to a query. It may not be a factual match","3":"Clustering"}
24
+ use_case_url = {"1":"https://huggingface.co/spaces/taskswithcode/semantic_similarity","2":"https://huggingface.co/spaces/taskswithcode/semantic_search","3":""}
25
+
26
 
 
 
 
 
 
27
 
28
  from transformers import BertTokenizer, BertForMaskedLM
29
 
30
+
31
  APP_NAME = "hf/semantic_clustering"
32
  INFO_URL = "https://www.taskswithcode.com/stats/"
33
 
34
 
35
+
36
+
37
+
38
  def get_views(action):
39
  ret_val = 0
40
  hostname = socket.gethostname()
41
  ip_address = socket.gethostbyname(hostname)
42
  if ("view_count" not in st.session_state):
43
  try:
44
+ app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
45
+ res = requests.post(INFO_URL, json = app_info).json()
46
+ print(res)
47
+ data = res["count"]
48
  except:
49
+ data = 0
50
  ret_val = data
51
  st.session_state["view_count"] = data
52
  else:
53
  ret_val = st.session_state["view_count"]
54
  if (action != "init"):
55
+ app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
56
+ res = requests.post(INFO_URL, json = app_info).json()
57
  return "{:,}".format(ret_val)
58
+
59
+
60
 
61
 
62
  def construct_model_info_for_display(model_names):
63
+ options_arr = []
64
  markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
65
  markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
66
  for node in model_names:
67
+ options_arr .append(node["name"])
68
  if (node["mark"] == "True"):
69
  markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\">&nbsp;•&nbsp;Model:&nbsp;<a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Code released by:&nbsp;<a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/>&nbsp;&nbsp;&nbsp;&nbsp;Model info:&nbsp;<a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
70
  if ("Note" in node):
71
  markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\">&nbsp;&nbsp;&nbsp;&nbsp;{node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
72
  markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
73
+
74
  markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>•&nbsp;Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
75
  limit = "{:,}".format(MAX_INPUT)
76
  markdown_str += f"<div style=\"font-size:12px; color: #9f9f9f; text-align: left\">•&nbsp;User uploaded file has a maximum limit of {limit} sentences.</div>"
77
+ return options_arr,markdown_str
78
 
79
 
80
+ st.set_page_config(page_title='TWC - Compare popular/state-of-the-art models for semantic clustering using sentence embeddings', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
81
+ menu_items={
82
+ 'About': 'This app was created by taskswithcode. http://taskswithcode.com'
83
+
84
+ })
85
+ col,pad = st.columns([85,15])
 
 
86
 
87
  with col:
88
  st.image("long_form_logo_with_icon.png")
89
 
90
 
91
  @st.experimental_memo
92
+ def load_model(model_name,model_class,load_model_name):
93
  try:
94
  ret_model = None
95
  obj_class = globals()[model_class]
96
  ret_model = obj_class()
97
  ret_model.init_model(load_model_name)
98
+ assert(ret_model is not None)
99
  except Exception as e:
100
+ st.error(f"Unable to load model class:{model_class} model_name: {model_name} load_model_name: {load_model_name} {str(e)}")
 
101
  pass
102
  return ret_model
103
 
104
 
105
+
106
  @st.experimental_memo
107
+ def cached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,_cluster,clustering_type):
108
+ texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
109
+ results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
110
  return results
111
 
112
 
113
+ def uncached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,cluster,clustering_type):
114
  with st.spinner('Computing vectors for sentences'):
115
+ texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
116
+ results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
117
+ #st.success("Similarity computation complete")
118
  return results
119
 
 
120
  DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
121
+ def get_model_info(model_names,model_name):
 
 
122
  for node in model_names:
123
  if (model_name == node["name"]):
124
+ return node,model_name
125
+ return get_model_info(model_names,DEFAULT_HF_MODEL)
126
 
127
 
128
+ def run_test(model_names,model_name,input_file_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
 
129
  display_area.text("Loading model:" + model_name)
130
+ #Note. model_name may get mapped to new name in the call below for custom models
131
  orig_model_name = model_name
132
+ model_info,model_name = get_model_info(model_names,model_name)
133
  if (model_name != orig_model_name):
134
+ load_model_name = orig_model_name
135
  else:
136
  load_model_name = model_info["model"]
137
  if ("Note" in model_info):
 
140
  if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
141
  fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
142
  display_area.write(fail_link)
143
+ return {"error":fail_link}
144
+ model = load_model(model_name,model_info["class"],load_model_name)
145
+ display_area.text("Model " + model_name + " load complete")
146
  try:
147
+ if (user_uploaded):
148
+ results = uncached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
149
+ else:
150
+ display_area.text("Computing vectors for sentences")
151
+ results = cached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
152
+ display_area.text("Similarity computation complete")
153
+ return results
154
+
 
 
155
  except Exception as e:
156
  st.error("Some error occurred during prediction" + str(e))
157
  st.stop()
158
  return {}
159
 
160
 
161
+
162
+
163
+
164
+ def display_results(orig_sentences,results,response_info,app_mode,model_name):
165
  main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
166
  main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model:&nbsp;<b>{model_name}</b></div>"
167
  score_text = "cosine distance"
 
171
  for i in range(len(results["clusters"])):
172
  pivot_index = results["clusters"][i]["pivot_index"]
173
  pivot_sent = orig_sentences[pivot_index]
174
+ pivot_index += 1
175
  d_cluster = {}
176
  download_data[i + 1] = d_cluster
177
+ d_cluster["pivot"] = {"pivot_index":pivot_index,"sent":pivot_sent,"children":{}}
178
+ body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{pivot_index}]&nbsp;{pivot_sent}&nbsp;<b><i>(Cluster {i+1})</i></b>&nbsp;&nbsp;</div>")
 
179
  neighs_dict = results["clusters"][i]["neighs"]
180
  for key in neighs_dict:
181
  cosine_dist = neighs_dict[key]
182
  child_index = key
183
  sentence = orig_sentences[child_index]
184
  child_index += 1
185
+ body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{child_index}]&nbsp;{sentence}&nbsp;&nbsp;&nbsp;<b>{cosine_dist:.2f}</b></div>")
186
+ d_cluster["pivot"]["children"][sentence] = f"{cosine_dist:.2f}"
 
187
  body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">&nbsp;</div>")
188
  main_sent = main_sent + "\n" + '\n'.join(body_sent)
189
+ st.markdown(main_sent,unsafe_allow_html=True)
190
+ st.session_state["download_ready"] = json.dumps(download_data,indent=4)
191
  get_views("submit")
192
 
193
 
194
  def init_session():
195
  if ("model_name" not in st.session_state):
196
  st.session_state["model_name"] = "ss_test"
197
+ st.session_state["download_ready"] = None
198
  st.session_state["model_name"] = "ss_test"
199
  st.session_state["threshold"] = 1.5
200
  st.session_state["file_name"] = "default"
 
202
  st.session_state["cluster"] = TWCClustering()
203
  else:
204
  print("Skipping init session")
205
+
206
+ def app_main(app_mode,example_files,model_name_files,clus_types):
207
+ init_session()
208
+ with open(example_files) as fp:
209
+ example_file_names = json.load(fp)
210
+ with open(model_name_files) as fp:
 
211
  model_names = json.load(fp)
212
+ with open(clus_types) as fp:
213
  cluster_types = json.load(fp)
214
+ curr_use_case = use_case[app_mode].split(".")[0]
215
+ st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for semantic clustering using sentence embeddings</h5>", unsafe_allow_html=True)
216
+ st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
217
+ st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;<a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/>&nbsp;&nbsp;&nbsp;•&nbsp;&nbsp;{use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
218
+ st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views:&nbsp;{get_views('init')}</div>", unsafe_allow_html=True)
219
+
220
+
221
+ try:
222
+
223
+
224
+ with st.form('twc_form'):
225
+
226
+ step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
227
+ if (app_mode == DOC_RETRIEVAL):
228
+ step1_line += ". The first line is treated as the query"
229
+ uploaded_file = st.file_uploader(step1_line, type=".txt")
230
+
231
+ selected_file_index = st.selectbox(label=f'Example files ({len(example_file_names)})',
232
+ options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
233
+ st.write("")
234
+ options_arr,markdown_str = construct_model_info_for_display(model_names)
235
+ selection_label = 'Select Model'
236
+ selected_model = st.selectbox(label=selection_label,
237
+ options = options_arr, index=0, key = "twc_model")
238
+ st.write("")
239
+ custom_model_selection = st.text_input("Model not listed above? Type any Hugging Face sentence embedding model name ", "",key="custom_model")
240
+ hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Hugging Face sentence embedding models</a><br/><br/><br/></div>"
241
+ st.markdown(hf_link_str, unsafe_allow_html=True)
242
+ threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
243
+ st.write("")
244
+ clustering_type = st.selectbox(label=f'Select type of clustering',
245
+ options = list(dict.keys(cluster_types)), index=0, key = "twc_cluster_types")
246
+ st.write("")
247
+ submit_button = st.form_submit_button('Run')
248
+
249
+
250
+ input_status_area = st.empty()
251
+ display_area = st.empty()
252
+ if submit_button:
253
+ start = time.time()
254
+ if uploaded_file is not None:
255
+ st.session_state["file_name"] = uploaded_file.name
256
+ sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read()
257
+ else:
258
+ st.session_state["file_name"] = example_file_names[selected_file_index]["name"]
259
+ sentences = open(example_file_names[selected_file_index]["name"]).read()
260
+ sentences = sentences.split("\n")[:-1]
261
+ if (len(sentences) > MAX_INPUT):
262
+ st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
263
+ sentences = sentences[:MAX_INPUT]
264
+ if (len(custom_model_selection) != 0):
265
+ run_model = custom_model_selection
266
+ else:
267
+ run_model = selected_model
268
+ st.session_state["model_name"] = selected_model
269
+ st.session_state["threshold"] = threshold
270
+ st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
271
+ results = run_test(model_names,run_model,st.session_state["file_name"],sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
272
+ display_area.empty()
273
+ with display_area.container():
274
+ if ("error" in results):
275
+ st.error(results["error"])
276
  else:
277
+ device = 'GPU' if torch.cuda.is_available() else 'CPU'
278
+ response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
279
+ if (len(custom_model_selection) != 0):
280
+ st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
281
+ display_results(sentences,results,response_info,app_mode,run_model)
282
+ #st.json(results)
283
+ st.download_button(
284
+ label="Download results as json",
285
+ data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
286
+ disabled = False if st.session_state["download_ready"] != None else True,
287
+ file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + st.session_state["overlapped"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
288
+ mime='text/json',
289
+ key ="download"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
+
292
+
293
 
294
+ except Exception as e:
295
+ st.error("Some error occurred during loading" + str(e))
296
+ st.stop()
297
+
298
+ st.markdown(markdown_str, unsafe_allow_html=True)
299
+
300
+
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  if __name__ == "__main__":
303
+ #print("comand line input:",len(sys.argv),str(sys.argv))
304
+ #app_main(sys.argv[1],sys.argv[2],sys.argv[3])
305
+ #app_main("1","sim_app_examples.json","sim_app_models.json")
306
+ app_main("3","clus_app_examples.json","clus_app_models.json","clus_app_clustypes.json")
307