Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
from transformers import AutoTokenizer, AutoModel | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
from input_format import * | |
from score import * | |
# load document scoring model | |
pretrained_model = 'allenai/specter' | |
tokenizer = AutoTokenizer.from_pretrained(pretrained_model) | |
doc_model = AutoModel.from_pretrained(pretrained_model) | |
# load sentence model | |
sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base') | |
def get_similar_paper( | |
abstract_text_input, | |
pdf_file_input, | |
author_id_input, | |
num_papers_show=10 | |
): | |
input_sentences = sent_tokenize(abstract_text_input) | |
pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb')) | |
# TODO handle pdf file input | |
if pdf_file_input is not None: | |
name = None | |
papers = [] | |
raise ValueError('Use submission abstract instead.') | |
else: | |
# Get author papers from id | |
name, papers = get_text_from_author_id(author_id_input) | |
# Compute Doc-level affinity scores for the Papers | |
titles, abstracts, doc_scores = compute_overall_score( | |
doc_model, | |
tokenizer, | |
abstract_text_input, | |
papers, | |
batch=30 | |
) | |
tmp = { | |
'titles': titles, | |
'abstracts': abstracts, | |
'doc_scores': doc_scores | |
} | |
pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb')) | |
# Select top K choices of papers to show | |
titles = titles[:num_papers_show] | |
abstracts = abstracts[:num_papers_show] | |
doc_scores = doc_scores[:num_papers_show] | |
return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True) | |
def get_highlights( | |
abstract_text_input, | |
pdf_file_input, | |
abstract, | |
K=2 | |
): | |
# Compute sent-level and phrase-level affinity scores for each papers | |
sent_ids, sent_scores, info = get_highlight_info( | |
sent_model, | |
abstract_text_input, | |
abstract, | |
K=K | |
) | |
input_sentences = sent_tokenize(abstract_text_input) | |
num_sents = len(input_sentences) | |
word_scores = dict() | |
# different highlights for each input sentences | |
for i in range(num_sents): | |
word_scores[str(i)] = { | |
"original": abstract, | |
"interpretation": list(zip(info['all_words'], info[i]['scores'])) | |
} | |
tmp = { | |
'source_sentences': input_sentences, | |
'highlight': word_scores | |
} | |
pickle.dump(tmp, open('highlight_info.pkl', 'wb')) | |
# update the visibility of radio choices | |
return gr.update(visible=True) | |
def update_name(author_id_input): | |
# update the name of the author based on the id input | |
name, _ = get_text_from_author_id(author_id_input) | |
return gr.update(value=name) | |
def change_output_highlight(source_sent_choice): | |
# change the output highlight based on the sentence selected from the submission | |
if os.path.exists('highlight_info.pkl'): | |
tmp = pickle.load(open('highlight_info.pkl', 'rb')) | |
source_sents = tmp['source_sentences'] | |
highlights = tmp['highlight'] | |
for i, s in enumerate(source_sents): | |
print('changing highlight!') | |
if source_sent_choice == s: | |
return highlights[str(i)] | |
else: | |
return | |
with gr.Blocks() as demo: | |
### INPUT | |
with gr.Row() as input_row: | |
with gr.Column(): | |
abstract_text_input = gr.Textbox(label='Submission Abstract') | |
with gr.Column(): | |
pdf_file_input = gr.File(label='OR upload a submission PDF File') | |
with gr.Column(): | |
with gr.Row(): | |
author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)') | |
with gr.Row(): | |
name = gr.Textbox(label='Confirm Reviewer Name', interactive=False) | |
author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name) | |
with gr.Row(): | |
compute_btn = gr.Button('Search Similar Papers from the Reviewer') | |
# with gr.Row(visible=False) as reviewer_name_info: | |
# name = gr.Textbox(label='Reveiwer Author Name') | |
# with gr.Row(): | |
# with gr.Tabs(): | |
# for tt in range(num_papers_show): | |
# with gr.TabItem('Paper %d'%(tt+1)): | |
# TODO handle multiple papers | |
### PAPER INFORMATION | |
with gr.Row(): | |
with gr.Column(scale=3): | |
paper_title = gr.Textbox(label='Title', interactive=False) | |
with gr.Column(scale=1): | |
affinity= gr.Number(label='Affinity', interactive=False, value=0) | |
with gr.Row(): | |
paper_abstract = gr.Textbox(label='Abstract', interactive=False) | |
with gr.Row(visible=False) as explain_button_row: | |
explain_btn = gr.Button('Show Relevant Parts from Selected Paper') | |
### RELEVANT PARTS (HIGHLIGHTS) | |
with gr.Row(): | |
with gr.Column(scale=2): # text from submission | |
source_sentences = gr.Radio( | |
choices=[], | |
visible=False, | |
label='Sentences from Submission Abstract', | |
) | |
with gr.Column(scale=3): # highlighted text from paper | |
highlight = gr.components.Interpretation(paper_abstract) | |
compute_btn.click( | |
fn=get_similar_paper, | |
inputs=[ | |
abstract_text_input, | |
pdf_file_input, | |
author_id_input | |
], | |
outputs=[ | |
paper_title, | |
paper_abstract, | |
affinity, | |
source_sentences, | |
explain_button_row | |
] | |
) | |
explain_btn.click( | |
fn=get_highlights, | |
inputs=[ | |
abstract_text_input, | |
pdf_file_input, | |
paper_abstract | |
], | |
outputs=source_sentences | |
) | |
source_sentences.change( | |
fn=change_output_highlight, | |
inputs=source_sentences, | |
outputs=highlight | |
) | |
if __name__ == "__main__": | |
demo.launch() |