vives commited on
Commit
2a61a57
1 Parent(s): 314007e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -2
app.py CHANGED
@@ -17,6 +17,11 @@ tokenizer_finbert = AutoTokenizer.from_pretrained("ProsusAI/finbert")
17
  kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle"
18
  kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle"
19
 
 
 
 
 
 
20
  text = st.text_input("Enter word or key-phrase")
21
  exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
22
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
@@ -28,7 +33,7 @@ with st.sidebar:
28
  k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
29
 
30
  #columns
31
- col1, col2 = st.columns(2)
32
  #load kp dicts
33
  with open(kp_dict_checkpoint,'rb') as handle:
34
  kp_dict = pickle.load(handle)
@@ -38,11 +43,17 @@ with open(kp_dict_finbert_checkpoint,'rb') as handle:
38
  kp_dict_finbert = pickle.load(handle)
39
  keys_finbert = list(kp_dict_finbert.keys())
40
 
 
 
 
 
41
  #load cosine distances of kp dict
42
  with open(kp_cosine_checkpoint,'rb') as handle:
43
  cosine_kp = pickle.load(handle)
44
  with open(kp_cosine_finbert_checkpoint,'rb') as handle:
45
  cosine_finbert_kp = pickle.load(handle)
 
 
46
 
47
  def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
48
  sim_dict = {}
@@ -100,11 +111,15 @@ if text:
100
  new_tokens.pop("KPS")
101
  new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
102
  new_tokens_finbert.pop("KPS")
 
 
103
  with torch.no_grad():
104
  outputs = model(**new_tokens)
105
  outputs_finbert = model_finbert(**new_tokens_finbert)
 
106
  sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
107
  sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
 
108
  if not diversify_box:
109
  with col1:
110
  st.write("distilbert-cvent")
@@ -112,11 +127,16 @@ if text:
112
  with col2:
113
  st.write("finbert")
114
  st.json(sim_dict_finbert)
 
 
 
115
  else:
116
  idxs = extract_idxs(sim_dict, kp_dict)
117
  idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
 
118
  distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
119
  distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
 
120
  #first do distilbert
121
  candidate = None
122
  min_sim = np.inf
@@ -133,6 +153,14 @@ if text:
133
  if sim < min_sim:
134
  candidate_finbert = combination
135
  min_sim = sim
 
 
 
 
 
 
 
 
136
  #distilbert
137
  ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
138
  ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
@@ -141,9 +169,16 @@ if text:
141
  ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
142
  ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
143
  ret_finbert = {x:y for x,y in ret_finbert}
 
 
 
 
144
  with col1:
145
  st.write("distilbert-cvent")
146
  st.json(ret)
147
  with col2:
148
  st.write("finbert")
149
- st.json(ret_finbert)
 
 
 
 
17
  kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle"
18
  kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle"
19
 
20
+ tokenizer_sapbert = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
21
+ model_sapbert = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext", output_hidden_states=True)
22
+ kp_dict_sapbert_checkpoint = "kp_dict_sapbert.pickle"
23
+ kp_cosine_sapbert_checkpoint = "cosine_kp_sapbert.pickle"
24
+
25
  text = st.text_input("Enter word or key-phrase")
26
  exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
27
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
 
33
  k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
34
 
35
  #columns
36
+ col1, col2, col3 = st.columns(3)
37
  #load kp dicts
38
  with open(kp_dict_checkpoint,'rb') as handle:
39
  kp_dict = pickle.load(handle)
 
43
  kp_dict_finbert = pickle.load(handle)
44
  keys_finbert = list(kp_dict_finbert.keys())
45
 
46
+ with open(kp_dict_sapbert_checkpoint,'rb') as handle:
47
+ kp_dict_sapbert = pickle.load(handle)
48
+ keys_sapbert = list(kp_dict_sapbert.keys())
49
+
50
  #load cosine distances of kp dict
51
  with open(kp_cosine_checkpoint,'rb') as handle:
52
  cosine_kp = pickle.load(handle)
53
  with open(kp_cosine_finbert_checkpoint,'rb') as handle:
54
  cosine_finbert_kp = pickle.load(handle)
55
+ with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
56
+ cosine_sapbert_kp = pickle.load(handle)
57
 
58
  def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
59
  sim_dict = {}
 
111
  new_tokens.pop("KPS")
112
  new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
113
  new_tokens_finbert.pop("KPS")
114
+ new_tokens_sapbert = concat_tokens([text], tokenizer_sapbert)
115
+ new_tokens_sapbert.pop("KPS")
116
  with torch.no_grad():
117
  outputs = model(**new_tokens)
118
  outputs_finbert = model_finbert(**new_tokens_finbert)
119
+ outputs_sapbert = model_sapbert(**new_tokens_sapbert)
120
  sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
121
  sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
122
+ sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
123
  if not diversify_box:
124
  with col1:
125
  st.write("distilbert-cvent")
 
127
  with col2:
128
  st.write("finbert")
129
  st.json(sim_dict_finbert)
130
+ with col3:
131
+ st.write("sapbert")
132
+ st.json(sim_dict_sapbert)
133
  else:
134
  idxs = extract_idxs(sim_dict, kp_dict)
135
  idxs_finbert = extract_idxs(sim_dict_finbert, kp_dict_finbert)
136
+ idxs_sapbert = extract_idxs(sim_dict_sapbert, kp_dict_sapbert)
137
  distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
138
  distances_candidates_finbert = cosine_finbert_kp[np.ix_(idxs_finbert, idxs_finbert)]
139
+ distances_candidates_sapbert = cosine_sapbert_kp[np.ix_(idxs_sapbert, idxs_sapbert)]
140
  #first do distilbert
141
  candidate = None
142
  min_sim = np.inf
 
153
  if sim < min_sim:
154
  candidate_finbert = combination
155
  min_sim = sim
156
+ #sapbert
157
+ candidate_sapbert = None
158
+ min_sim = np.inf
159
+ for combination in itertools.combinations(range(len(idxs_sapbert)), k):
160
+ sim = sum([distances_candidates_sapbert[i][j] for i in combination for j in combination if i != j])
161
+ if sim < min_sim:
162
+ candidate_sapbert = combination
163
+ min_sim = sim
164
  #distilbert
165
  ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
166
  ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
 
169
  ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs_finbert[idx]]] for idx in candidate_finbert}
170
  ret_finbert = sorted(ret_finbert.items(), key= lambda x: x[1], reverse = True)
171
  ret_finbert = {x:y for x,y in ret_finbert}
172
+ #sapbert
173
+ ret_sapbert = {keys_sapbert[idxs_sapbert[idx]]:sim_dict_sapbert[keys_sapbert[idxs_sapbert[idx]]] for idx in candidate_sapbert}
174
+ ret_sapbert = sorted(ret_sapbert.items(), key= lambda x: x[1], reverse = True)
175
+ ret_sapbert = {x:y for x,y in ret_sapbert}
176
  with col1:
177
  st.write("distilbert-cvent")
178
  st.json(ret)
179
  with col2:
180
  st.write("finbert")
181
+ st.json(ret_finbert)
182
+ with col3:
183
+ st.write("sapbert")
184
+ st.json(ret_sapbert)