File size: 8,944 Bytes
f563c99
ed162b2
 
 
 
 
7fdb2f4
caee7a0
b7a045b
ed162b2
9fae370
b045767
 
 
 
 
06da3ff
 
b045767
9fae370
b045767
 
 
 
 
58bd5f2
 
73c3a05
9fae370
b045767
 
 
 
 
aed5b51
 
2a61a57
ed162b2
06da3ff
121c34f
f050ba4
ed162b2
560e686
 
 
caee7a0
560e686
06da3ff
2a61a57
06da3ff
8d94e80
ed162b2
52ce487
 
06da3ff
 
 
 
2a61a57
 
 
 
caee7a0
8d94e80
caee7a0
06da3ff
 
2a61a57
 
9a15376
e9c9402
ed162b2
e9c9402
 
 
54edc92
ed162b2
 
 
 
 
 
 
 
 
 
 
 
 
06da3ff
ed162b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caee7a0
 
 
 
 
 
 
 
 
ed162b2
0b48aba
06da3ff
ed162b2
06da3ff
 
2a61a57
 
ed162b2
 
06da3ff
8767f7f
caee7a0
b069816
 
 
06da3ff
 
 
 
 
 
05fc1b3
 
 
caee7a0
b069816
 
 
caee7a0
a8bcdc3
2a61a57
caee7a0
56eac52
2a61a57
06da3ff
caee7a0
06da3ff
caee7a0
 
9c74908
 
 
06da3ff
 
 
 
 
 
 
 
2a61a57
 
 
 
 
 
 
 
06da3ff
4c852a8
d20a440
 
06da3ff
09ac8a0
084fc81
 
2a61a57
 
 
 
06da3ff
 
 
 
 
2a61a57
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from transformers import AutoModelForMaskedLM , AutoModelForSequenceClassification, AutoModel
from transformers import AutoTokenizer
from sklearn.metrics.pairwise import cosine_similarity
import streamlit as st
import torch
import pickle
import numpy as np
import itertools
import tokenizers

@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForMaskedLM: lambda _: None})
def load_bert():
  return (AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True),
  AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022"))

model, tokenizer = load_bert()
kp_dict_checkpoint = "kp_dict_merged.pickle"
kp_cosine_checkpoint = "cosine_kp.pickle"

@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModelForSequenceClassification: lambda _: None})
def load_finbert():
  return (AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True),
  AutoTokenizer.from_pretrained("ProsusAI/finbert"))
  
model_finbert, tokenizer_finbert = load_finbert()
kp_dict_finbert_checkpoint = "kp_dict_finance.pickle"
kp_cosine_finbert_checkpoint = "cosine_kp_finance.pickle"

@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, AutoModel: lambda _: None})
def load_sapbert():
  return (AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True),
  AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext"))
  
model_sapbert, tokenizer_sapbert = load_sapbert()
kp_dict_sapbert_checkpoint = "kp_dict_medical.pickle"
kp_cosine_sapbert_checkpoint = "cosine_kp_medical.pickle"

text = st.text_input("Enter word or key-phrase")
exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
k = st.number_input("Top k nearest key-phrases",1,10,5)

with st.sidebar:
  diversify_box = st.checkbox("Diversify results",True)
  if diversify_box:
    k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)

#columns
col1, col2, col3 = st.columns(3)
#load kp dicts
with open(kp_dict_checkpoint,'rb') as handle:
  kp_dict = pickle.load(handle)
keys = list(kp_dict.keys())

with open(kp_dict_finbert_checkpoint,'rb') as handle:
  kp_dict_finbert = pickle.load(handle)
keys_finbert = list(kp_dict_finbert.keys())

with open(kp_dict_sapbert_checkpoint,'rb') as handle:
  kp_dict_sapbert = pickle.load(handle)
keys_sapbert = list(kp_dict_sapbert.keys())

#load cosine distances of kp dict
with open(kp_cosine_checkpoint,'rb') as handle:
  cosine_kp = pickle.load(handle)
with open(kp_cosine_finbert_checkpoint,'rb') as handle:
  cosine_finbert_kp = pickle.load(handle)
