mbahrami's picture
Update app.py
fe35e1b
raw
history blame
No virus
2.77 kB
import streamlit as st
import pandas as pd
from streamlit import cli as stcli
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import sys
HISTORY_WEIGHT = 100 # set history weight (if found any keyword from history, it will priorities based on its weight)
@st.cache(allow_output_mutation=True)
def get_model(model):
return pipeline("fill-mask", model=model, top_k=100)#set the maximum of tokens to be retrieved after each inference to model
def main(nlp, semantic_model):
data_load_state = st.text('Inference to model...')
result = nlp(text+' '+nlp.tokenizer.mask_token)
data_load_state.text('')
sem_list=[semantic_text.strip()]
if len(semantic_text):
predicted_seq=[rec['sequence'] for rec in result]
predicted_embeddings = semantic_model.encode(predicted_seq, convert_to_tensor=True)
semantic_history_embeddings = semantic_model.encode(sem_list, convert_to_tensor=True)
cosine_scores = util.cos_sim(predicted_embeddings, semantic_history_embeddings)
for index, r in enumerate(result):
if len(semantic_text):
if len(r['token_str'])>2: #skip spcial chars such as "?"
result[index]['score']+=float(sum(cosine_scores[index]))*HISTORY_WEIGHT
if r['token_str'].lower().strip() in history_keyword_text.lower().strip() and len(r['token_str'].lower().strip())>1:
#found from history, then increase the score of tokens
result[index]['score']*=HISTORY_WEIGHT
#sort the results
df=pd.DataFrame(result).sort_values(by='score', ascending=False)
# show the results as a table
st.table(df)
# print(df)
if __name__ == '__main__':
if st._is_running_with_streamlit:
st.caption("This is a simple auto-completion where the next token is predicted per probability and a weight if it is appeared in keyword user's history or there is a similarity to semantic user's history")
history_keyword_text = st.text_input("Enter users's history <keywords matc> (optional, i.e., 'Gates')", value="")
text = st.text_input("Enter a text for auto completion...", value='Where is Bill')
semantic_text = st.text_input("Enter users's history <semantic> (optional, i.e., 'Microsoft or President')", value="Microsoft")
model = st.selectbox("Choose a model", ["roberta-base", "bert-base-uncased"])
data_load_state = st.text('Loading model...')
semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
nlp = get_model(model)
main(nlp, semantic_model)
else:
sys.argv = ['streamlit', 'run', sys.argv[0]]
sys.exit(stcli.main())