Spaces:
Runtime error
Runtime error
from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification | |
from transformers import AutoTokenizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
import streamlit as st | |
import torch | |
import pickle | |
import numpy as np | |
import itertools | |
model = AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022") | |
kp_dict_checkpoint = "kp_dict_merged.pickle" | |
kp_cosine_checkpoint = "cosine_kp.pickle" | |
model_finbert = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True) | |
tokenizer_finbert = AutoTokenizer.from_pretrained("ProsusAI/finbert") | |
kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle" | |
kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle" | |
text = st.text_input("Enter word or key-phrase") | |
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query") | |
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')") | |
k = st.number_input("Top k nearest key-phrases",1,10,5) | |
with st.sidebar: | |
diversify_box = st.checkbox("Diversify results",True) | |
if diversify_box: | |
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20) | |
#columns | |
col1, col2 = st.columns(2) | |
#load kp dicts | |
with open(kp_dict_checkpoint,'rb') as handle: | |
kp_dict = pickle.load(handle) | |
keys = list(kp_dict.keys()) | |
with open(kp_dict_finbert_checkpoint,'rb') as handle: | |
kp_dict_finbert = pickle.load(handle) | |
keys_finbert = list(kp_dict_finbert.keys()) | |
#load cosine distances of kp dict | |
with open(kp_cosine_checkpoint,'rb') as handle: | |
cosine_kp = pickle.load(handle) | |
with open(kp_cosine_finbert_checkpoint,'rb') as handle: | |
cosine_finbert_kp = pickle.load(handle) | |
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5): | |
sim_dict = {} | |
pools = pool_embeddings(out, tokens).detach().numpy() | |
for key in kp_dict.keys(): | |
if key == text: | |
continue | |
if exclude_text and text in key: | |
continue | |
if exclude_words and True in [x in key for x in text.split(" ")]: | |
continue | |
sim_dict[key] = cosine_similarity( | |
pools, | |
[kp_dict[key]] | |
)[0][0] | |
sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k] | |
return {x:y for x,y in sims} | |
def concat_tokens(sentences, tokenizer): | |
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []} | |
for sentence in sentences: | |
# encode each sentence and append to dictionary | |
new_tokens = tokenizer.encode_plus(sentence, max_length=64, | |
truncation=True, padding='max_length', | |
return_tensors='pt') | |
tokens['input_ids'].append(new_tokens['input_ids'][0]) | |
tokens['attention_mask'].append(new_tokens['attention_mask'][0]) | |
tokens['KPS'].append(sentence) | |
# reformat list of tensors into single tensor | |
tokens['input_ids'] = torch.stack(tokens['input_ids']) | |
tokens['attention_mask'] = torch.stack(tokens['attention_mask']) | |
return tokens | |
def pool_embeddings(out, tok): | |
embeddings = out["hidden_states"][-1] | |
attention_mask = tok['attention_mask'] | |
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float() | |
masked_embeddings = embeddings * mask | |
summed = torch.sum(masked_embeddings, 1) | |
summed_mask = torch.clamp(mask.sum(1), min=1e-9) | |
mean_pooled = summed / summed_mask | |
return mean_pooled | |
def extract_idxs(top_dict, kp_dict): | |
idxs = [] | |
c = 0 | |
for i in list(kp_dict.keys()): | |
if i in top_dict.keys(): | |
idxs.append(c) | |
c+=1 | |
return idxs | |
if text: | |
text = text.lower() | |
new_tokens = concat_tokens([text], tokenizer) | |
new_tokens.pop("KPS") | |
new_tokens_finbert = concat_tokens([text], tokenizer_finbert) | |
new_tokens_finbert.pop("KPS") | |
with torch.no_grad(): | |
outputs = model(**new_tokens) | |
outputs_finbert = model_finbert(**new_tokens_finbert) | |
sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k) | |
sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k) | |
if not diversify_box: | |
with col1: | |
st.write("distilbert-cvent") | |
st.json(sim_dict) | |
with col2: | |
st.write("finbert") | |
st.json(sim_dict_finbert) | |
else: | |
idxs = extract_idxs(sim_dict, kp_dict) | |
idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert) | |
distances_candidates = cosine_kp[np.ix_(idxs, idxs)] | |
distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)] | |
#first do distilbert | |
candidate = None | |
min_sim = np.inf | |
for combination in itertools.combinations(range(len(idxs)), k): | |
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j]) | |
if sim < min_sim: | |
candidate = combination | |
min_sim = sim | |
#then do finbert | |
candidate_finbert = None | |
min_sim = np.inf | |
for combination in itertools.combinations(range(len(idxs_finbert)), k): | |
sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j]) | |
if sim < min_sim: | |
candidate_finbert = combination | |
min_sim = sim | |
#distilbert | |
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate} | |
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True) | |
ret = {x:y for x,y in ret} | |
#finbert | |
ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert} | |
candidate_finbert = sorted(candidate_finbert.items(), key= lambda x: x[1], reverse = True) | |
candidate_finbert = {x:y for x,y in candidate_finbert} | |
with col1: | |
st.write("distilbert-cvent") | |
st.json(ret) | |
with col2: | |
st.write("finbert") | |
st.json(ret_finbert) |