File size: 4,254 Bytes
ed162b2
 
 
 
 
 
7fdb2f4
caee7a0
ed162b2
 
 
 
 
caee7a0
121c34f
383394d
121c34f
f050ba4
ed162b2
560e686
 
 
caee7a0
560e686
caee7a0
ed162b2
 
52ce487
 
9a15376
 
caee7a0
 
 
 
9a15376
ed162b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caee7a0
 
 
 
 
 
 
 
 
ed162b2
0b48aba
ed162b2
 
 
 
caee7a0
 
 
 
 
 
 
 
 
 
 
9c74908
 
 
4c852a8
d20a440
 
4c852a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import numpy as np
import itertools

model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
text = st.text_input("Enter word or key-phrase")

exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")

exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)

with st.sidebar:
  diversify_box = st.checkbox("Diversify results",True)
  if diversify_box:
    k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)

#load kp dict
with open("kp_dict_merged.pickle",'rb') as handle:
  kp_dict = pickle.load(handle)
keys = list(kp_dict.keys())

for key in kp_dict.keys():
  kp_dict[key] = kp_dict[key].detach().numpy()

#load cosine distances of kp dict
with open("cosine_kp.pickle",'rb') as handle:
  cosine_kp = pickle.load(handle)
  
def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
  sim_dict = {}
  pools = pool_embeddings(out, tokens).detach().numpy()
  for key in kp_dict.keys():
    if key == text:
      continue
    if exclude_text and text in key:
      continue
    if exclude_words and True in [x in key for x in text.split(" ")]:
      continue
    sim_dict[key] = cosine_similarity(
        pools,
        [kp_dict[key]]
    )[0][0]
  sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
  return {x:y for x,y in sims}
def concat_tokens(sentences):
  tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
  for sentence in sentences:
      # encode each sentence and append to dictionary
      new_tokens = tokenizer.encode_plus(sentence, max_length=64,
                                         truncation=True, padding='max_length',
                                         return_tensors='pt')
      tokens['input_ids'].append(new_tokens['input_ids'][0])
      tokens['attention_mask'].append(new_tokens['attention_mask'][0])
      tokens['KPS'].append(sentence)
  # reformat list of tensors into single tensor
  tokens['input_ids'] = torch.stack(tokens['input_ids'])
  tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
  return tokens

def pool_embeddings(out, tok):
  embeddings = out["hidden_states"][-1]
  attention_mask = tok['attention_mask']
  mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
  masked_embeddings = embeddings * mask
  summed = torch.sum(masked_embeddings, 1)
  summed_mask = torch.clamp(mask.sum(1), min=1e-9)
  mean_pooled = summed / summed_mask
  return mean_pooled
  
def extract_idxs(top_dict, kp_dict):
  idxs = []
  c = 0
  for i in list(kp_dict.keys()):
    if i in top_dict.keys():
      idxs.append(c)
    c+=1
  return idxs
  
if text:
  text = text.lower()
  new_tokens = concat_tokens([text])
  new_tokens.pop("KPS")
  with torch.no_grad():
    outputs = model(**new_tokens)
  if not diversify_box:
    sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
    st.json(sim_dict)
  else:
    sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
    idxs = extract_idxs(sim_dict, kp_dict)
    distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
    min_sim = np.inf
    candidate = None
    for combination in itertools.combinations(range(len(idxs)), k):
      sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
      if sim < min_sim:
        candidate = combination
        min_sim = sim
    ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
    ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
    ret = {x:y for x,y in ret}
    st.json(ret)