vives's picture
Update app.py
f050ba4
raw
history blame
2.73 kB
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
text = st.text_input("Enter word or key-phrase")
exclude_text = st.radio("exclude_text",[True,False])
exclude_words = st.radio("exclude_words",[True,False])
k = st.number_input("Top k nearest key-phrases",1,10,5)
with open("kp_dict_merged.pickle",'rb') as handle:
kp_dict = pickle.load(handle)
for key in kp_dict.keys():
kp_dict[key] = kp_dict[key].detach().numpy()
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
sim_dict = {}
pools = pool_embeddings(out, tokens).detach().numpy()
for key in kp_dict.keys():
if key == text:
continue
if exclude_text and text in key:
continue
if exclude_words and True in [x in key for x in text.split(" ")]:
continue
sim_dict[key] = cosine_similarity(
pools,
[kp_dict[key]]
)[0][0]
sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
return {x:y for x,y in sims}
def concat_tokens(sentences):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
for sentence in sentences:
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'].append(sentence)
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
def pool_embeddings(out, tok):
embeddings = out["hidden_states"][-1]
attention_mask = tok['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
if text:
new_tokens = concat_tokens([text])
new_tokens.pop("KPS")
with torch.no_grad():
outputs = model(**new_tokens)
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
st.json(sim_dict)