vives commited on
Commit
caee7a0
1 Parent(s): 560e686

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -4,11 +4,13 @@ from sklearn.metrics.pairwise import cosine_similarity
4
  import streamlit as st
5
  import torch
6
  import pickle
 
7
 
8
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
9
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
10
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
11
  text = st.text_input("Enter word or key-phrase")
 
12
  exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
13
 
14
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
@@ -17,12 +19,17 @@ k = st.number_input("Top k nearest key-phrases",1,10,5)
17
  with st.sidebar:
18
  diversify_box = st.checkbox("Diversify results",True)
19
  if diversify_box:
20
- k = st.number_input("Top k nearest key-phrases",10,30,20)
21
 
 
22
  with open("kp_dict_merged.pickle",'rb') as handle:
23
  kp_dict = pickle.load(handle)
24
  for key in kp_dict.keys():
25
  kp_dict[key] = kp_dict[key].detach().numpy()
 
 
 
 
26
 
27
  def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
28
  sim_dict = {}
@@ -65,10 +72,30 @@ def pool_embeddings(out, tok):
65
  mean_pooled = summed / summed_mask
66
  return mean_pooled
67
 
 
 
 
 
 
 
 
 
 
68
  if text:
69
  new_tokens = concat_tokens([text])
70
  new_tokens.pop("KPS")
71
  with torch.no_grad():
72
  outputs = model(**new_tokens)
73
- sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
74
- st.json(sim_dict)
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
5
  import torch
6
  import pickle
7
+ import itertools
8
 
9
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
10
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
11
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
12
  text = st.text_input("Enter word or key-phrase")
13
+
14
  exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
15
 
16
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
 
19
  with st.sidebar:
20
  diversify_box = st.checkbox("Diversify results",True)
21
  if diversify_box:
22
+ k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
23
 
24
+ #load kp dict
25
  with open("kp_dict_merged.pickle",'rb') as handle:
26
  kp_dict = pickle.load(handle)
27
  for key in kp_dict.keys():
28
  kp_dict[key] = kp_dict[key].detach().numpy()
29
+
30
+ #load cosine distances of kp dict
31
+ with open("cosine_kp.pickle",'rb') as handle:
32
+ cosine_kp = pickle.load(handle)
33
 
34
  def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
35
  sim_dict = {}
 
72
  mean_pooled = summed / summed_mask
73
  return mean_pooled
74
 
75
+ def extract_idxs(top_dict, kp_dict):
76
+ idxs = []
77
+ c = 0
78
+ for i in list(kp_dict.keys()):
79
+ if i in top_dict.keys():
80
+ idxs.append(c)
81
+ c+=1
82
+ return idxs
83
+
84
  if text:
85
  new_tokens = concat_tokens([text])
86
  new_tokens.pop("KPS")
87
  with torch.no_grad():
88
  outputs = model(**new_tokens)
89
+ if not diversify_box:
90
+ sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
91
+ st.json(sim_dict)
92
+ else:
93
+ sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
94
+ idxs = extract_idxs(sim_dict, kp_dict)
95
+ distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
96
+ min_sim = np.inf
97
+ candidate = None
98
+ for combination in itertools.combinations(range(len(idxs)), k):
99
+ sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
100
+
101
+