Spaces:
Runtime error
Runtime error
vaibhavsharda
commited on
Commit
•
bd4f103
1
Parent(s):
7e8b361
Update app.py
Browse files
app.py
CHANGED
@@ -2,135 +2,136 @@ import time
|
|
2 |
import sys
|
3 |
import streamlit as st
|
4 |
import string
|
5 |
-
from io import StringIO
|
6 |
import pdb
|
7 |
import json
|
8 |
-
from twc_embeddings import HFModel,
|
9 |
from twc_openai_embeddings import OpenAIModel
|
10 |
from twc_clustering import TWCClustering
|
11 |
import torch
|
12 |
import requests
|
13 |
import socket
|
14 |
|
|
|
15 |
MAX_INPUT = 10000
|
16 |
|
17 |
-
SEM_SIMILARITY
|
18 |
-
DOC_RETRIEVAL
|
19 |
-
CLUSTERING
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
use_case = {"1": "Finding similar phrases/sentences",
|
22 |
-
"2": "Retrieving semantically matching information to a query. It may not be a factual match",
|
23 |
-
"3": "Clustering"}
|
24 |
-
use_case_url = {"1": "https://huggingface.co/spaces/taskswithcode/semantic_similarity",
|
25 |
-
"2": "https://huggingface.co/spaces/taskswithcode/semantic_search", "3": ""}
|
26 |
|
27 |
from transformers import BertTokenizer, BertForMaskedLM
|
28 |
|
|
|
29 |
APP_NAME = "hf/semantic_clustering"
|
30 |
INFO_URL = "https://www.taskswithcode.com/stats/"
|
31 |
|
32 |
|
|
|
|
|
|
|
33 |
def get_views(action):
|
34 |
ret_val = 0
|
35 |
hostname = socket.gethostname()
|
36 |
ip_address = socket.gethostbyname(hostname)
|
37 |
if ("view_count" not in st.session_state):
|
38 |
try:
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
except:
|
44 |
-
|
45 |
ret_val = data
|
46 |
st.session_state["view_count"] = data
|
47 |
else:
|
48 |
ret_val = st.session_state["view_count"]
|
49 |
if (action != "init"):
|
50 |
-
|
51 |
-
|
52 |
return "{:,}".format(ret_val)
|
|
|
|
|
53 |
|
54 |
|
55 |
def construct_model_info_for_display(model_names):
|
56 |
-
options_arr
|
57 |
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
|
58 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
59 |
for node in model_names:
|
60 |
-
options_arr.append(node["name"])
|
61 |
if (node["mark"] == "True"):
|
62 |
markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"> • Model: <a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/> Code released by: <a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/> Model info: <a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
|
63 |
if ("Note" in node):
|
64 |
markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\"> {node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
|
65 |
markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
|
66 |
-
|
67 |
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
|
68 |
limit = "{:,}".format(MAX_INPUT)
|
69 |
markdown_str += f"<div style=\"font-size:12px; color: #9f9f9f; text-align: left\">• User uploaded file has a maximum limit of {limit} sentences.</div>"
|
70 |
-
return options_arr,
|
71 |
|
72 |
|
73 |
-
st.set_page_config(
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
})
|
80 |
-
col, pad = st.columns([85, 15])
|
81 |
|
82 |
with col:
|
83 |
st.image("long_form_logo_with_icon.png")
|
84 |
|
85 |
|
86 |
@st.experimental_memo
|
87 |
-
def load_model(model_name,
|
88 |
try:
|
89 |
ret_model = None
|
90 |
obj_class = globals()[model_class]
|
91 |
ret_model = obj_class()
|
92 |
ret_model.init_model(load_model_name)
|
93 |
-
assert
|
94 |
except Exception as e:
|
95 |
-
st.error(
|
96 |
-
f"Unable to load model class:{model_class} model_name: {model_name} load_model_name: {load_model_name} {str(e)}")
|
97 |
pass
|
98 |
return ret_model
|
99 |
|
100 |
|
|
|
101 |
@st.experimental_memo
|
102 |
-
def cached_compute_similarity(input_file_name,
|
103 |
-
texts,
|
104 |
-
results = _cluster.cluster(None,
|
105 |
return results
|
106 |
|
107 |
|
108 |
-
def uncached_compute_similarity(input_file_name,
|
109 |
with st.spinner('Computing vectors for sentences'):
|
110 |
-
texts,
|
111 |
-
results = cluster.cluster(None,
|
112 |
-
#
|
113 |
return results
|
114 |
|
115 |
-
|
116 |
DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
117 |
-
|
118 |
-
|
119 |
-
def get_model_info(model_names, model_name):
|
120 |
for node in model_names:
|
121 |
if (model_name == node["name"]):
|
122 |
-
return node,
|
123 |
-
return get_model_info(model_names,
|
124 |
|
125 |
|
126 |
-
def run_test(model_names,
|
127 |
-
clustering_type):
|
128 |
display_area.text("Loading model:" + model_name)
|
129 |
-
#
|
130 |
orig_model_name = model_name
|
131 |
-
model_info,
|
132 |
if (model_name != orig_model_name):
|
133 |
-
load_model_name
|
134 |
else:
|
135 |
load_model_name = model_info["model"]
|
136 |
if ("Note" in model_info):
|
@@ -139,27 +140,28 @@ def run_test(model_names, model_name, input_file_name, sentences, display_area,
|
|
139 |
if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
|
140 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
141 |
display_area.write(fail_link)
|
142 |
-
return {"error":
|
143 |
-
model = load_model(model_name,
|
144 |
-
display_area.text("Model " + model_name
|
145 |
try:
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
return results
|
155 |
-
|
156 |
except Exception as e:
|
157 |
st.error("Some error occurred during prediction" + str(e))
|
158 |
st.stop()
|
159 |
return {}
|
160 |
|
161 |
|
162 |
-
|
|
|
|
|
|
|
163 |
main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
|
164 |
main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model: <b>{model_name}</b></div>"
|
165 |
score_text = "cosine distance"
|
@@ -169,32 +171,30 @@ def display_results(orig_sentences, results, response_info, app_mode, model_name
|
|
169 |
for i in range(len(results["clusters"])):
|
170 |
pivot_index = results["clusters"][i]["pivot_index"]
|
171 |
pivot_sent = orig_sentences[pivot_index]
|
172 |
-
pivot_index +=
|
173 |
d_cluster = {}
|
174 |
download_data[i + 1] = d_cluster
|
175 |
-
d_cluster["pivot"] = {"pivot_index":
|
176 |
-
body_sent.append(
|
177 |
-
f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{pivot_index}] {pivot_sent} <b><i>(Cluster {i + 1})</i></b> </div>")
|
178 |
neighs_dict = results["clusters"][i]["neighs"]
|
179 |
for key in neighs_dict:
|
180 |
cosine_dist = neighs_dict[key]
|
181 |
child_index = key
|
182 |
sentence = orig_sentences[child_index]
|
183 |
child_index += 1
|
184 |
-
body_sent.append(
|
185 |
-
|
186 |
-
d_cluster["pivot"]["children"][sentence] = f"{cosine_dist:.2f}"
|
187 |
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"> </div>")
|
188 |
main_sent = main_sent + "\n" + '\n'.join(body_sent)
|
189 |
-
st.markdown(main_sent,
|
190 |
-
st.session_state["download_ready"] = json.dumps(download_data,
|
191 |
get_views("submit")
|
192 |
|
193 |
|
194 |
def init_session():
|
195 |
if ("model_name" not in st.session_state):
|
196 |
st.session_state["model_name"] = "ss_test"
|
197 |
-
st.session_state["download_ready"] = None
|
198 |
st.session_state["model_name"] = "ss_test"
|
199 |
st.session_state["threshold"] = 1.5
|
200 |
st.session_state["file_name"] = "default"
|
@@ -202,139 +202,106 @@ def init_session():
|
|
202 |
st.session_state["cluster"] = TWCClustering()
|
203 |
else:
|
204 |
print("Skipping init session")
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
with open(model_name_files) as fp:
|
212 |
model_names = json.load(fp)
|
213 |
-
|
214 |
cluster_types = json.load(fp)
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
|
|
|
|
275 |
else:
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
290 |
-
results = run_test(model_names, run_model, st.session_state["file_name"], sentences, display_area,
|
291 |
-
threshold, (uploaded_file is not None), (len(custom_model_selection) != 0),
|
292 |
-
cluster_types[clustering_type]["type"])
|
293 |
-
display_area.empty()
|
294 |
-
with display_area.container():
|
295 |
-
if ("error" in results):
|
296 |
-
st.error(results["error"])
|
297 |
-
else:
|
298 |
-
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
299 |
-
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
300 |
-
if (len(custom_model_selection) != 0):
|
301 |
-
st.info(
|
302 |
-
"Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
303 |
-
display_results(sentences, results, response_info, app_mode, run_model)
|
304 |
-
# st.json(results)
|
305 |
-
st.download_button(
|
306 |
-
label="Download results as json",
|
307 |
-
data=st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
308 |
-
disabled=False if st.session_state["download_ready"] != None else True,
|
309 |
-
file_name=(st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" +
|
310 |
-
st.session_state["overlapped"] + "_" + '_'.join(
|
311 |
-
st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/", "_"),
|
312 |
-
mime='text/json',
|
313 |
-
key="download"
|
314 |
)
|
|
|
|
|
315 |
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
file_name=(st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" +
|
324 |
-
st.session_state["overlapped"] + "_" + '_'.join(
|
325 |
-
st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/", "_"),
|
326 |
-
mime='text/json',
|
327 |
-
key="download"
|
328 |
-
)
|
329 |
-
st.error("Some error occurred during loading" + str(e))
|
330 |
-
#st.stop()
|
331 |
-
|
332 |
-
st.markdown(markdown_str, unsafe_allow_html=True)
|
333 |
-
|
334 |
|
335 |
if __name__ == "__main__":
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
|
|
|
2 |
import sys
|
3 |
import streamlit as st
|
4 |
import string
|
5 |
+
from io import StringIO
|
6 |
import pdb
|
7 |
import json
|
8 |
+
from twc_embeddings import HFModel,SimCSEModel,SGPTModel,CausalLMModel,SGPTQnAModel
|
9 |
from twc_openai_embeddings import OpenAIModel
|
10 |
from twc_clustering import TWCClustering
|
11 |
import torch
|
12 |
import requests
|
13 |
import socket
|
14 |
|
15 |
+
|
16 |
MAX_INPUT = 10000
|
17 |
|
18 |
+
SEM_SIMILARITY="1"
|
19 |
+
DOC_RETRIEVAL="2"
|
20 |
+
CLUSTERING="3"
|
21 |
+
|
22 |
+
|
23 |
+
use_case = {"1":"Finding similar phrases/sentences","2":"Retrieving semantically matching information to a query. It may not be a factual match","3":"Clustering"}
|
24 |
+
use_case_url = {"1":"https://huggingface.co/spaces/taskswithcode/semantic_similarity","2":"https://huggingface.co/spaces/taskswithcode/semantic_search","3":""}
|
25 |
+
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
from transformers import BertTokenizer, BertForMaskedLM
|
29 |
|
30 |
+
|
31 |
APP_NAME = "hf/semantic_clustering"
|
32 |
INFO_URL = "https://www.taskswithcode.com/stats/"
|
33 |
|
34 |
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
def get_views(action):
|
39 |
ret_val = 0
|
40 |
hostname = socket.gethostname()
|
41 |
ip_address = socket.gethostbyname(hostname)
|
42 |
if ("view_count" not in st.session_state):
|
43 |
try:
|
44 |
+
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
|
45 |
+
res = requests.post(INFO_URL, json = app_info).json()
|
46 |
+
print(res)
|
47 |
+
data = res["count"]
|
48 |
except:
|
49 |
+
data = 0
|
50 |
ret_val = data
|
51 |
st.session_state["view_count"] = data
|
52 |
else:
|
53 |
ret_val = st.session_state["view_count"]
|
54 |
if (action != "init"):
|
55 |
+
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
|
56 |
+
res = requests.post(INFO_URL, json = app_info).json()
|
57 |
return "{:,}".format(ret_val)
|
58 |
+
|
59 |
+
|
60 |
|
61 |
|
62 |
def construct_model_info_for_display(model_names):
|
63 |
+
options_arr = []
|
64 |
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Models evaluated ({len(model_names)})</b><br/><i>The selected models satisfy one or more of the following (1) state-of-the-art (2) the most downloaded models on Hugging Face (3) Large Language Models (e.g. GPT-3)</i></div>"
|
65 |
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>"
|
66 |
for node in model_names:
|
67 |
+
options_arr .append(node["name"])
|
68 |
if (node["mark"] == "True"):
|
69 |
markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"> • Model: <a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/> Code released by: <a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/> Model info: <a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>"
|
70 |
if ("Note" in node):
|
71 |
markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\"> {node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>"
|
72 |
markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>"
|
73 |
+
|
74 |
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>"
|
75 |
limit = "{:,}".format(MAX_INPUT)
|
76 |
markdown_str += f"<div style=\"font-size:12px; color: #9f9f9f; text-align: left\">• User uploaded file has a maximum limit of {limit} sentences.</div>"
|
77 |
+
return options_arr,markdown_str
|
78 |
|
79 |
|
80 |
+
st.set_page_config(page_title='TWC - Compare popular/state-of-the-art models for semantic clustering using sentence embeddings', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
|
81 |
+
menu_items={
|
82 |
+
'About': 'This app was created by taskswithcode. http://taskswithcode.com'
|
83 |
+
|
84 |
+
})
|
85 |
+
col,pad = st.columns([85,15])
|
|
|
|
|
86 |
|
87 |
with col:
|
88 |
st.image("long_form_logo_with_icon.png")
|
89 |
|
90 |
|
91 |
@st.experimental_memo
|
92 |
+
def load_model(model_name,model_class,load_model_name):
|
93 |
try:
|
94 |
ret_model = None
|
95 |
obj_class = globals()[model_class]
|
96 |
ret_model = obj_class()
|
97 |
ret_model.init_model(load_model_name)
|
98 |
+
assert(ret_model is not None)
|
99 |
except Exception as e:
|
100 |
+
st.error(f"Unable to load model class:{model_class} model_name: {model_name} load_model_name: {load_model_name} {str(e)}")
|
|
|
101 |
pass
|
102 |
return ret_model
|
103 |
|
104 |
|
105 |
+
|
106 |
@st.experimental_memo
|
107 |
+
def cached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,_cluster,clustering_type):
|
108 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
109 |
+
results = _cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
110 |
return results
|
111 |
|
112 |
|
113 |
+
def uncached_compute_similarity(input_file_name,sentences,_model,model_name,threshold,cluster,clustering_type):
|
114 |
with st.spinner('Computing vectors for sentences'):
|
115 |
+
texts,embeddings = _model.compute_embeddings(input_file_name,sentences,is_file=False)
|
116 |
+
results = cluster.cluster(None,texts,embeddings,threshold,clustering_type)
|
117 |
+
#st.success("Similarity computation complete")
|
118 |
return results
|
119 |
|
|
|
120 |
DEFAULT_HF_MODEL = "sentence-transformers/paraphrase-MiniLM-L6-v2"
|
121 |
+
def get_model_info(model_names,model_name):
|
|
|
|
|
122 |
for node in model_names:
|
123 |
if (model_name == node["name"]):
|
124 |
+
return node,model_name
|
125 |
+
return get_model_info(model_names,DEFAULT_HF_MODEL)
|
126 |
|
127 |
|
128 |
+
def run_test(model_names,model_name,input_file_name,sentences,display_area,threshold,user_uploaded,custom_model,clustering_type):
|
|
|
129 |
display_area.text("Loading model:" + model_name)
|
130 |
+
#Note. model_name may get mapped to new name in the call below for custom models
|
131 |
orig_model_name = model_name
|
132 |
+
model_info,model_name = get_model_info(model_names,model_name)
|
133 |
if (model_name != orig_model_name):
|
134 |
+
load_model_name = orig_model_name
|
135 |
else:
|
136 |
load_model_name = model_info["model"]
|
137 |
if ("Note" in model_info):
|
|
|
140 |
if (user_uploaded and "custom_load" in model_info and model_info["custom_load"] == "False"):
|
141 |
fail_link = f"{model_info['Note']} [link]({model_info['alt_url']})"
|
142 |
display_area.write(fail_link)
|
143 |
+
return {"error":fail_link}
|
144 |
+
model = load_model(model_name,model_info["class"],load_model_name)
|
145 |
+
display_area.text("Model " + model_name + " load complete")
|
146 |
try:
|
147 |
+
if (user_uploaded):
|
148 |
+
results = uncached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
149 |
+
else:
|
150 |
+
display_area.text("Computing vectors for sentences")
|
151 |
+
results = cached_compute_similarity(input_file_name,sentences,model,model_name,threshold,st.session_state["cluster"],clustering_type)
|
152 |
+
display_area.text("Similarity computation complete")
|
153 |
+
return results
|
154 |
+
|
|
|
|
|
155 |
except Exception as e:
|
156 |
st.error("Some error occurred during prediction" + str(e))
|
157 |
st.stop()
|
158 |
return {}
|
159 |
|
160 |
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
def display_results(orig_sentences,results,response_info,app_mode,model_name):
|
165 |
main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>"
|
166 |
main_sent += f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">Showing results for model: <b>{model_name}</b></div>"
|
167 |
score_text = "cosine distance"
|
|
|
171 |
for i in range(len(results["clusters"])):
|
172 |
pivot_index = results["clusters"][i]["pivot_index"]
|
173 |
pivot_sent = orig_sentences[pivot_index]
|
174 |
+
pivot_index += 1
|
175 |
d_cluster = {}
|
176 |
download_data[i + 1] = d_cluster
|
177 |
+
d_cluster["pivot"] = {"pivot_index":pivot_index,"sent":pivot_sent,"children":{}}
|
178 |
+
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{pivot_index}] {pivot_sent} <b><i>(Cluster {i+1})</i></b> </div>")
|
|
|
179 |
neighs_dict = results["clusters"][i]["neighs"]
|
180 |
for key in neighs_dict:
|
181 |
cosine_dist = neighs_dict[key]
|
182 |
child_index = key
|
183 |
sentence = orig_sentences[child_index]
|
184 |
child_index += 1
|
185 |
+
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\">{child_index}] {sentence} <b>{cosine_dist:.2f}</b></div>")
|
186 |
+
d_cluster["pivot"]["children"][sentence] = f"{cosine_dist:.2f}"
|
|
|
187 |
body_sent.append(f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"> </div>")
|
188 |
main_sent = main_sent + "\n" + '\n'.join(body_sent)
|
189 |
+
st.markdown(main_sent,unsafe_allow_html=True)
|
190 |
+
st.session_state["download_ready"] = json.dumps(download_data,indent=4)
|
191 |
get_views("submit")
|
192 |
|
193 |
|
194 |
def init_session():
|
195 |
if ("model_name" not in st.session_state):
|
196 |
st.session_state["model_name"] = "ss_test"
|
197 |
+
st.session_state["download_ready"] = None
|
198 |
st.session_state["model_name"] = "ss_test"
|
199 |
st.session_state["threshold"] = 1.5
|
200 |
st.session_state["file_name"] = "default"
|
|
|
202 |
st.session_state["cluster"] = TWCClustering()
|
203 |
else:
|
204 |
print("Skipping init session")
|
205 |
+
|
206 |
+
def app_main(app_mode,example_files,model_name_files,clus_types):
|
207 |
+
init_session()
|
208 |
+
with open(example_files) as fp:
|
209 |
+
example_file_names = json.load(fp)
|
210 |
+
with open(model_name_files) as fp:
|
|
|
211 |
model_names = json.load(fp)
|
212 |
+
with open(clus_types) as fp:
|
213 |
cluster_types = json.load(fp)
|
214 |
+
curr_use_case = use_case[app_mode].split(".")[0]
|
215 |
+
st.markdown("<h5 style='text-align: center;'>Compare popular/state-of-the-art models for semantic clustering using sentence embeddings</h5>", unsafe_allow_html=True)
|
216 |
+
st.markdown(f"<p style='font-size:14px; color: #4f4f4f; text-align: center'><i>Or compare your own model with state-of-the-art/popular models</p>", unsafe_allow_html=True)
|
217 |
+
st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for sentence embeddings<br/> • <a href=\'{use_case_url['1']}\' target='_blank'>{use_case['1']}</a><br/> • <a href=\'{use_case_url['2']}\' target='_blank'>{use_case['2']}</a><br/> • {use_case['3']}<br/><i>This app illustrates <b>'{curr_use_case}'</b> use case</i></div>", unsafe_allow_html=True)
|
218 |
+
st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views: {get_views('init')}</div>", unsafe_allow_html=True)
|
219 |
+
|
220 |
+
|
221 |
+
try:
|
222 |
+
|
223 |
+
|
224 |
+
with st.form('twc_form'):
|
225 |
+
|
226 |
+
step1_line = "Upload text file(one sentence in a line) or choose an example text file below"
|
227 |
+
if (app_mode == DOC_RETRIEVAL):
|
228 |
+
step1_line += ". The first line is treated as the query"
|
229 |
+
uploaded_file = st.file_uploader(step1_line, type=".txt")
|
230 |
+
|
231 |
+
selected_file_index = st.selectbox(label=f'Example files ({len(example_file_names)})',
|
232 |
+
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
|
233 |
+
st.write("")
|
234 |
+
options_arr,markdown_str = construct_model_info_for_display(model_names)
|
235 |
+
selection_label = 'Select Model'
|
236 |
+
selected_model = st.selectbox(label=selection_label,
|
237 |
+
options = options_arr, index=0, key = "twc_model")
|
238 |
+
st.write("")
|
239 |
+
custom_model_selection = st.text_input("Model not listed above? Type any Hugging Face sentence embedding model name ", "",key="custom_model")
|
240 |
+
hf_link_str = "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><a href='https://huggingface.co/models?pipeline_tag=sentence-similarity' target = '_blank'>List of Hugging Face sentence embedding models</a><br/><br/><br/></div>"
|
241 |
+
st.markdown(hf_link_str, unsafe_allow_html=True)
|
242 |
+
threshold = st.number_input('Choose a zscore threshold (number of std devs from mean)',value=st.session_state["threshold"],min_value = 0.0,step=.01)
|
243 |
+
st.write("")
|
244 |
+
clustering_type = st.selectbox(label=f'Select type of clustering',
|
245 |
+
options = list(dict.keys(cluster_types)), index=0, key = "twc_cluster_types")
|
246 |
+
st.write("")
|
247 |
+
submit_button = st.form_submit_button('Run')
|
248 |
+
|
249 |
+
|
250 |
+
input_status_area = st.empty()
|
251 |
+
display_area = st.empty()
|
252 |
+
if submit_button:
|
253 |
+
start = time.time()
|
254 |
+
if uploaded_file is not None:
|
255 |
+
st.session_state["file_name"] = uploaded_file.name
|
256 |
+
sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read()
|
257 |
+
else:
|
258 |
+
st.session_state["file_name"] = example_file_names[selected_file_index]["name"]
|
259 |
+
sentences = open(example_file_names[selected_file_index]["name"]).read()
|
260 |
+
sentences = sentences.split("\n")[:-1]
|
261 |
+
if (len(sentences) > MAX_INPUT):
|
262 |
+
st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
|
263 |
+
sentences = sentences[:MAX_INPUT]
|
264 |
+
if (len(custom_model_selection) != 0):
|
265 |
+
run_model = custom_model_selection
|
266 |
+
else:
|
267 |
+
run_model = selected_model
|
268 |
+
st.session_state["model_name"] = selected_model
|
269 |
+
st.session_state["threshold"] = threshold
|
270 |
+
st.session_state["overlapped"] = cluster_types[clustering_type]["type"]
|
271 |
+
results = run_test(model_names,run_model,st.session_state["file_name"],sentences,display_area,threshold,(uploaded_file is not None),(len(custom_model_selection) != 0),cluster_types[clustering_type]["type"])
|
272 |
+
display_area.empty()
|
273 |
+
with display_area.container():
|
274 |
+
if ("error" in results):
|
275 |
+
st.error(results["error"])
|
276 |
else:
|
277 |
+
device = 'GPU' if torch.cuda.is_available() else 'CPU'
|
278 |
+
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for {len(sentences)} sentences"
|
279 |
+
if (len(custom_model_selection) != 0):
|
280 |
+
st.info("Custom model overrides model selection in step 2 above. So please clear the custom model text box to choose models from step 2")
|
281 |
+
display_results(sentences,results,response_info,app_mode,run_model)
|
282 |
+
#st.json(results)
|
283 |
+
st.download_button(
|
284 |
+
label="Download results as json",
|
285 |
+
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
|
286 |
+
disabled = False if st.session_state["download_ready"] != None else True,
|
287 |
+
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["threshold"]) + "_" + st.session_state["overlapped"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
|
288 |
+
mime='text/json',
|
289 |
+
key ="download"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
)
|
291 |
+
|
292 |
+
|
293 |
|
294 |
+
except Exception as e:
|
295 |
+
st.error("Some error occurred during loading" + str(e))
|
296 |
+
st.stop()
|
297 |
+
|
298 |
+
st.markdown(markdown_str, unsafe_allow_html=True)
|
299 |
+
|
300 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
if __name__ == "__main__":
|
303 |
+
#print("comand line input:",len(sys.argv),str(sys.argv))
|
304 |
+
#app_main(sys.argv[1],sys.argv[2],sys.argv[3])
|
305 |
+
#app_main("1","sim_app_examples.json","sim_app_models.json")
|
306 |
+
app_main("3","clus_app_examples.json","clus_app_models.json","clus_app_clustypes.json")
|
307 |
|