with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
  cosine_sapbert_kp = pickle.load(handle)
  
def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5, pooler=True):
  sim_dict = {}
  if pooler:
    pools = pool_embeddings(out, tokens).detach().numpy()
  else:
    pools = out["pooler_output"].detach().numpy()
  for key in kp_dict.keys():
    if key == text:
      continue
    if exclude_text and text in key:
      continue
    if exclude_words and True in [x in key for x in text.split(" ")]:
      continue
    sim_dict[key] = cosine_similarity(
        pools,
        [kp_dict[key]]
    )[0][0]
  sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
  return {x:y for x,y in sims}
def concat_tokens(sentences, tokenizer):
  tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
  for sentence in sentences:
      # encode each sentence and append to dictionary
      new_tokens = tokenizer.encode_plus(sentence, max_length=64,
                                         truncation=True, padding='max_length',
                                         return_tensors='pt')
      tokens['input_ids'].append(new_tokens['input_ids'][0])
      tokens['attention_mask'].append(new_tokens['attention_mask'][0])
      tokens['KPS'].append(sentence)
  # reformat list of tensors into single tensor
  tokens['input_ids'] = torch.stack(tokens['input_ids'])
  tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
  return tokens

def pool_embeddings(out, tok):
  embeddings = out["hidden_states"][-1]
  attention_mask = tok['attention_mask']
  mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
  masked_embeddings = embeddings * mask
  summed = torch.sum(masked_embeddings, 1)
  summed_mask = torch.clamp(mask.sum(1), min=1e-9)
  mean_pooled = summed / summed_mask
  return mean_pooled
  
def extract_idxs(top_dict, kp_dict):
  idxs = []
  c = 0
  for i in list(kp_dict.keys()):
    if i in top_dict.keys():
      idxs.append(c)
    c+=1
  return idxs
  
if text:
  text = text.lower()
  new_tokens = concat_tokens([text], tokenizer)
  new_tokens.pop("KPS")
  new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
  new_tokens_finbert.pop("KPS")
  new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert)
  new_tokens_sapbert.pop("KPS")
  with torch.no_grad():
    outputs = model(**new_tokens)
    outputs_finbert = model_finbert(**new_tokens_finbert)
    outputs_sapbert = model_sapbert(**new_tokens_sapbert)
  if not diversify_box:
    sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
    sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
    sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify, pooler=False)
    with col1:
      st.write("distilbert-cvent")
      st.json(sim_dict)
    with col2:
      st.write("finbert")
      st.json(sim_dict_finbert)
    with col3:
      st.write("sapbert")
      st.json(sim_dict_sapbert)
  else:
    sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
    sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
    sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k, pooler=False)
    idxs = extract_idxs(sim_dict, kp_dict)
    idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
    idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert)
    distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
    distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
    distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)]
    #first do distilbert
    candidate = None
    min_sim = np.inf
    for combination in itertools.combinations(range(len(idxs)), k):
      sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
      if sim < min_sim:
        candidate = combination
        min_sim = sim
    #then do finbert
    candidate_finbert = None
    min_sim = np.inf
    for combination in itertools.combinations(range(len(idxs_finbert)), k):
      sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j])
      if sim < min_sim:
        candidate_finbert = combination
        min_sim = sim
    #sapbert
    candidate_sapbert = None
    min_sim = np.inf
    for combination in itertools.combinations(range(len(idxs_sapbert)), k):
      sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j])
      if sim < min_sim:
        candidate_sapbert = combination
        min_sim = sim
    #distilbert
    ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
    ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
    ret = {x:y for x,y in ret}
    #finbert
    ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
    ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
    ret_finbert = {x:y for x,y in ret_finbert}
    #sapbert
    ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert}
    ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True)
    ret_sapbert = {x:y for x,y in ret_sapbert}
    with col1:
      st.write("distilbert-cvent")
      st.json(ret)
    with col2:
      st.write("finbert")
      st.json(ret_finbert)
    with col3:
      st.write("sapbert")
      st.json(ret_sapbert)