File size: 3,930 Bytes
ed162b2
 
 
 
 
 
caee7a0
ed162b2
 
 
 
 
caee7a0
121c34f
383394d
121c34f
f050ba4
ed162b2
560e686
 
 
caee7a0
560e686
caee7a0
ed162b2
 
9a15376
 
caee7a0
 
 
 
9a15376
ed162b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caee7a0
 
 
 
 
 
 
 
 
ed162b2
 
 
 
 
caee7a0
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import itertools

model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
text = st.text_input("Enter word or key-phrase")

exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")

exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)

with st.sidebar:
  diversify_box = st.checkbox("Diversify results",True)
  if diversify_box:
    k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)

#load kp dict
with open("kp_dict_merged.pickle",'rb') as handle:
  kp_dict = pickle.load(handle)
for key in kp_dict.keys():
  kp_dict[key] = kp_dict[key].detach().numpy()

#load cosine distances of kp dict
with open("cosine_kp.pickle",'rb') as handle:
  cosine_kp = pickle.load(handle)
  
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
  sim_dict = {}
  pools = pool_embeddings(out, tokens).detach().numpy()
  for key in kp_dict.keys():
    if key == text:
      continue
    if exclude_text and text in key:
      continue
    if exclude_words and True in [x in key for x in text.split(" ")]:
      continue
    sim_dict[key] = cosine_similarity(
        pools,
        [kp_dict[key]]
    )[0][0]
  sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
  return {x:y for x,y in sims}
def concat_tokens(sentences):
  tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
  for sentence in sentences:
      # encode each sentence and append to dictionary
      new_tokens = tokenizer.encode_plus(sentence, max_length=64,
                                         truncation=True, padding='max_length',
                                         return_tensors='pt')
      tokens['input_ids'].append(new_tokens['input_ids'][0])
      tokens['attention_mask'].append(new_tokens['attention_mask'][0])
      tokens['KPS'].append(sentence)
  # reformat list of tensors into single tensor
  tokens['input_ids'] = torch.stack(tokens['input_ids'])
  tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
  return tokens

def pool_embeddings(out, tok):
  embeddings = out["hidden_states"][-1]
  attention_mask = tok['attention_mask']
  mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
  masked_embeddings = embeddings * mask
  summed = torch.sum(masked_embeddings, 1)
  summed_mask = torch.clamp(mask.sum(1), min=1e-9)
  mean_pooled = summed / summed_mask
  return mean_pooled
  
def extract_idxs(top_dict, kp_dict):
  idxs = []
  c = 0
  for i in list(kp_dict.keys()):
    if i in top_dict.keys():
      idxs.append(c)
    c+=1
  return idxs
  
if text:
  new_tokens = concat_tokens([text])
  new_tokens.pop("KPS")
  with torch.no_grad():
    outputs = model(**new_tokens)
  if not diversify_box:
    sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
    st.json(sim_dict)
  else:
    sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
    idxs = extract_idxs(sim_dict, kp_dict)
    distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
    min_sim = np.inf
    candidate = None
    for combination in itertools.combinations(range(len(idxs)), k):
      sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])