vives commited on
Commit
06da3ff
1 Parent(s): 73c3a05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -26
app.py CHANGED
@@ -7,25 +7,18 @@ import pickle
7
  import numpy as np
8
  import itertools
9
 
10
- choice = st.radio("Choose model",["distilbert-cvent","finbert"])
11
- if choice == "distilbert-cvent":
12
- model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
13
- model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
14
- tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
15
- kp_dict_checkpoint = "kp_dict_merged.pickle"
16
- kp_cosine_checkpoint = "cosine_kp.pickle"
17
-
18
- elif choice == "finbert":
19
- model_checkpoint = "ProsusAI/finbert"
20
- tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")
21
- model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True)
22
- kp_dict_checkpoint = "kp_dict_finbert.pickle"
23
- kp_cosine_checkpoint = "cosine_kp_finbert.pickle"
24
 
25
  text = st.text_input("Enter word or key-phrase")
26
-
27
- exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query (i.e exclude 'hot coffee' if the query is 'cold coffee')")
28
-
29
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
30
  k = st.number_input("Top k nearest key-phrases",1,10,5)
31
 
@@ -34,16 +27,24 @@ with st.sidebar:
34
  if diversify_box:
35
  k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
36
 
37
- #load kp di
 
 
38
  with open(kp_dict_checkpoint,'rb') as handle:
39
  kp_dict = pickle.load(handle)
40
  keys = list(kp_dict.keys())
41
 
 
 
 
 
42
  #load cosine distances of kp dict
43
  with open(kp_cosine_checkpoint,'rb') as handle:
44
  cosine_kp = pickle.load(handle)
 
 
45
 
46
- def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5):
47
  sim_dict = {}
48
  pools = pool_embeddings(out, tokens).detach().numpy()
49
  for key in kp_dict.keys():
@@ -59,7 +60,7 @@ def calculate_top_k(out, tokens,text,exclude_text=False,exclude_words=False, k=5
59
  )[0][0]
60
  sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
61
  return {x:y for x,y in sims}
62
- def concat_tokens(sentences):
63
  tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
64
  for sentence in sentences:
65
  # encode each sentence and append to dictionary
@@ -95,25 +96,54 @@ def extract_idxs(top_dict, kp_dict):
95
 
96
  if text:
97
  text = text.lower()
98
- new_tokens = concat_tokens([text])
99
  new_tokens.pop("KPS")
 
 
100
  with torch.no_grad():
101
  outputs = model(**new_tokens)
 
 
 
102
  if not diversify_box:
103
- sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
104
- st.json(sim_dict)
 
 
 
 
105
  else:
106
- sim_dict = calculate_top_k(outputs, new_tokens, text, exclude_text=exclude_text,exclude_words=exclude_words,k=k_diversify)
107
  idxs = extract_idxs(sim_dict, kp_dict)
 
108
  distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
109
- min_sim = np.inf
 
110
  candidate = None
 
111
  for combination in itertools.combinations(range(len(idxs)), k):
112
  sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
113
  if sim < min_sim:
114
  candidate = combination
115
  min_sim = sim
 
 
 
 
 
 
 
 
 
116
  ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
117
  ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
118
  ret = {x:y for x,y in ret}
119
- st.json(ret)
 
 
 
 
 
 
 
 
 
 
7
  import numpy as np
8
  import itertools
9
 
10
+ model = AutoModelForMaskedLM.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022", output_hidden_states=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("vives/distilbert-base-uncased-finetuned-cvent-2019_2022")
12
+ kp_dict_checkpoint = "kp_dict_merged.pickle"
13
+ kp_cosine_checkpoint = "cosine_kp.pickle"
14
+
15
+ model_finbert = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert", output_hidden_states=True)
16
+ tokenizer_finbert = AutoTokenizer.from_pretrained("ProsusAI/finbert")
17
+ kp_dict_finbert_checkpoint = "kp_dict_finbert.pickle"
18
+ kp_cosine_finbert_checkpoint = "cosine_kp_finbert.pickle"
 
 
 
 
 
19
 
20
  text = st.text_input("Enter word or key-phrase")
21
+ exclude_words = st.radio("exclude_words",[True,False], help="Exclude results that contain any words in the query")
 
 
22
  exclude_text = st.radio("exclude_text",[True,False], help="Exclude results that contain the query (i.e exclude 'tomato soup recipe' if the query is 'tomato soup')")
