vives's picture
Update app.py
b069816
from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification, AutoModel
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import numpy as np
import itertools
import tokenizers
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForMaskedLM: lambda _: None})
def load_bert():
return (AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True),
AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022"))
model, tokenizer = load_bert()
kp_dict_checkpoint = "kp_dict_merged.pickle"
kp_cosine_checkpoint = "cosine_kp.pickle"
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForSequenceClassification: lambda _: None})
def load_finbert():
return (AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True),
AutoTokenizer.from_pretrained("ProsusAI/finbert"))
model_finbert, tokenizer_finbert = load_finbert()
kp_dict_finbert_checkpoint = "kp_dict_finance.pickle"
kp_cosine_finbert_checkpoint = "cosine_kp_finance.pickle"
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModel: lambda _: None})
def load_sapbert():
return (AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True),
AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext"))
model_sapbert, tokenizer_sapbert = load_sapbert()
kp_dict_sapbert_checkpoint = "kp_dict_medical.pickle"
kp_cosine_sapbert_checkpoint = "cosine_kp_medical.pickle"
text = st.text_input("Enter word or key-phrase")
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)
with st.sidebar:
diversify_box = st.checkbox("Diversify results",True)
if diversify_box:
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
#columns
col1, col2, col3 = st.columns(3)
#load kp dicts
with open(kp_dict_checkpoint,'rb') as handle:
kp_dict = pickle.load(handle)
keys = list(kp_dict.keys())
with open(kp_dict_finbert_checkpoint,'rb') as handle:
kp_dict_finbert = pickle.load(handle)
keys_finbert = list(kp_dict_finbert.keys())
with open(kp_dict_sapbert_checkpoint,'rb') as handle:
kp_dict_sapbert = pickle.load(handle)
keys_sapbert = list(kp_dict_sapbert.keys())
#load cosine distances of kp dict
with open(kp_cosine_checkpoint,'rb') as handle:
cosine_kp = pickle.load(handle)
with open(kp_cosine_finbert_checkpoint,'rb') as handle:
cosine_finbert_kp = pickle.load(handle)
with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
cosine_sapbert_kp = pickle.load(handle)
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5, pooler=True):
sim_dict = {}
if pooler:
pools = pool_embeddings(out, tokens).detach().numpy()
else:
pools = out["pooler_output"].detach().numpy()
for key in kp_dict.keys():
if key == text:
continue
if exclude_text and text in key:
continue
if exclude_words and True in [x in key for x in text.split(" ")]:
continue
sim_dict[key] = cosine_similarity(
pools,
[kp_dict[key]]
)[0][0]
sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
return {x:y for x,y in sims}
def concat_tokens(sentences, tokenizer):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
for sentence in sentences:
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'].append(sentence)
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
def pool_embeddings(out, tok):
embeddings = out["hidden_states"][-1]
attention_mask = tok['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
def extract_idxs(top_dict, kp_dict):
idxs = []
c = 0
for i in list(kp_dict.keys()):
if i in top_dict.keys():
idxs.append(c)
c+=1
return idxs
if text:
text = text.lower()
new_tokens = concat_tokens([text], tokenizer)
new_tokens.pop("KPS")
new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
new_tokens_finbert.pop("KPS")
new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert)
new_tokens_sapbert.pop("KPS")
with torch.no_grad():
outputs = model(**new_tokens)
outputs_finbert = model_finbert(**new_tokens_finbert)
outputs_sapbert = model_sapbert(**new_tokens_sapbert)
if not diversify_box:
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify, pooler=False)
with col1:
st.write("distilbert-cvent")
st.json(sim_dict)
with col2:
st.write("finbert")
st.json(sim_dict_finbert)
with col3:
st.write("sapbert")
st.json(sim_dict_sapbert)
else:
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k, pooler=False)
idxs = extract_idxs(sim_dict, kp_dict)
idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert)
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)]
#first do distilbert
candidate = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs)), k):
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate = combination
min_sim = sim
#then do finbert
candidate_finbert = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs_finbert)), k):
sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate_finbert = combination
min_sim = sim
#sapbert
candidate_sapbert = None
min_sim = np.inf
for combination in itertools.combinations(range(len(idxs_sapbert)), k):
sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate_sapbert = combination
min_sim = sim
#distilbert
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
ret = {x:y for x,y in ret}
#finbert
ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
ret_finbert = {x:y for x,y in ret_finbert}
#sapbert
ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert}
ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True)
ret_sapbert = {x:y for x,y in ret_sapbert}
with col1:
st.write("distilbert-cvent")
st.json(ret)
with col2:
st.write("finbert")
st.json(ret_finbert)
with col3:
st.write("sapbert")
st.json(ret_sapbert)