vives commited on
Commit
e9c9402
1 Parent(s): aed5b51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -55,9 +55,12 @@ with open(kp_cosine_finbert_checkpoint,'rb') as handle:
55
  with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
56
  cosine_sapbert_kp = pickle.load(handle)
57
 
58
- def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5):
59
  sim_dict = {}
60
- pools = pool_embeddings(out, tokens).detach().numpy()
 
 
 
61
  for key in kp_dict.keys():
62
  if key == text:
63
  continue
@@ -116,7 +119,7 @@ if text:
116
  with torch.no_grad():
117
  outputs = model(**new_tokens)
118
  outputs_finbert = model_finbert(**new_tokens_finbert)
119
- outputs_sapbert = model_sapbert(**new_tokens_sapbert)
120
  sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
121
  sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
122
  sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
 
55
  with open(kp_cosine_sapbert_checkpoint,'rb') as handle:
56
  cosine_sapbert_kp = pickle.load(handle)
57
 
58
+ def calculate_top_k(out, tokens,text,kp_dict,exclude_text=False,exclude_words=False, k=5, pooler=True):
59
  sim_dict = {}
60
+ if pooler:
61
+ pools = pool_embeddings(out, tokens).detach().numpy()
62
+ else:
63
+ pools = out["pooler_outputs"].detach().numpy()
64
  for key in kp_dict.keys():
65
  if key == text:
66
  continue
 
119
  with torch.no_grad():
120
  outputs = model(**new_tokens)
121
  outputs_finbert = model_finbert(**new_tokens_finbert)
122
+ outputs_sapbert = model_sapbert(**new_tokens_sapbert, pooler=False)
123
  sim_dict = calculate_top_k(outputs, new_tokens, text, kp_dict, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
124
  sim_dict_finbert = calculate_top_k(outputs_finbert, new_tokens_finbert, text, kp_dict_finbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)
125
  sim_dict_sapbert = calculate_top_k(outputs_sapbert, new_tokens_sapbert, text, kp_dict_sapbert, exclude_text=exclude_text,exclude_words=exclude_words,k=k)