Spaces:
Runtime error
Runtime error
File size: 4,254 Bytes
ed162b2 7fdb2f4 caee7a0 ed162b2 caee7a0 121c34f 383394d 121c34f f050ba4 ed162b2 560e686 caee7a0 560e686 caee7a0 ed162b2 52ce487 9a15376 caee7a0 9a15376 ed162b2 caee7a0 ed162b2 0b48aba ed162b2 caee7a0 9c74908 4c852a8 d20a440 4c852a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import numpy as np
import itertools
model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
text = st.text_input("Enter word or key-phrase")
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)
with st.sidebar:
diversify_box = st.checkbox("Diversify results",True)
if diversify_box:
k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
#load kp dict
with open("kp_dict_merged.pickle",'rb') as handle:
kp_dict = pickle.load(handle)
keys = list(kp_dict.keys())
for key in kp_dict.keys():
kp_dict[key] = kp_dict[key].detach().numpy()
#load cosine distances of kp dict
with open("cosine_kp.pickle",'rb') as handle:
cosine_kp = pickle.load(handle)
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
sim_dict = {}
pools = pool_embeddings(out, tokens).detach().numpy()
for key in kp_dict.keys():
if key == text:
continue
if exclude_text and text in key:
continue
if exclude_words and True in [x in key for x in text.split(" ")]:
continue
sim_dict[key] = cosine_similarity(
pools,
[kp_dict[key]]
)[0][0]
sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
return {x:y for x,y in sims}
def concat_tokens(sentences):
tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
for sentence in sentences:
# encode each sentence and append to dictionary
new_tokens = tokenizer.encode_plus(sentence, max_length=64,
truncation=True, padding='max_length',
return_tensors='pt')
tokens['input_ids'].append(new_tokens['input_ids'][0])
tokens['attention_mask'].append(new_tokens['attention_mask'][0])
tokens['KPS'].append(sentence)
# reformat list of tensors into single tensor
tokens['input_ids'] = torch.stack(tokens['input_ids'])
tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
return tokens
def pool_embeddings(out, tok):
embeddings = out["hidden_states"][-1]
attention_mask = tok['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
masked_embeddings = embeddings * mask
summed = torch.sum(masked_embeddings, 1)
summed_mask = torch.clamp(mask.sum(1), min=1e-9)
mean_pooled = summed / summed_mask
return mean_pooled
def extract_idxs(top_dict, kp_dict):
idxs = []
c = 0
for i in list(kp_dict.keys()):
if i in top_dict.keys():
idxs.append(c)
c+=1
return idxs
if text:
text = text.lower()
new_tokens = concat_tokens([text])
new_tokens.pop("KPS")
with torch.no_grad():
outputs = model(**new_tokens)
if not diversify_box:
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
st.json(sim_dict)
else:
sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
idxs = extract_idxs(sim_dict, kp_dict)
distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
min_sim = np.inf
candidate = None
for combination in itertools.combinations(range(len(idxs)), k):
sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
if sim < min_sim:
candidate = combination
min_sim = sim
ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
ret = {x:y for x,y in ret}
st.json(ret) |