Spaces:
Runtime error
Runtime error
added phrase highlights
Browse files
app.py
CHANGED
@@ -10,12 +10,16 @@ from input_format import *
|
|
10 |
from score import *
|
11 |
|
12 |
# load document scoring model
|
|
|
|
|
13 |
pretrained_model = 'allenai/specter'
|
14 |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
|
15 |
doc_model = AutoModel.from_pretrained(pretrained_model)
|
|
|
16 |
|
17 |
# load sentence model
|
18 |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
|
|
19 |
|
20 |
def get_similar_paper(
|
21 |
abstract_text_input,
|
@@ -25,8 +29,6 @@ def get_similar_paper(
|
|
25 |
):
|
26 |
input_sentences = sent_tokenize(abstract_text_input)
|
27 |
|
28 |
-
pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
|
29 |
-
|
30 |
# TODO handle pdf file input
|
31 |
if pdf_file_input is not None:
|
32 |
name = None
|
@@ -42,7 +44,7 @@ def get_similar_paper(
|
|
42 |
tokenizer,
|
43 |
abstract_text_input,
|
44 |
papers,
|
45 |
-
batch=
|
46 |
)
|
47 |
|
48 |
tmp = {
|
|
|
10 |
from score import *
|
11 |
|
12 |
# load document scoring model
|
13 |
+
torch.cuda.is_available = lambda : False
|
14 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
15 |
pretrained_model = 'allenai/specter'
|
16 |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
|
17 |
doc_model = AutoModel.from_pretrained(pretrained_model)
|
18 |
+
doc_model.to(device)
|
19 |
|
20 |
# load sentence model
|
21 |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
22 |
+
sent_model.to(device)
|
23 |
|
24 |
def get_similar_paper(
|
25 |
abstract_text_input,
|
|
|
29 |
):
|
30 |
input_sentences = sent_tokenize(abstract_text_input)
|
31 |
|
|
|
|
|
32 |
# TODO handle pdf file input
|
33 |
if pdf_file_input is not None:
|
34 |
name = None
|
|
|
44 |
tokenizer,
|
45 |
abstract_text_input,
|
46 |
papers,
|
47 |
+
batch=50
|
48 |
)
|
49 |
|
50 |
tmp = {
|
score.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from sentence_transformers import util
|
2 |
from nltk.tokenize import sent_tokenize
|
|
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
|
@@ -33,19 +34,52 @@ def get_words(sent):
|
|
33 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
34 |
counter = 0
|
35 |
for x in sent:
|
36 |
-
w = x.split()
|
|
|
37 |
nw = len(w)
|
38 |
counter += nw
|
39 |
words.append(w)
|
40 |
sent_start_id.append(counter)
|
41 |
-
words = [x
|
42 |
all_words = [item for sublist in words for item in sublist]
|
43 |
sent_start_id.pop()
|
44 |
sent_start_id = [0] + sent_start_id
|
45 |
assert(len(sent_start_id) == len(sent))
|
46 |
return words, all_words, sent_start_id
|
47 |
|
48 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
num_query_sent = sent_ids.shape[0]
|
50 |
num_words = len(all_words)
|
51 |
|
@@ -55,22 +89,29 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
|
|
55 |
|
56 |
# for each query sentence, mark the highlight information
|
57 |
for i in range(num_query_sent):
|
|
|
58 |
is_selected_sent = np.zeros(num_words)
|
59 |
is_selected_phrase = np.zeros(num_words)
|
60 |
-
word_scores = np.zeros(num_words)
|
61 |
|
62 |
-
#
|
63 |
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
64 |
#print(len(sent_start_id), sid, sid+1)
|
65 |
if sid+1 < len(sent_start_id):
|
66 |
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
67 |
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
68 |
word_scores[sent_range[0]:sent_range[1]] = sscore
|
|
|
|
|
69 |
else:
|
70 |
-
is_selected_sent[
|
71 |
-
word_scores[
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
# TODO get phrase selection information
|
74 |
output[i] = {
|
75 |
'is_selected_sent': is_selected_sent,
|
76 |
'is_selected_phrase': is_selected_phrase,
|
@@ -79,16 +120,18 @@ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
|
|
79 |
|
80 |
return output
|
81 |
|
82 |
-
def get_highlight_info(model, text1, text2, K=
|
83 |
sent1 = sent_tokenize(text1) # query
|
84 |
sent2 = sent_tokenize(text2) # candidate
|
|
|
|
|
85 |
score_mat = compute_sentencewise_scores(model, sent1, sent2)
|
86 |
|
87 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
88 |
#print(sent_ids, sent_scores)
|
89 |
-
|
90 |
#print(all_words1, sent_start_id1)
|
91 |
-
info = mark_words(
|
92 |
|
93 |
return sent_ids, sent_scores, info
|
94 |
|
|
|
1 |
from sentence_transformers import util
|
2 |
from nltk.tokenize import sent_tokenize
|
3 |
+
from nltk import word_tokenize, pos_tag
|
4 |
import torch
|
5 |
import numpy as np
|
6 |
|
|
|
34 |
sent_start_id = [] # keep track of the word index where the new sentence starts
|
35 |
counter = 0
|
36 |
for x in sent:
|
37 |
+
#w = x.split()
|
38 |
+
w = word_tokenize(x)
|
39 |
nw = len(w)
|
40 |
counter += nw
|
41 |
words.append(w)
|
42 |
sent_start_id.append(counter)
|
43 |
+
words = [word_tokenize(x) for x in sent]
|
44 |
all_words = [item for sublist in words for item in sublist]
|
45 |
sent_start_id.pop()
|
46 |
sent_start_id = [0] + sent_start_id
|
47 |
assert(len(sent_start_id) == len(sent))
|
48 |
return words, all_words, sent_start_id
|
49 |
|
50 |
+
def get_match_phrase(w1, w2):
|
51 |
+
# list of words for query and candidate as input
|
52 |
+
# return the word list and binary mask of matching phrases
|
53 |
+
# POS tags that should be considered for matching phrase
|
54 |
+
include = [
|
55 |
+
'JJ',
|
56 |
+
'JJR',
|
57 |
+
'JJS',
|
58 |
+
'MD',
|
59 |
+
'NN',
|
60 |
+
'NNS',
|
61 |
+
'NNP',
|
62 |
+
'NNPS',
|
63 |
+
'RB',
|
64 |
+
'RBR',
|
65 |
+
'RBS',
|
66 |
+
'SYM',
|
67 |
+
'VB',
|
68 |
+
'VBD',
|
69 |
+
'VBG',
|
70 |
+
'VBN',
|
71 |
+
'FW'
|
72 |
+
]
|
73 |
+
mask1 = np.zeros(len(w1))
|
74 |
+
mask2 = np.zeros(len(w2))
|
75 |
+
pos1 = pos_tag(w1)
|
76 |
+
pos2 = pos_tag(w2)
|
77 |
+
for i, (w, p) in enumerate(pos2):
|
78 |
+
if w.lower() in w1 and p in include:
|
79 |
+
mask2[i] = 1
|
80 |
+
return mask2
|
81 |
+
|
82 |
+
def mark_words(query_sents, words, all_words, sent_start_id, sent_ids, sent_scores):
|
83 |
num_query_sent = sent_ids.shape[0]
|
84 |
num_words = len(all_words)
|
85 |
|
|
|
89 |
|
90 |
# for each query sentence, mark the highlight information
|
91 |
for i in range(num_query_sent):
|
92 |
+
query_words = word_tokenize(query_sents[i])
|
93 |
is_selected_sent = np.zeros(num_words)
|
94 |
is_selected_phrase = np.zeros(num_words)
|
95 |
+
word_scores = np.zeros(num_words)
|
96 |
|
97 |
+
# for each selected sentences from the candidate, compile information
|
98 |
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
99 |
#print(len(sent_start_id), sid, sid+1)
|
100 |
if sid+1 < len(sent_start_id):
|
101 |
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
102 |
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
103 |
word_scores[sent_range[0]:sent_range[1]] = sscore
|
104 |
+
is_selected_phrase[sent_range[0]:sent_range[1]] = \
|
105 |
+
get_match_phrase(query_words, all_words[sent_range[0]:sent_range[1]])
|
106 |
else:
|
107 |
+
is_selected_sent[sent_start_id[sid]:] = 1
|
108 |
+
word_scores[sent_start_id[sid]:] = sscore
|
109 |
+
is_selected_phrase[sent_start_id[sid]:] = \
|
110 |
+
get_match_phrase(query_words, all_words[sent_start_id[sid]:])
|
111 |
+
|
112 |
+
# update selected phrase scores (-1 meaning a different color in gradio)
|
113 |
+
word_scores[is_selected_sent+is_selected_phrase==2] = -1
|
114 |
|
|
|
115 |
output[i] = {
|
116 |
'is_selected_sent': is_selected_sent,
|
117 |
'is_selected_phrase': is_selected_phrase,
|
|
|
120 |
|
121 |
return output
|
122 |
|
123 |
+
def get_highlight_info(model, text1, text2, K=None):
|
124 |
sent1 = sent_tokenize(text1) # query
|
125 |
sent2 = sent_tokenize(text2) # candidate
|
126 |
+
if K is None: # if K is not set, select based on the length of the candidate
|
127 |
+
K = int(len(sent2) / 3)
|
128 |
score_mat = compute_sentencewise_scores(model, sent1, sent2)
|
129 |
|
130 |
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
131 |
#print(sent_ids, sent_scores)
|
132 |
+
words2, all_words2, sent_start_id2 = get_words(sent2)
|
133 |
#print(all_words1, sent_start_id1)
|
134 |
+
info = mark_words(sent1, words2, all_words2, sent_start_id2, sent_ids, sent_scores)
|
135 |
|
136 |
return sent_ids, sent_scores, info
|
137 |
|