jskim commited on
Commit
580aef7
1 Parent(s): 4bea31b

added phrase highlights

Browse files
Files changed (2) hide show
  1. app.py +5 -3
  2. score.py +54 -11
app.py CHANGED
@@ -10,12 +10,16 @@ from input_format import *
10
  from score import *
11
 
12
  # load document scoring model
 
 
13
  pretrained_model = 'allenai/specter'
14
  tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
15
  doc_model = AutoModel.from_pretrained(pretrained_model)
 
16
 
17
  # load sentence model
18
  sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
 
19
 
20
  def get_similar_paper(
21
  abstract_text_input,
@@ -25,8 +29,6 @@ def get_similar_paper(
25
  ):
26
  input_sentences = sent_tokenize(abstract_text_input)
27
 
28
- pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
29
-
30
  # TODO handle pdf file input
31
  if pdf_file_input is not None:
32
  name = None
@@ -42,7 +44,7 @@ def get_similar_paper(
42
  tokenizer,
43
  abstract_text_input,
44
  papers,
45
- batch=30
46
  )
47
 
48
  tmp = {
 
10
  from score import *
11
 
12
  # load document scoring model
13
+ torch.cuda.is_available = lambda : False
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  pretrained_model = 'allenai/specter'
16
  tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
17
  doc_model = AutoModel.from_pretrained(pretrained_model)
18
+ doc_model.to(device)
19
 
20
  # load sentence model
21
  sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
22
+ sent_model.to(device)
23
 
24
  def get_similar_paper(
25
  abstract_text_input,
 
29
  ):
30
  input_sentences = sent_tokenize(abstract_text_input)
31
 
 
 
32
  # TODO handle pdf file input
33
  if pdf_file_input is not None:
34
  name = None
 
44
  tokenizer,
45
  abstract_text_input,
46
  papers,
47
+ batch=50
48
  )
49
 
50
  tmp = {
score.py CHANGED
@@ -1,5 +1,6 @@
1
  from sentence_transformers import util
2
  from nltk.tokenize import sent_tokenize
 
3
  import torch
4
  import numpy as np
5
 
@@ -33,19 +34,52 @@ def get_words(sent):
33
  sent_start_id = [] # keep track of the word index where the new sentence starts
34
  counter = 0
35
  for x in sent:
36
- w = x.split()
 
37
  nw = len(w)
38
  counter += nw
39
  words.append(w)
40
  sent_start_id.append(counter)
41
- words = [x.split() for x in sent]
42
  all_words = [item for sublist in words for item in sublist]
43
  sent_start_id.pop()
44
  sent_start_id = [0] + sent_start_id
45
  assert(len(sent_start_id) == len(sent))
46
  return words, all_words, sent_start_id
47
 
48
- def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  num_query_sent = sent_ids.shape[0]
50
  num_words = len(all_words)
51
 
@@ -55,22 +89,29 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
55
 
56
  # for each query sentence, mark the highlight information
57
  for i in range(num_query_sent):
 
58
  is_selected_sent = np.zeros(num_words)
59
  is_selected_phrase = np.zeros(num_words)
60
- word_scores = np.zeros(num_words) + 1e-4
61
 
62
- # get sentence selection information
63
  for sid, sscore in zip(sent_ids[i], sent_scores[i]):
64
  #print(len(sent_start_id), sid, sid+1)
65
  if sid+1 < len(sent_start_id):
66
  sent_range = (sent_start_id[sid], sent_start_id[sid+1])
67
  is_selected_sent[sent_range[0]:sent_range[1]] = 1
68
  word_scores[sent_range[0]:sent_range[1]] = sscore
 
 
69
  else:
70
- is_selected_sent[sent_range[0]:] = 1
71
- word_scores[sent_range[0]:] = sscore
 
 
 
 
 
72
 
73
- # TODO get phrase selection information
74
  output[i] = {
75
  'is_selected_sent': is_selected_sent,
76
  'is_selected_phrase': is_selected_phrase,
@@ -79,16 +120,18 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
79
 
80
  return output
81
 
82
- def get_highlight_info(model, text1, text2, K=3):
83
  sent1 = sent_tokenize(text1) # query
84
  sent2 = sent_tokenize(text2) # candidate
 
 
85
  score_mat = compute_sentencewise_scores(model, sent1, sent2)
86
 
87
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
88
  #print(sent_ids, sent_scores)
89
- words1, all_words1, sent_start_id1 = get_words(sent2)
90
  #print(all_words1, sent_start_id1)
91
- info = mark_words(words1, all_words1, sent_start_id1, sent_ids, sent_scores)
92
 
93
  return sent_ids, sent_scores, info
94
 
 
1
  from sentence_transformers import util
2
  from nltk.tokenize import sent_tokenize
3
+ from nltk import word_tokenize, pos_tag
4
  import torch
5
  import numpy as np
6
 
 
34
  sent_start_id = [] # keep track of the word index where the new sentence starts
35
  counter = 0
36
  for x in sent:
37
+ #w = x.split()
38
+ w = word_tokenize(x)
39
  nw = len(w)
40
  counter += nw
41
  words.append(w)
42
  sent_start_id.append(counter)
43
+ words = [word_tokenize(x) for x in sent]
44
  all_words = [item for sublist in words for item in sublist]
45
  sent_start_id.pop()
46
  sent_start_id = [0] + sent_start_id
47
  assert(len(sent_start_id) == len(sent))
48
  return words, all_words, sent_start_id
49
 
50
+ def get_match_phrase(w1, w2):
51
+ # list of words for query and candidate as input
52
+ # return the word list and binary mask of matching phrases
53
+ # POS tags that should be considered for matching phrase
54
+ include = [
55
+ 'JJ',
56
+ 'JJR',
57
+ 'JJS',
58
+ 'MD',
59
+ 'NN',
60
+ 'NNS',
61
+ 'NNP',
62
+ 'NNPS',
63
+ 'RB',
64
+ 'RBR',
65
+ 'RBS',
66
+ 'SYM',
67
+ 'VB',
68
+ 'VBD',
69
+ 'VBG',
70
+ 'VBN',
71
+ 'FW'
72
+ ]
73
+ mask1 = np.zeros(len(w1))
74
+ mask2 = np.zeros(len(w2))
75
+ pos1 = pos_tag(w1)
76
+ pos2 = pos_tag(w2)
77
+ for i, (w, p) in enumerate(pos2):
78
+ if w.lower() in w1 and p in include:
79
+ mask2[i] = 1
80
+ return mask2
81
+
82
+ def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
83
  num_query_sent = sent_ids.shape[0]
84
  num_words = len(all_words)
85
 
 
89
 
90
  # for each query sentence, mark the highlight information
91
  for i in range(num_query_sent):
92
+ query_words = word_tokenize(query_sents[i])
93
  is_selected_sent = np.zeros(num_words)
94
  is_selected_phrase = np.zeros(num_words)
95
+ word_scores = np.zeros(num_words)
96
 
97
+ # for each selected sentences from the candidate, compile information
98
  for sid, sscore in zip(sent_ids[i], sent_scores[i]):
99
  #print(len(sent_start_id), sid, sid+1)
100
  if sid+1 < len(sent_start_id):
101
  sent_range = (sent_start_id[sid], sent_start_id[sid+1])
102
  is_selected_sent[sent_range[0]:sent_range[1]] = 1
103
  word_scores[sent_range[0]:sent_range[1]] = sscore
104
+ is_selected_phrase[sent_range[0]:sent_range[1]] = \
105
+ get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
106
  else:
107
+ is_selected_sent[sent_start_id[sid]:] = 1
108
+ word_scores[sent_start_id[sid]:] = sscore
109
+ is_selected_phrase[sent_start_id[sid]:] = \
110
+ get_match_phrase(query_words, all_words[sent_start_id[sid]:])
111
+
112
+ # update selected phrase scores (-1 meaning a different color in gradio)
113
+ word_scores[is_selected_sent+is_selected_phrase==2] = -1
114
 
 
115
  output[i] = {
116
  'is_selected_sent': is_selected_sent,
117
  'is_selected_phrase': is_selected_phrase,
 
120
 
121
  return output
122
 
123
+ def get_highlight_info(model, text1, text2, K=None):
124
  sent1 = sent_tokenize(text1) # query
125
  sent2 = sent_tokenize(text2) # candidate
126
+ if K is None: # if K is not set, select based on the length of the candidate
127
+ K = int(len(sent2) / 3)
128
  score_mat = compute_sentencewise_scores(model, sent1, sent2)
129
 
130
  sent_ids, sent_scores = get_top_k(score_mat, K=K)
131
  #print(sent_ids, sent_scores)
132
+ words2, all_words2, sent_start_id2 = get_words(sent2)
133
  #print(all_words1, sent_start_id1)
134
+ info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
135
 
136
  return sent_ids, sent_scores, info
137