import time
import streamlit as st
import string
from io import StringIO
import json
from transformers import BertTokenizer, BertForMaskedLM
MAX_INPUT = 1000
model_names = [
{ "name":"SGPT-125M",
"model":"Muennighoff/SGPT-125M-weightedmean-nli-bitfit",
"mark":False,
"class":"SGPTModel"},
{ "name":"SGPT-5.8B",
"model": "Muennighoff/SGPT-5.8B-weightedmean-msmarco-specb-bitfit" ,
"fork_url":"https://github.com/taskswithcode/sgpt",
"orig_author_url":"https://github.com/Muennighoff",
"orig_author":"Niklas Muennighoff",
"sota_info": {
"task":"#1 in multiple information retrieval & search tasks",
"sota_link":"https://paperswithcode.com/paper/sgpt-gpt-sentence-embeddings-for-semantic",
},
"paper_url":"https://arxiv.org/abs/2202.08904v5",
"mark":True,
"class":"SGPTModel"},
{ "name":"SGPT-1.3B",
"model": "Muennighoff/SGPT-1.3B-weightedmean-msmarco-specb-bitfit",
"mark":False,
"class":"SGPTModel"},
{ "name":"sentence-transformers/all-MiniLM-L6-v2",
"model":"sentence-transformers/all-MiniLM-L6-v2",
"fork_url":"https://github.com/taskswithcode/sentence_similarity_hf_model",
"orig_author_url":"https://github.com/UKPLab",
"orig_author":"Ubiquitous Knowledge Processing Lab",
"sota_info": {
"task":"Nearly 4 million downloads from huggingface",
"sota_link":"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2"
},
"paper_url":"https://arxiv.org/abs/1908.10084",
"mark":True,
"class":"HFModel"},
]
example_file_names = {
"Machine learning terms (30+ phrases)": "tests/small_test.txt",
"Customer feedback mixed with noise (50+ sentences)":"tests/larger_test.txt"
}
def construct_model_info_for_display():
options_arr = []
markdown_str = "
Models evaluated
"
for node in model_names:
options_arr .append(node["name"])
if (node["mark"] == True):
markdown_str += f""
markdown_str += "Note:
• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not saved
"
limit = "{:,}".format(MAX_INPUT)
markdown_str += f"• User uploaded file has a maximum limit of {limit} sentences.
"
return options_arr,markdown_str
st.set_page_config(page_title='TWC - Compare state-of-the-art models for Sentence Similarity task', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
menu_items={
'About': 'This app was created by taskswithcode. http://taskswithcode.com'
})
col,pad = st.columns([85,15])
with col:
st.image("long_form_logo_with_icon.png")
@st.experimental_memo
def load_model(model_name):
try:
ret_model = None
for node in model_names:
if (model_name.startswith(node["name"])):
obj_class = globals()[node["class"]]
ret_model = obj_class()
ret_model.init_model(node["model"])
assert(ret_model is not None)
except Exception as e:
st.error("Unable to load model:" + model_name + " " + str(e))
pass
return ret_model
@st.experimental_memo
def cached_compute_similarity(sentences,_model,model_name,main_index):
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
results = _model.output_results(None,texts,embeddings,main_index)
return results
def uncached_compute_similarity(sentences,_model,model_name,main_index):
with st.spinner('Computing vectors for sentences'):
texts,embeddings = _model.compute_embeddings(sentences,is_file=False)
results = _model.output_results(None,texts,embeddings,main_index)
#st.success("Similarity computation complete")
return results
def run_test(model_name,sentences,display_area,main_index,user_uploaded):
display_area.text("Loading model:" + model_name)
model = load_model(model_name)
display_area.text("Model " + model_name + " load complete")
try:
if (user_uploaded):
results = uncached_compute_similarity(sentences,model,model_name,main_index)
else:
display_area.text("Computing vectors for sentences")
results = cached_compute_similarity(sentences,model,model_name,main_index)
display_area.text("Similarity computation complete")
return results
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
return {}
def display_results(orig_sentences,main_index,results,response_info):
main_sent = f"{response_info}
"
main_sent += "Results sorted by cosine distance. Closest(1) to furthest(-1) away from main sentence
"
main_sent += f"Main sentence: {orig_sentences[main_index]}
"
body_sent = []
download_data = {}
for key in results:
index = orig_sentences.index(key) + 1
body_sent.append(f"{index}] {key} {results[key]:.2f}
")
download_data[key] = f"{results[key]:.2f}"
main_sent = main_sent + "\n" + '\n'.join(body_sent)
st.markdown(main_sent,unsafe_allow_html=True)
st.session_state["download_ready"] = json.dumps(download_data,indent=4)
def init_session():
st.session_state["download_ready"] = None
st.session_state["model_name"] = "ss_test"
st.session_state["main_index"] = 1
st.session_state["file_name"] = "default"
def main():
init_session()
st.markdown("Compare state-of-the-art models for Sentence Similarity task
", unsafe_allow_html=True)
try:
with st.form('twc_form'):
uploaded_file = st.file_uploader("Step 1. Upload text file(one sentence in a line) or choose an example text file below.", type=".txt")
selected_file_index = st.selectbox(label='Example files ',
options = list(dict.keys(example_file_names)), index=0, key = "twc_file")
st.write("")
options_arr,markdown_str = construct_model_info_for_display()
selected_model = st.selectbox(label='Step 2. Select Model',
options = options_arr, index=0, key = "twc_model")
st.write("")
main_index = st.number_input('Step 3. Enter index of sentence in file to make it the main sentence:',value=1,min_value = 1)
st.write("")
submit_button = st.form_submit_button('Run')
input_status_area = st.empty()
display_area = st.empty()
if submit_button:
start = time.time()
if uploaded_file is not None:
st.session_state["file_name"] = uploaded_file.name
sentences = StringIO(uploaded_file.getvalue().decode("utf-8")).read()
else:
st.session_state["file_name"] = example_file_names[selected_file_index]
sentences = open(example_file_names[selected_file_index]).read()
sentences = sentences.split("\n")[:-1]
if (len(sentences) < main_index):
main_index = len(sentences)
st.info("Selected sentence index is larger than number of sentences in file. Truncating to " + str(main_index))
if (len(sentences) > MAX_INPUT):
st.info(f"Input sentence count exceeds maximum sentence limit. First {MAX_INPUT} out of {len(sentences)} sentences chosen")
sentences = sentences[:MAX_INPUT]
st.session_state["model_name"] = selected_model
st.session_state["main_index"] = main_index
results = run_test(selected_model,sentences,display_area,main_index - 1,(uploaded_file is not None))
display_area.empty()
with display_area.container():
response_info = f"Response time - {time.time() - start:.2f} secs for {len(sentences)} sentences"
display_results(sentences,main_index - 1,results,response_info)
#st.json(results)
st.download_button(
label="Download results as json",
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
disabled = False if st.session_state["download_ready"] != None else True,
file_name= (st.session_state["model_name"] + "_" + str(st.session_state["main_index"]) + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".json").replace("/","_"),
mime='text/json',
key ="download"
)
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.markdown(markdown_str, unsafe_allow_html=True)
if __name__ == "__main__":
main()