Spaces:
Runtime error
Runtime error
init files
Browse files- app.py +196 -0
- input_format.py +114 -0
- requirements.txt +83 -0
- score.py +149 -0
app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
import pickle
|
6 |
+
|
7 |
+
from input_format import *
|
8 |
+
from score import *
|
9 |
+
|
10 |
+
# load document scoring model
|
11 |
+
pretrained_model = 'allenai/specter'
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
|
13 |
+
doc_model = AutoModel.from_pretrained(pretrained_model)
|
14 |
+
|
15 |
+
# load sentence model
|
16 |
+
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
|
17 |
+
|
18 |
+
def get_similar_paper(
|
19 |
+
abstract_text_input,
|
20 |
+
pdf_file_input,
|
21 |
+
author_id_input,
|
22 |
+
num_papers_show=10
|
23 |
+
):
|
24 |
+
input_sentences = sent_tokenize(abstract_text_input)
|
25 |
+
|
26 |
+
pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
|
27 |
+
|
28 |
+
# TODO handle pdf file input
|
29 |
+
if pdf_file_input is not None:
|
30 |
+
name = None
|
31 |
+
papers = []
|
32 |
+
raise ValueError('Use submission abstract instead.')
|
33 |
+
else:
|
34 |
+
# Get author papers from id
|
35 |
+
name, papers = get_text_from_author_id(author_id_input)
|
36 |
+
|
37 |
+
# Compute Doc-level affinity scores for the Papers
|
38 |
+
titles, abstracts, doc_scores = compute_overall_score(
|
39 |
+
doc_model,
|
40 |
+
tokenizer,
|
41 |
+
abstract_text_input,
|
42 |
+
papers,
|
43 |
+
batch=30
|
44 |
+
)
|
45 |
+
|
46 |
+
tmp = {
|
47 |
+
'titles': titles,
|
48 |
+
'abstracts': abstracts,
|
49 |
+
'doc_scores': doc_scores
|
50 |
+
}
|
51 |
+
pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb'))
|
52 |
+
|
53 |
+
# Select top K choices of papers to show
|
54 |
+
titles = titles[:num_papers_show]
|
55 |
+
abstracts = abstracts[:num_papers_show]
|
56 |
+
doc_scores = doc_scores[:num_papers_show]
|
57 |
+
|
58 |
+
return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True)
|
59 |
+
|
60 |
+
def get_highlights(
|
61 |
+
abstract_text_input,
|
62 |
+
pdf_file_input,
|
63 |
+
abstract,
|
64 |
+
K=2
|
65 |
+
):
|
66 |
+
# Compute sent-level and phrase-level affinity scores for each papers
|
67 |
+
sent_ids, sent_scores, info = get_highlight_info(
|
68 |
+
sent_model,
|
69 |
+
abstract_text_input,
|
70 |
+
abstract,
|
71 |
+
K=K
|
72 |
+
)
|
73 |
+
|
74 |
+
input_sentences = sent_tokenize(abstract_text_input)
|
75 |
+
num_sents = len(input_sentences)
|
76 |
+
|
77 |
+
word_scores = dict()
|
78 |
+
# different highlights for each input sentences
|
79 |
+
for i in range(num_sents):
|
80 |
+
word_scores[str(i)] = {
|
81 |
+
"original": abstract,
|
82 |
+
"interpretation": list(zip(info['all_words'], info[i]['scores']))
|
83 |
+
}
|
84 |
+
|
85 |
+
tmp = {
|
86 |
+
'source_sentences': input_sentences,
|
87 |
+
'highlight': word_scores
|
88 |
+
}
|
89 |
+
pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
|
90 |
+
|
91 |
+
# update the visibility of radio choices
|
92 |
+
return gr.update(visible=True)
|
93 |
+
|
94 |
+
def update_name(author_id_input):
|
95 |
+
# update the name of the author based on the id input
|
96 |
+
name, _ = get_text_from_author_id(author_id_input)
|
97 |
+
return gr.update(value=name)
|
98 |
+
|
99 |
+
def change_output_highlight(source_sent_choice):
|
100 |
+
# change the output highlight based on the sentence selected from the submission
|
101 |
+
if os.path.exists('highlight_info.pkl'):
|
102 |
+
tmp = pickle.load(open('highlight_info.pkl', 'rb'))
|
103 |
+
source_sents = tmp['source_sentences']
|
104 |
+
highlights = tmp['highlight']
|
105 |
+
for i, s in enumerate(source_sents):
|
106 |
+
print('changing highlight!')
|
107 |
+
if source_sent_choice == s:
|
108 |
+
return highlights[str(i)]
|
109 |
+
else:
|
110 |
+
return
|
111 |
+
|
112 |
+
with gr.Blocks() as demo:
|
113 |
+
|
114 |
+
### INPUT
|
115 |
+
with gr.Row() as input_row:
|
116 |
+
with gr.Column():
|
117 |
+
abstract_text_input = gr.Textbox(label='Submission Abstract')
|
118 |
+
with gr.Column():
|
119 |
+
pdf_file_input = gr.File(label='OR upload a submission PDF File')
|
120 |
+
with gr.Column():
|
121 |
+
with gr.Row():
|
122 |
+
author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)')
|
123 |
+
with gr.Row():
|
124 |
+
name = gr.Textbox(label='Confirm Reviewer Name', interactive=False)
|
125 |
+
author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
|
126 |
+
with gr.Row():
|
127 |
+
compute_btn = gr.Button('Search Similar Papers from the Reviewer')
|
128 |
+
|
129 |
+
|
130 |
+
# with gr.Row(visible=False) as reviewer_name_info:
|
131 |
+
# name = gr.Textbox(label='Reveiwer Author Name')
|
132 |
+
# with gr.Row():
|
133 |
+
# with gr.Tabs():
|
134 |
+
# for tt in range(num_papers_show):
|
135 |
+
# with gr.TabItem('Paper %d'%(tt+1)):
|
136 |
+
|
137 |
+
# TODO handle multiple papers
|
138 |
+
|
139 |
+
### PAPER INFORMATION
|
140 |
+
with gr.Row():
|
141 |
+
with gr.Column(scale=3):
|
142 |
+
paper_title = gr.Textbox(label='Title', interactive=False)
|
143 |
+
with gr.Column(scale=1):
|
144 |
+
affinity= gr.Number(label='Affinity', interactive=False, value=0)
|
145 |
+
with gr.Row():
|
146 |
+
paper_abstract = gr.Textbox(label='Abstract', interactive=False)
|
147 |
+
|
148 |
+
with gr.Row(visible=False) as explain_button_row:
|
149 |
+
explain_btn = gr.Button('Show Relevant Parts from Selected Paper')
|
150 |
+
|
151 |
+
### RELEVANT PARTS (HIGHLIGHTS)
|
152 |
+
|
153 |
+
with gr.Row():
|
154 |
+
with gr.Column(scale=2): # text from submission
|
155 |
+
source_sentences = gr.Radio(
|
156 |
+
choices=[],
|
157 |
+
visible=False,
|
158 |
+
label='Sentences from Submission Abstract',
|
159 |
+
)
|
160 |
+
with gr.Column(scale=3): # highlighted text from paper
|
161 |
+
highlight = gr.components.Interpretation(paper_abstract)
|
162 |
+
|
163 |
+
compute_btn.click(
|
164 |
+
fn=get_similar_paper,
|
165 |
+
inputs=[
|
166 |
+
abstract_text_input,
|
167 |
+
pdf_file_input,
|
168 |
+
author_id_input
|
169 |
+
],
|
170 |
+
outputs=[
|
171 |
+
paper_title,
|
172 |
+
paper_abstract,
|
173 |
+
affinity,
|
174 |
+
source_sentences,
|
175 |
+
explain_button_row
|
176 |
+
]
|
177 |
+
)
|
178 |
+
|
179 |
+
explain_btn.click(
|
180 |
+
fn=get_highlights,
|
181 |
+
inputs=[
|
182 |
+
abstract_text_input,
|
183 |
+
pdf_file_input,
|
184 |
+
paper_abstract
|
185 |
+
],
|
186 |
+
outputs=source_sentences
|
187 |
+
)
|
188 |
+
|
189 |
+
source_sentences.change(
|
190 |
+
fn=change_output_highlight,
|
191 |
+
inputs=source_sentences,
|
192 |
+
outputs=highlight
|
193 |
+
)
|
194 |
+
|
195 |
+
if __name__ == "__main__":
|
196 |
+
demo.launch()
|
input_format.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from pypdf import PdfReader
|
3 |
+
from urllib.parse import urlparse
|
4 |
+
import requests
|
5 |
+
from semanticscholar import SemanticScholar
|
6 |
+
|
7 |
+
### Input Formatting Module
|
8 |
+
|
9 |
+
## Input formatting for the given paper
|
10 |
+
# Extracting text from a pdf or a link
|
11 |
+
|
12 |
+
def get_text_from_pdf(file_path):
|
13 |
+
"""
|
14 |
+
Convert a pdf to list of text files
|
15 |
+
"""
|
16 |
+
reader = PdfReader(file_path)
|
17 |
+
text = []
|
18 |
+
for p in reader.pages:
|
19 |
+
t = p.extract_text()
|
20 |
+
text.append(t)
|
21 |
+
return text
|
22 |
+
|
23 |
+
def get_text_from_url(url, file_path='paper.pdf'):
|
24 |
+
"""
|
25 |
+
Get text of the paper from a url
|
26 |
+
"""
|
27 |
+
# TODO check for other valid urls (e.g. semantic scholar)
|
28 |
+
|
29 |
+
## Check for different URL cases
|
30 |
+
url_parts = urlparse(url)
|
31 |
+
# arxiv
|
32 |
+
if 'arxiv' in url_parts.netloc:
|
33 |
+
if 'abs' in url_parts.path:
|
34 |
+
# abstract page, change the url to pdf link
|
35 |
+
paper_id = url_parts.path.split('/')[-1]
|
36 |
+
url = 'https://www.arxiv.org/pdf/%s.pdf'%(paper_id)
|
37 |
+
elif 'pdf' in url_parts.path:
|
38 |
+
# pdf file, pass
|
39 |
+
pass
|
40 |
+
else:
|
41 |
+
raise ValueError('invalid url')
|
42 |
+
else:
|
43 |
+
raise ValueError('invalid url')
|
44 |
+
|
45 |
+
# download the file
|
46 |
+
download_pdf(url, file_path)
|
47 |
+
|
48 |
+
# get the text from the pdf file
|
49 |
+
text = get_text_from_pdf(file_path)
|
50 |
+
return text
|
51 |
+
|
52 |
+
def download_pdf(url, file_name):
|
53 |
+
"""
|
54 |
+
Download the pdf file from given url and save it as file_name
|
55 |
+
"""
|
56 |
+
# Send GET request
|
57 |
+
response = requests.get(url)
|
58 |
+
|
59 |
+
# Save the PDF
|
60 |
+
if response.status_code == 200:
|
61 |
+
with open(file_name, "wb") as f:
|
62 |
+
f.write(response.content)
|
63 |
+
elif response.status_code == 404:
|
64 |
+
raise ValueError('cannot download the file')
|
65 |
+
else:
|
66 |
+
print(response.status_code)
|
67 |
+
|
68 |
+
## Input formatting for the given author (reviewer)
|
69 |
+
# Extracting text from a link
|
70 |
+
|
71 |
+
def get_text_from_author_id(author_id, max_count=100):
|
72 |
+
if author_id is None:
|
73 |
+
raise ValueError('Input valid author ID')
|
74 |
+
author_id = str(author_id)
|
75 |
+
# author_id = '1737249'
|
76 |
+
url = "https://api.semanticscholar.org/graph/v1/author/%s?fields=url,name,paperCount,papers,papers.title,papers.abstract"%author_id
|
77 |
+
r = requests.get(url)
|
78 |
+
if r.status_code == 404:
|
79 |
+
raise ValueError('Input valid author ID')
|
80 |
+
data = r.json()
|
81 |
+
papers = data['papers'][:max_count]
|
82 |
+
name = data['name']
|
83 |
+
|
84 |
+
return name, papers
|
85 |
+
|
86 |
+
## TODO Preprocess Extracted Texts from PDFs
|
87 |
+
# Get a portion of the text for actual task
|
88 |
+
|
89 |
+
def get_title(text):
|
90 |
+
pass
|
91 |
+
|
92 |
+
def get_abstract(text):
|
93 |
+
pass
|
94 |
+
|
95 |
+
def get_introduction(text):
|
96 |
+
pass
|
97 |
+
|
98 |
+
def get_conclusion(text):
|
99 |
+
pass
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == '__main__':
|
103 |
+
def run_sample():
|
104 |
+
url = 'https://arxiv.org/abs/2105.06506'
|
105 |
+
text = get_text_from_url(url)
|
106 |
+
assert(text[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
|
107 |
+
|
108 |
+
text2 = get_text_from_url('https://arxiv.org/pdf/2105.06506.pdf')
|
109 |
+
assert(text2[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
|
110 |
+
|
111 |
+
# text = get_text_from_url('https://arxiv.org/paetseths.pdf')
|
112 |
+
|
113 |
+
# test the code
|
114 |
+
run_sample()
|
requirements.txt
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.19.1
|
2 |
+
huggingface-hub==0.8.1
|
3 |
+
nltk==3.7
|
4 |
+
numpy==1.21.6
|
5 |
+
py-pdf-parser==0.10.2
|
6 |
+
py-rouge==1.1
|
7 |
+
pypdf==3.3.0
|
8 |
+
pyrogue==0.0.2
|
9 |
+
requests==2.28.1
|
10 |
+
rouge-score==0.1.2
|
11 |
+
scikit-learn==1.0.2
|
12 |
+
scipy==1.7.3
|
13 |
+
scs==2.1.4
|
14 |
+
seaborn==0.11.2
|
15 |
+
segtok==1.5.11
|
16 |
+
semanticscholar==0.3.2
|
17 |
+
sentence-transformers==2.2.0
|
18 |
+
sentencepiece==0.1.96
|
19 |
+
sentry-sdk==1.9.0
|
20 |
+
setproctitle==1.3.0
|
21 |
+
shap==0.40.0
|
22 |
+
shapely==2.0.0
|
23 |
+
shortuuid==1.0.9
|
24 |
+
six @ file:///tmp/build/80754af9/six_1623709665295/work
|
25 |
+
sklearn==0.0
|
26 |
+
slicer==0.0.7
|
27 |
+
smart-open==5.2.1
|
28 |
+
smmap==5.0.0
|
29 |
+
sniffio==1.2.0
|
30 |
+
spacy==3.0.8
|
31 |
+
spacy-legacy==3.0.9
|
32 |
+
spacy-loggers==1.0.3
|
33 |
+
sqlitedict==2.0.0
|
34 |
+
srsly==2.4.4
|
35 |
+
starlette==0.22.0
|
36 |
+
statsmodels==0.13.2
|
37 |
+
tabulate==0.8.9
|
38 |
+
tea==0.1.4
|
39 |
+
tea-client==0.0.7
|
40 |
+
tea-console==0.0.6
|
41 |
+
tenacity==8.1.0
|
42 |
+
tensorboardX==2.5.1
|
43 |
+
termcolor==1.1.0
|
44 |
+
terminado==0.9.4
|
45 |
+
testpath @ file:///tmp/build/80754af9/testpath_1624638946665/work
|
46 |
+
text-unidecode==1.3
|
47 |
+
thinc==8.0.17
|
48 |
+
threadpoolctl==2.2.0
|
49 |
+
tifffile==2021.11.2
|
50 |
+
tld==0.10
|
51 |
+
tokenizers==0.10.3
|
52 |
+
tomli==2.0.1
|
53 |
+
toolz==0.12.0
|
54 |
+
torch==1.9.0
|
55 |
+
torchaudio==0.10.2
|
56 |
+
torchdata==0.3.0
|
57 |
+
torchtext==0.12.0
|
58 |
+
torchvision==0.8.2
|
59 |
+
tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
|
60 |
+
tqdm==4.62.2
|
61 |
+
traitlets==5.3.0
|
62 |
+
transformers==4.3.3
|
63 |
+
transformers-interpret==0.5.2
|
64 |
+
typer==0.3.2
|
65 |
+
typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1624965014186/work
|
66 |
+
tzlocal==2.1
|
67 |
+
uc-micro-py==1.0.1
|
68 |
+
urllib3==1.26.6
|
69 |
+
uvicorn==0.20.0
|
70 |
+
Wand==0.6.10
|
71 |
+
wandb==0.12.21
|
72 |
+
wasabi==0.10.1
|
73 |
+
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
|
74 |
+
webencodings==0.5.1
|
75 |
+
websockets==10.4
|
76 |
+
Werkzeug==2.2.2
|
77 |
+
widgetsnbextension==3.5.1
|
78 |
+
Wikipedia-API==0.5.4
|
79 |
+
word2number==1.1
|
80 |
+
wrapt==1.12.1
|
81 |
+
xxhash==2.0.2
|
82 |
+
yarl==1.7.2
|
83 |
+
zipp @ file:///tmp/build/80754af9/zipp_1625570634446/work
|
score.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import util
|
2 |
+
from nltk.tokenize import sent_tokenize
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def compute_sentencewise_scores(model, query_sents, candidate_sents):
|
7 |
+
# list of sentences from query and candidate
|
8 |
+
|
9 |
+
q_v, c_v = get_embedding(model, query_sents, candidate_sents)
|
10 |
+
return util.cos_sim(q_v, c_v)
|
11 |
+
|
12 |
+
def get_embedding(model, query_sents, candidate_sents):
|
13 |
+
|
14 |
+
q_v = model.encode(query_sents)
|
15 |
+
c_v = model.encode(candidate_sents)
|
16 |
+
|
17 |
+
return q_v, c_v
|
18 |
+
|
19 |
+
def get_top_k(score_mat, K=3):
|
20 |
+
"""
|
21 |
+
Pick top K sentences to show
|
22 |
+
"""
|
23 |
+
idx = torch.argsort(-score_mat)
|
24 |
+
picked_sent = idx[:,:K]
|
25 |
+
picked_scores = torch.vstack(
|
26 |
+
[score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
|
27 |
+
)
|
28 |
+
|
29 |
+
return picked_sent, picked_scores
|
30 |
+
|
31 |
+
def get_words(sent):
|
32 |
+
words = []
|
33 |
+
sent_start_id = [] # keep track of the word index where the new sentence starts
|
34 |
+
counter = 0
|
35 |
+
for x in sent:
|
36 |
+
w = x.split()
|
37 |
+
nw = len(w)
|
38 |
+
counter += nw
|
39 |
+
words.append(w)
|
40 |
+
sent_start_id.append(counter)
|
41 |
+
words = [x.split() for x in sent]
|
42 |
+
all_words = [item for sublist in words for item in sublist]
|
43 |
+
sent_start_id.pop()
|
44 |
+
sent_start_id = [0] + sent_start_id
|
45 |
+
assert(len(sent_start_id) == len(sent))
|
46 |
+
return words, all_words, sent_start_id
|
47 |
+
|
48 |
+
def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
|
49 |
+
num_query_sent = sent_ids.shape[0]
|
50 |
+
num_words = len(all_words)
|
51 |
+
|
52 |
+
output = dict()
|
53 |
+
output['all_words'] = all_words
|
54 |
+
output['words_by_sentence'] = words
|
55 |
+
|
56 |
+
# for each query sentence, mark the highlight information
|
57 |
+
for i in range(num_query_sent):
|
58 |
+
is_selected_sent = np.zeros(num_words)
|
59 |
+
is_selected_phrase = np.zeros(num_words)
|
60 |
+
word_scores = np.zeros(num_words) + 1e-4
|
61 |
+
|
62 |
+
# get sentence selection information
|
63 |
+
for sid, sscore in zip(sent_ids[i], sent_scores[i]):
|
64 |
+
#print(len(sent_start_id), sid, sid+1)
|
65 |
+
if sid+1 < len(sent_start_id):
|
66 |
+
sent_range = (sent_start_id[sid], sent_start_id[sid+1])
|
67 |
+
is_selected_sent[sent_range[0]:sent_range[1]] = 1
|
68 |
+
word_scores[sent_range[0]:sent_range[1]] = sscore
|
69 |
+
else:
|
70 |
+
is_selected_sent[sent_range[0]:] = 1
|
71 |
+
word_scores[sent_range[0]:] = sscore
|
72 |
+
|
73 |
+
# TODO get phrase selection information
|
74 |
+
output[i] = {
|
75 |
+
'is_selected_sent': is_selected_sent,
|
76 |
+
'is_selected_phrase': is_selected_phrase,
|
77 |
+
'scores': word_scores
|
78 |
+
}
|
79 |
+
|
80 |
+
return output
|
81 |
+
|
82 |
+
def get_highlight_info(model, text1, text2, K=3):
|
83 |
+
sent1 = sent_tokenize(text1) # query
|
84 |
+
sent2 = sent_tokenize(text2) # candidate
|
85 |
+
score_mat = compute_sentencewise_scores(model, sent1, sent2)
|
86 |
+
|
87 |
+
sent_ids, sent_scores = get_top_k(score_mat, K=K)
|
88 |
+
#print(sent_ids, sent_scores)
|
89 |
+
words1, all_words1, sent_start_id1 = get_words(sent2)
|
90 |
+
#print(all_words1, sent_start_id1)
|
91 |
+
info = mark_words(words1, all_words1, sent_start_id1, sent_ids, sent_scores)
|
92 |
+
|
93 |
+
return sent_ids, sent_scores, info
|
94 |
+
|
95 |
+
## Document-level operations
|
96 |
+
def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
|
97 |
+
|
98 |
+
# concatenate title and abstract
|
99 |
+
title_abs = []
|
100 |
+
for t, a in zip(titles, abstracts):
|
101 |
+
if t is not None and a is not None:
|
102 |
+
title_abs.append(t + ' [SEP] ' + a)
|
103 |
+
|
104 |
+
num_docs = len(title_abs)
|
105 |
+
no_iter = int(np.ceil(num_docs / batch))
|
106 |
+
|
107 |
+
# preprocess the input
|
108 |
+
scores = []
|
109 |
+
with torch.no_grad():
|
110 |
+
# batch
|
111 |
+
for i in range(no_iter):
|
112 |
+
inputs = tokenizer(
|
113 |
+
[query] + title_abs[i*batch:(i+1)*batch],
|
114 |
+
padding=True,
|
115 |
+
truncation=True,
|
116 |
+
return_tensors="pt",
|
117 |
+
max_length=512
|
118 |
+
)
|
119 |
+
inputs.to(doc_model.device)
|
120 |
+
result = doc_model(**inputs)
|
121 |
+
|
122 |
+
# take the first token in the batch as the embedding
|
123 |
+
embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
|
124 |
+
|
125 |
+
# compute cosine similarity
|
126 |
+
q_emb = embeddings[0,:]
|
127 |
+
p_emb = embeddings[1:,:]
|
128 |
+
nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
|
129 |
+
scores += list(np.dot(p_emb, q_emb) / nn)
|
130 |
+
|
131 |
+
assert(len(scores) == num_docs)
|
132 |
+
|
133 |
+
return scores
|
134 |
+
|
135 |
+
def compute_overall_score(doc_model, tokenizer, query, papers, batch=5):
|
136 |
+
scores = []
|
137 |
+
titles = []
|
138 |
+
abstracts = []
|
139 |
+
for p in papers:
|
140 |
+
titles.append(p['title'])
|
141 |
+
abstracts.append(p['abstract'])
|
142 |
+
scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
|
143 |
+
idx_sorted = np.argsort(scores)[::-1]
|
144 |
+
|
145 |
+
titles_sorted = [titles[x] for x in idx_sorted]
|
146 |
+
abstracts_sorted = [abstracts[x] for x in idx_sorted]
|
147 |
+
scores_sorted = [scores[x] for x in idx_sorted]
|
148 |
+
|
149 |
+
return titles_sorted, abstracts_sorted, scores_sorted
|