import gradio as gr import os from transformers import AutoTokenizer, AutoModel from sentence_transformers import SentenceTransformer import pickle from input_format import * from score import * # load document scoring model pretrained_model = 'allenai/specter' tokenizer = AutoTokenizer.from_pretrained(pretrained_model) doc_model = AutoModel.from_pretrained(pretrained_model) # load sentence model sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base') def get_similar_paper( abstract_text_input, pdf_file_input, author_id_input, num_papers_show=10 ): input_sentences = sent_tokenize(abstract_text_input) pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb')) # TODO handle pdf file input if pdf_file_input is not None: name = None papers = [] raise ValueError('Use submission abstract instead.') else: # Get author papers from id name, papers = get_text_from_author_id(author_id_input) # Compute Doc-level affinity scores for the Papers titles, abstracts, doc_scores = compute_overall_score( doc_model, tokenizer, abstract_text_input, papers, batch=30 ) tmp = { 'titles': titles, 'abstracts': abstracts, 'doc_scores': doc_scores } pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb')) # Select top K choices of papers to show titles = titles[:num_papers_show] abstracts = abstracts[:num_papers_show] doc_scores = doc_scores[:num_papers_show] return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True) def get_highlights( abstract_text_input, pdf_file_input, abstract, K=2 ): # Compute sent-level and phrase-level affinity scores for each papers sent_ids, sent_scores, info = get_highlight_info( sent_model, abstract_text_input, abstract, K=K ) input_sentences = sent_tokenize(abstract_text_input) num_sents = len(input_sentences) word_scores = dict() # different highlights for each input sentences for i in range(num_sents): word_scores[str(i)] = { "original": abstract, "interpretation": list(zip(info['all_words'], info[i]['scores'])) } tmp = { 'source_sentences': input_sentences, 'highlight': word_scores } pickle.dump(tmp, open('highlight_info.pkl', 'wb')) # update the visibility of radio choices return gr.update(visible=True) def update_name(author_id_input): # update the name of the author based on the id input name, _ = get_text_from_author_id(author_id_input) return gr.update(value=name) def change_output_highlight(source_sent_choice): # change the output highlight based on the sentence selected from the submission if os.path.exists('highlight_info.pkl'): tmp = pickle.load(open('highlight_info.pkl', 'rb')) source_sents = tmp['source_sentences'] highlights = tmp['highlight'] for i, s in enumerate(source_sents): print('changing highlight!') if source_sent_choice == s: return highlights[str(i)] else: return with gr.Blocks() as demo: ### INPUT with gr.Row() as input_row: with gr.Column(): abstract_text_input = gr.Textbox(label='Submission Abstract') with gr.Column(): pdf_file_input = gr.File(label='OR upload a submission PDF File') with gr.Column(): with gr.Row(): author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)') with gr.Row(): name = gr.Textbox(label='Confirm Reviewer Name', interactive=False) author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name) with gr.Row(): compute_btn = gr.Button('Search Similar Papers from the Reviewer') # with gr.Row(visible=False) as reviewer_name_info: # name = gr.Textbox(label='Reveiwer Author Name') # with gr.Row(): # with gr.Tabs(): # for tt in range(num_papers_show): # with gr.TabItem('Paper %d'%(tt+1)): # TODO handle multiple papers ### PAPER INFORMATION with gr.Row(): with gr.Column(scale=3): paper_title = gr.Textbox(label='Title', interactive=False) with gr.Column(scale=1): affinity= gr.Number(label='Affinity', interactive=False, value=0) with gr.Row(): paper_abstract = gr.Textbox(label='Abstract', interactive=False) with gr.Row(visible=False) as explain_button_row: explain_btn = gr.Button('Show Relevant Parts from Selected Paper') ### RELEVANT PARTS (HIGHLIGHTS) with gr.Row(): with gr.Column(scale=2): # text from submission source_sentences = gr.Radio( choices=[], visible=False, label='Sentences from Submission Abstract', ) with gr.Column(scale=3): # highlighted text from paper highlight = gr.components.Interpretation(paper_abstract) compute_btn.click( fn=get_similar_paper, inputs=[ abstract_text_input, pdf_file_input, author_id_input ], outputs=[ paper_title, paper_abstract, affinity, source_sentences, explain_button_row ] ) explain_btn.click( fn=get_highlights, inputs=[ abstract_text_input, pdf_file_input, paper_abstract ], outputs=source_sentences ) source_sentences.change( fn=change_output_highlight, inputs=source_sentences, outputs=highlight ) if __name__ == "__main__": demo.launch()