23
  k = st.number_input("Top k nearest key-phrases",1,10,5)
24
 
 
27
  if diversify_box:
28
  k_diversify = st.number_input("Set of key-phrases to diversify from",10,30,20)
29
 
30
+ #columns
31
+ col1, col2 = st.columns(2)
32
+ #load kp dicts
33
  with open(kp_dict_checkpoint,'rb') as handle:
34
  kp_dict = pickle.load(handle)
35
  keys = list(kp_dict.keys())
36
 
37
+ with open(kp_dict_finbert_checkpoint,'rb') as handle:
38
+ kp_dict_finbert = pickle.load(handle)
39
+ keys_finbert = list(kp_dict_finbert.keys())
40
+
41
  #load cosine distances of kp dict
42
  with open(kp_cosine_checkpoint,'rb') as handle:
43
  cosine_kp = pickle.load(handle)
44
+ with open(kp_cosine_finbert_checkpoint,'rb') as handle:
45
+ cosine_finbert_kp = pickle.load(handle)
46
 
47
+ def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
48
  sim_dict = {}
49
  pools = pool_embeddings(out, tokens).detach().numpy()
50
  for key in kp_dict.keys():
 
60
  )[0][0]
61
  sims = sorted(sim_dict.items(), key= lambda x: x[1], reverse = True)[:k]
62
  return {x:y for x,y in sims}
63
+ def concat_tokens(sentences, tokenizer):
64
  tokens = {'input_ids': [], 'attention_mask': [], 'KPS': []}
65
  for sentence in sentences:
66
  # encode each sentence and append to dictionary
 
96
 
97
  if text:
98
  text = text.lower()
99
+ new_tokens = concat_tokens([text], tokenizer)
100
  new_tokens.pop("KPS")
101
+ new_tokens_finbert = concat_tokens([text], tokenizer_finbert)
102
+ new_tokens_finbert.pop("KPS")
103
  with torch.no_grad():
104
  outputs = model(**new_tokens)
105
+ outputs_finbert = model_finbert(**new_tokens_finbert)
106
+ sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
107
+ sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
108
  if not diversify_box:
109
+ with col1:
110
+ st.write("distilbert-cvent")
111
+ st.json(sim_dict)
112
+ with col2:
113
+ st.write("finbert")
114
+ st.json(sim_dict_finbert)
115
  else:
 
116
  idxs = extract_idxs(sim_dict, kp_dict)
117
+ idxs_finbert = extract_idxs(sim_dict, kp_dict_finbert)
118
  distances_candidates = cosine_kp[np.ix_(idxs, idxs)]
119
+ distances_candidates_finbert = cosine_kp_finbert[np.ix_(idxs_finbert, idxs_finbert)]
120
+ #first do distilbert
121
  candidate = None
122
+ min_sim = np.inf
123
  for combination in itertools.combinations(range(len(idxs)), k):
124
  sim = sum([distances_candidates[i][j] for i in combination for j in combination if i != j])
125
  if sim < min_sim:
126
  candidate = combination
127
  min_sim = sim
128
+ #then do finbert
129
+ candidate_finbert = None
130
+ min_sim = np.inf
131
+ for combination in itertools.combinations(range(len(idxs_finbert)), k):
132
+ sim = sum([distances_candidates_finbert[i][j] for i in combination for j in combination if i != j])
133
+ if sim < min_sim:
134
+ candidate_finbert = combination
135
+ min_sim = sim
136
+ #distilbert
137
  ret = {keys[idxs[idx]]:sim_dict[keys[idxs[idx]]] for idx in candidate}
138
  ret = sorted(ret.items(), key= lambda x: x[1], reverse = True)
139
  ret = {x:y for x,y in ret}
140
+ #finbert
141
+ ret_finbert = {keys_finbert[idxs_finbert[idx]]:sim_dict_finbert[keys_finbert[idxs[idx]]] for idx in candidate_finbert}
142
+ candidate_finbert = sorted(candidate_finbert.items(), key= lambda x: x[1], reverse = True)
143
+ candidate_finbert = {x:y for x,y in candidate_finbert}
144
+ with col1:
145
+ st.write("distilbert-cvent")
146
+ st.json(ret)
147
+ with col2:
148
+ st.write("finbert")
149
+ st.json(ret_finbert)