paper-matching / app.py
jskim's picture
init files
6eff5e7
raw
history blame
6.03 kB
import gradio as gr
import os
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import pickle
from input_format import *
from score import *
# load document scoring model
pretrained_model = 'allenai/specter'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
doc_model = AutoModel.from_pretrained(pretrained_model)
# load sentence model
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
def get_similar_paper(
abstract_text_input,
pdf_file_input,
author_id_input,
num_papers_show=10
):
input_sentences = sent_tokenize(abstract_text_input)
pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
# TODO handle pdf file input
if pdf_file_input is not None:
name = None
papers = []
raise ValueError('Use submission abstract instead.')
else:
# Get author papers from id
name, papers = get_text_from_author_id(author_id_input)
# Compute Doc-level affinity scores for the Papers
titles, abstracts, doc_scores = compute_overall_score(
doc_model,
tokenizer,
abstract_text_input,
papers,
batch=30
)
tmp = {
'titles': titles,
'abstracts': abstracts,
'doc_scores': doc_scores
}
pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb'))
# Select top K choices of papers to show
titles = titles[:num_papers_show]
abstracts = abstracts[:num_papers_show]
doc_scores = doc_scores[:num_papers_show]
return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True)
def get_highlights(
abstract_text_input,
pdf_file_input,
abstract,
K=2
):
# Compute sent-level and phrase-level affinity scores for each papers
sent_ids, sent_scores, info = get_highlight_info(
sent_model,
abstract_text_input,
abstract,
K=K
)
input_sentences = sent_tokenize(abstract_text_input)
num_sents = len(input_sentences)
word_scores = dict()
# different highlights for each input sentences
for i in range(num_sents):
word_scores[str(i)] = {
"original": abstract,
"interpretation": list(zip(info['all_words'], info[i]['scores']))
}
tmp = {
'source_sentences': input_sentences,
'highlight': word_scores
}
pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
# update the visibility of radio choices
return gr.update(visible=True)
def update_name(author_id_input):
# update the name of the author based on the id input
name, _ = get_text_from_author_id(author_id_input)
return gr.update(value=name)
def change_output_highlight(source_sent_choice):
# change the output highlight based on the sentence selected from the submission
if os.path.exists('highlight_info.pkl'):
tmp = pickle.load(open('highlight_info.pkl', 'rb'))
source_sents = tmp['source_sentences']
highlights = tmp['highlight']
for i, s in enumerate(source_sents):
print('changing highlight!')
if source_sent_choice == s:
return highlights[str(i)]
else:
return
with gr.Blocks() as demo:
### INPUT
with gr.Row() as input_row:
with gr.Column():
abstract_text_input = gr.Textbox(label='Submission Abstract')
with gr.Column():
pdf_file_input = gr.File(label='OR upload a submission PDF File')
with gr.Column():
with gr.Row():
author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)')
with gr.Row():
name = gr.Textbox(label='Confirm Reviewer Name', interactive=False)
author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
with gr.Row():
compute_btn = gr.Button('Search Similar Papers from the Reviewer')
# with gr.Row(visible=False) as reviewer_name_info:
# name = gr.Textbox(label='Reveiwer Author Name')
# with gr.Row():
# with gr.Tabs():
# for tt in range(num_papers_show):
# with gr.TabItem('Paper %d'%(tt+1)):
# TODO handle multiple papers
### PAPER INFORMATION
with gr.Row():
with gr.Column(scale=3):
paper_title = gr.Textbox(label='Title', interactive=False)
with gr.Column(scale=1):
affinity= gr.Number(label='Affinity', interactive=False, value=0)
with gr.Row():
paper_abstract = gr.Textbox(label='Abstract', interactive=False)
with gr.Row(visible=False) as explain_button_row:
explain_btn = gr.Button('Show Relevant Parts from Selected Paper')
### RELEVANT PARTS (HIGHLIGHTS)
with gr.Row():
with gr.Column(scale=2): # text from submission
source_sentences = gr.Radio(
choices=[],
visible=False,
label='Sentences from Submission Abstract',
)
with gr.Column(scale=3): # highlighted text from paper
highlight = gr.components.Interpretation(paper_abstract)
compute_btn.click(
fn=get_similar_paper,
inputs=[
abstract_text_input,
pdf_file_input,
author_id_input
],
outputs=[
paper_title,
paper_abstract,
affinity,
source_sentences,
explain_button_row
]
)
explain_btn.click(
fn=get_highlights,
inputs=[
abstract_text_input,
pdf_file_input,
paper_abstract
],
outputs=source_sentences
)
source_sentences.change(
fn=change_output_highlight,
inputs=source_sentences,
outputs=highlight
)
if __name__ == "__main__":
demo.launch()