jskim commited on
Commit
6eff5e7
1 Parent(s): 82d22b3

init files

Browse files
Files changed (4) hide show
  1. app.py +196 -0
  2. input_format.py +114 -0
  3. requirements.txt +83 -0
  4. score.py +149 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from sentence_transformers import SentenceTransformer
5
+ import pickle
6
+
7
+ from input_format import *
8
+ from score import *
9
+
10
+ # load document scoring model
11
+ pretrained_model = 'allenai/specter'
12
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
13
+ doc_model = AutoModel.from_pretrained(pretrained_model)
14
+
15
+ # load sentence model
16
+ sent_model = SentenceTransformer('sentence-transformers/gtr-t5-base')
17
+
18
+ def get_similar_paper(
19
+ abstract_text_input,
20
+ pdf_file_input,
21
+ author_id_input,
22
+ num_papers_show=10
23
+ ):
24
+ input_sentences = sent_tokenize(abstract_text_input)
25
+
26
+ pickle.dump(input_sentences, open('tmp_input_sents.pkl', 'wb'))
27
+
28
+ # TODO handle pdf file input
29
+ if pdf_file_input is not None:
30
+ name = None
31
+ papers = []
32
+ raise ValueError('Use submission abstract instead.')
33
+ else:
34
+ # Get author papers from id
35
+ name, papers = get_text_from_author_id(author_id_input)
36
+
37
+ # Compute Doc-level affinity scores for the Papers
38
+ titles, abstracts, doc_scores = compute_overall_score(
39
+ doc_model,
40
+ tokenizer,
41
+ abstract_text_input,
42
+ papers,
43
+ batch=30
44
+ )
45
+
46
+ tmp = {
47
+ 'titles': titles,
48
+ 'abstracts': abstracts,
49
+ 'doc_scores': doc_scores
50
+ }
51
+ pickle.dump(tmp, open('tmp_paperinfo.pkl', 'wb'))
52
+
53
+ # Select top K choices of papers to show
54
+ titles = titles[:num_papers_show]
55
+ abstracts = abstracts[:num_papers_show]
56
+ doc_scores = doc_scores[:num_papers_show]
57
+
58
+ return titles[0], abstracts[0], doc_scores[0], gr.update(choices=input_sentences, interactive=True), gr.update(visible=True)
59
+
60
+ def get_highlights(
61
+ abstract_text_input,
62
+ pdf_file_input,
63
+ abstract,
64
+ K=2
65
+ ):
66
+ # Compute sent-level and phrase-level affinity scores for each papers
67
+ sent_ids, sent_scores, info = get_highlight_info(
68
+ sent_model,
69
+ abstract_text_input,
70
+ abstract,
71
+ K=K
72
+ )
73
+
74
+ input_sentences = sent_tokenize(abstract_text_input)
75
+ num_sents = len(input_sentences)
76
+
77
+ word_scores = dict()
78
+ # different highlights for each input sentences
79
+ for i in range(num_sents):
80
+ word_scores[str(i)] = {
81
+ "original": abstract,
82
+ "interpretation": list(zip(info['all_words'], info[i]['scores']))
83
+ }
84
+
85
+ tmp = {
86
+ 'source_sentences': input_sentences,
87
+ 'highlight': word_scores
88
+ }
89
+ pickle.dump(tmp, open('highlight_info.pkl', 'wb'))
90
+
91
+ # update the visibility of radio choices
92
+ return gr.update(visible=True)
93
+
94
+ def update_name(author_id_input):
95
+ # update the name of the author based on the id input
96
+ name, _ = get_text_from_author_id(author_id_input)
97
+ return gr.update(value=name)
98
+
99
+ def change_output_highlight(source_sent_choice):
100
+ # change the output highlight based on the sentence selected from the submission
101
+ if os.path.exists('highlight_info.pkl'):
102
+ tmp = pickle.load(open('highlight_info.pkl', 'rb'))
103
+ source_sents = tmp['source_sentences']
104
+ highlights = tmp['highlight']
105
+ for i, s in enumerate(source_sents):
106
+ print('changing highlight!')
107
+ if source_sent_choice == s:
108
+ return highlights[str(i)]
109
+ else:
110
+ return
111
+
112
+ with gr.Blocks() as demo:
113
+
114
+ ### INPUT
115
+ with gr.Row() as input_row:
116
+ with gr.Column():
117
+ abstract_text_input = gr.Textbox(label='Submission Abstract')
118
+ with gr.Column():
119
+ pdf_file_input = gr.File(label='OR upload a submission PDF File')
120
+ with gr.Column():
121
+ with gr.Row():
122
+ author_id_input = gr.Textbox(label='Reviewer ID (Semantic Scholar)')
123
+ with gr.Row():
124
+ name = gr.Textbox(label='Confirm Reviewer Name', interactive=False)
125
+ author_id_input.change(fn=update_name, inputs=author_id_input, outputs=name)
126
+ with gr.Row():
127
+ compute_btn = gr.Button('Search Similar Papers from the Reviewer')
128
+
129
+
130
+ # with gr.Row(visible=False) as reviewer_name_info:
131
+ # name = gr.Textbox(label='Reveiwer Author Name')
132
+ # with gr.Row():
133
+ # with gr.Tabs():
134
+ # for tt in range(num_papers_show):
135
+ # with gr.TabItem('Paper %d'%(tt+1)):
136
+
137
+ # TODO handle multiple papers
138
+
139
+ ### PAPER INFORMATION
140
+ with gr.Row():
141
+ with gr.Column(scale=3):
142
+ paper_title = gr.Textbox(label='Title', interactive=False)
143
+ with gr.Column(scale=1):
144
+ affinity= gr.Number(label='Affinity', interactive=False, value=0)
145
+ with gr.Row():
146
+ paper_abstract = gr.Textbox(label='Abstract', interactive=False)
147
+
148
+ with gr.Row(visible=False) as explain_button_row:
149
+ explain_btn = gr.Button('Show Relevant Parts from Selected Paper')
150
+
151
+ ### RELEVANT PARTS (HIGHLIGHTS)
152
+
153
+ with gr.Row():
154
+ with gr.Column(scale=2): # text from submission
155
+ source_sentences = gr.Radio(
156
+ choices=[],
157
+ visible=False,
158
+ label='Sentences from Submission Abstract',
159
+ )
160
+ with gr.Column(scale=3): # highlighted text from paper
161
+ highlight = gr.components.Interpretation(paper_abstract)
162
+
163
+ compute_btn.click(
164
+ fn=get_similar_paper,
165
+ inputs=[
166
+ abstract_text_input,
167
+ pdf_file_input,
168
+ author_id_input
169
+ ],
170
+ outputs=[
171
+ paper_title,
172
+ paper_abstract,
173
+ affinity,
174
+ source_sentences,
175
+ explain_button_row
176
+ ]
177
+ )
178
+
179
+ explain_btn.click(
180
+ fn=get_highlights,
181
+ inputs=[
182
+ abstract_text_input,
183
+ pdf_file_input,
184
+ paper_abstract
185
+ ],
186
+ outputs=source_sentences
187
+ )
188
+
189
+ source_sentences.change(
190
+ fn=change_output_highlight,
191
+ inputs=source_sentences,
192
+ outputs=highlight
193
+ )
194
+
195
+ if __name__ == "__main__":
196
+ demo.launch()
input_format.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pypdf import PdfReader
3
+ from urllib.parse import urlparse
4
+ import requests
5
+ from semanticscholar import SemanticScholar
6
+
7
+ ### Input Formatting Module
8
+
9
+ ## Input formatting for the given paper
10
+ # Extracting text from a pdf or a link
11
+
12
+ def get_text_from_pdf(file_path):
13
+ """
14
+ Convert a pdf to list of text files
15
+ """
16
+ reader = PdfReader(file_path)
17
+ text = []
18
+ for p in reader.pages:
19
+ t = p.extract_text()
20
+ text.append(t)
21
+ return text
22
+
23
+ def get_text_from_url(url, file_path='paper.pdf'):
24
+ """
25
+ Get text of the paper from a url
26
+ """
27
+ # TODO check for other valid urls (e.g. semantic scholar)
28
+
29
+ ## Check for different URL cases
30
+ url_parts = urlparse(url)
31
+ # arxiv
32
+ if 'arxiv' in url_parts.netloc:
33
+ if 'abs' in url_parts.path:
34
+ # abstract page, change the url to pdf link
35
+ paper_id = url_parts.path.split('/')[-1]
36
+ url = 'https://www.arxiv.org/pdf/%s.pdf'%(paper_id)
37
+ elif 'pdf' in url_parts.path:
38
+ # pdf file, pass
39
+ pass
40
+ else:
41
+ raise ValueError('invalid url')
42
+ else:
43
+ raise ValueError('invalid url')
44
+
45
+ # download the file
46
+ download_pdf(url, file_path)
47
+
48
+ # get the text from the pdf file
49
+ text = get_text_from_pdf(file_path)
50
+ return text
51
+
52
+ def download_pdf(url, file_name):
53
+ """
54
+ Download the pdf file from given url and save it as file_name
55
+ """
56
+ # Send GET request
57
+ response = requests.get(url)
58
+
59
+ # Save the PDF
60
+ if response.status_code == 200:
61
+ with open(file_name, "wb") as f:
62
+ f.write(response.content)
63
+ elif response.status_code == 404:
64
+ raise ValueError('cannot download the file')
65
+ else:
66
+ print(response.status_code)
67
+
68
+ ## Input formatting for the given author (reviewer)
69
+ # Extracting text from a link
70
+
71
+ def get_text_from_author_id(author_id, max_count=100):
72
+ if author_id is None:
73
+ raise ValueError('Input valid author ID')
74
+ author_id = str(author_id)
75
+ # author_id = '1737249'
76
+ url = "https://api.semanticscholar.org/graph/v1/author/%s?fields=url,name,paperCount,papers,papers.title,papers.abstract"%author_id
77
+ r = requests.get(url)
78
+ if r.status_code == 404:
79
+ raise ValueError('Input valid author ID')
80
+ data = r.json()
81
+ papers = data['papers'][:max_count]
82
+ name = data['name']
83
+
84
+ return name, papers
85
+
86
+ ## TODO Preprocess Extracted Texts from PDFs
87
+ # Get a portion of the text for actual task
88
+
89
+ def get_title(text):
90
+ pass
91
+
92
+ def get_abstract(text):
93
+ pass
94
+
95
+ def get_introduction(text):
96
+ pass
97
+
98
+ def get_conclusion(text):
99
+ pass
100
+
101
+
102
+ if __name__ == '__main__':
103
+ def run_sample():
104
+ url = 'https://arxiv.org/abs/2105.06506'
105
+ text = get_text_from_url(url)
106
+ assert(text[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
107
+
108
+ text2 = get_text_from_url('https://arxiv.org/pdf/2105.06506.pdf')
109
+ assert(text2[0].split('\n')[0] == 'Sanity Simulations for Saliency Methods')
110
+
111
+ # text = get_text_from_url('https://arxiv.org/paetseths.pdf')
112
+
113
+ # test the code
114
+ run_sample()
requirements.txt ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==3.19.1
2
+ huggingface-hub==0.8.1
3
+ nltk==3.7
4
+ numpy==1.21.6
5
+ py-pdf-parser==0.10.2
6
+ py-rouge==1.1
7
+ pypdf==3.3.0
8
+ pyrogue==0.0.2
9
+ requests==2.28.1
10
+ rouge-score==0.1.2
11
+ scikit-learn==1.0.2
12
+ scipy==1.7.3
13
+ scs==2.1.4
14
+ seaborn==0.11.2
15
+ segtok==1.5.11
16
+ semanticscholar==0.3.2
17
+ sentence-transformers==2.2.0
18
+ sentencepiece==0.1.96
19
+ sentry-sdk==1.9.0
20
+ setproctitle==1.3.0
21
+ shap==0.40.0
22
+ shapely==2.0.0
23
+ shortuuid==1.0.9
24
+ six @ file:///tmp/build/80754af9/six_1623709665295/work
25
+ sklearn==0.0
26
+ slicer==0.0.7
27
+ smart-open==5.2.1
28
+ smmap==5.0.0
29
+ sniffio==1.2.0
30
+ spacy==3.0.8
31
+ spacy-legacy==3.0.9
32
+ spacy-loggers==1.0.3
33
+ sqlitedict==2.0.0
34
+ srsly==2.4.4
35
+ starlette==0.22.0
36
+ statsmodels==0.13.2
37
+ tabulate==0.8.9
38
+ tea==0.1.4
39
+ tea-client==0.0.7
40
+ tea-console==0.0.6
41
+ tenacity==8.1.0
42
+ tensorboardX==2.5.1
43
+ termcolor==1.1.0
44
+ terminado==0.9.4
45
+ testpath @ file:///tmp/build/80754af9/testpath_1624638946665/work
46
+ text-unidecode==1.3
47
+ thinc==8.0.17
48
+ threadpoolctl==2.2.0
49
+ tifffile==2021.11.2
50
+ tld==0.10
51
+ tokenizers==0.10.3
52
+ tomli==2.0.1
53
+ toolz==0.12.0
54
+ torch==1.9.0
55
+ torchaudio==0.10.2
56
+ torchdata==0.3.0
57
+ torchtext==0.12.0
58
+ torchvision==0.8.2
59
+ tornado @ file:///tmp/build/80754af9/tornado_1606942283357/work
60
+ tqdm==4.62.2
61
+ traitlets==5.3.0
62
+ transformers==4.3.3
63
+ transformers-interpret==0.5.2
64
+ typer==0.3.2
65
+ typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1624965014186/work
66
+ tzlocal==2.1
67
+ uc-micro-py==1.0.1
68
+ urllib3==1.26.6
69
+ uvicorn==0.20.0
70
+ Wand==0.6.10
71
+ wandb==0.12.21
72
+ wasabi==0.10.1
73
+ wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
74
+ webencodings==0.5.1
75
+ websockets==10.4
76
+ Werkzeug==2.2.2
77
+ widgetsnbextension==3.5.1
78
+ Wikipedia-API==0.5.4
79
+ word2number==1.1
80
+ wrapt==1.12.1
81
+ xxhash==2.0.2
82
+ yarl==1.7.2
83
+ zipp @ file:///tmp/build/80754af9/zipp_1625570634446/work
score.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import util
2
+ from nltk.tokenize import sent_tokenize
3
+ import torch
4
+ import numpy as np
5
+
6
+ def compute_sentencewise_scores(model, query_sents, candidate_sents):
7
+ # list of sentences from query and candidate
8
+
9
+ q_v, c_v = get_embedding(model, query_sents, candidate_sents)
10
+ return util.cos_sim(q_v, c_v)
11
+
12
+ def get_embedding(model, query_sents, candidate_sents):
13
+
14
+ q_v = model.encode(query_sents)
15
+ c_v = model.encode(candidate_sents)
16
+
17
+ return q_v, c_v
18
+
19
+ def get_top_k(score_mat, K=3):
20
+ """
21
+ Pick top K sentences to show
22
+ """
23
+ idx = torch.argsort(-score_mat)
24
+ picked_sent = idx[:,:K]
25
+ picked_scores = torch.vstack(
26
+ [score_mat[i,picked_sent[i]] for i in range(picked_sent.shape[0])]
27
+ )
28
+
29
+ return picked_sent, picked_scores
30
+
31
+ def get_words(sent):
32
+ words = []
33
+ sent_start_id = [] # keep track of the word index where the new sentence starts
34
+ counter = 0
35
+ for x in sent:
36
+ w = x.split()
37
+ nw = len(w)
38
+ counter += nw
39
+ words.append(w)
40
+ sent_start_id.append(counter)
41
+ words = [x.split() for x in sent]
42
+ all_words = [item for sublist in words for item in sublist]
43
+ sent_start_id.pop()
44
+ sent_start_id = [0] + sent_start_id
45
+ assert(len(sent_start_id) == len(sent))
46
+ return words, all_words, sent_start_id
47
+
48
+ def mark_words(words, all_words, sent_start_id, sent_ids, sent_scores):
49
+ num_query_sent = sent_ids.shape[0]
50
+ num_words = len(all_words)
51
+
52
+ output = dict()
53
+ output['all_words'] = all_words
54
+ output['words_by_sentence'] = words
55
+
56
+ # for each query sentence, mark the highlight information
57
+ for i in range(num_query_sent):
58
+ is_selected_sent = np.zeros(num_words)
59
+ is_selected_phrase = np.zeros(num_words)
60
+ word_scores = np.zeros(num_words) + 1e-4
61
+
62
+ # get sentence selection information
63
+ for sid, sscore in zip(sent_ids[i], sent_scores[i]):
64
+ #print(len(sent_start_id), sid, sid+1)
65
+ if sid+1 < len(sent_start_id):
66
+ sent_range = (sent_start_id[sid], sent_start_id[sid+1])
67
+ is_selected_sent[sent_range[0]:sent_range[1]] = 1
68
+ word_scores[sent_range[0]:sent_range[1]] = sscore
69
+ else:
70
+ is_selected_sent[sent_range[0]:] = 1
71
+ word_scores[sent_range[0]:] = sscore
72
+
73
+ # TODO get phrase selection information
74
+ output[i] = {
75
+ 'is_selected_sent': is_selected_sent,
76
+ 'is_selected_phrase': is_selected_phrase,
77
+ 'scores': word_scores
78
+ }
79
+
80
+ return output
81
+
82
+ def get_highlight_info(model, text1, text2, K=3):
83
+ sent1 = sent_tokenize(text1) # query
84
+ sent2 = sent_tokenize(text2) # candidate
85
+ score_mat = compute_sentencewise_scores(model, sent1, sent2)
86
+
87
+ sent_ids, sent_scores = get_top_k(score_mat, K=K)
88
+ #print(sent_ids, sent_scores)
89
+ words1, all_words1, sent_start_id1 = get_words(sent2)
90
+ #print(all_words1, sent_start_id1)
91
+ info = mark_words(words1, all_words1, sent_start_id1, sent_ids, sent_scores)
92
+
93
+ return sent_ids, sent_scores, info
94
+
95
+ ## Document-level operations
96
+ def predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=20):
97
+
98
+ # concatenate title and abstract
99
+ title_abs = []
100
+ for t, a in zip(titles, abstracts):
101
+ if t is not None and a is not None:
102
+ title_abs.append(t + ' [SEP] ' + a)
103
+
104
+ num_docs = len(title_abs)
105
+ no_iter = int(np.ceil(num_docs / batch))
106
+
107
+ # preprocess the input
108
+ scores = []
109
+ with torch.no_grad():
110
+ # batch
111
+ for i in range(no_iter):
112
+ inputs = tokenizer(
113
+ [query] + title_abs[i*batch:(i+1)*batch],
114
+ padding=True,
115
+ truncation=True,
116
+ return_tensors="pt",
117
+ max_length=512
118
+ )
119
+ inputs.to(doc_model.device)
120
+ result = doc_model(**inputs)
121
+
122
+ # take the first token in the batch as the embedding
123
+ embeddings = result.last_hidden_state[:, 0, :].detach().cpu().numpy()
124
+
125
+ # compute cosine similarity
126
+ q_emb = embeddings[0,:]
127
+ p_emb = embeddings[1:,:]
128
+ nn = np.linalg.norm(q_emb) * np.linalg.norm(p_emb, axis=1)
129
+ scores += list(np.dot(p_emb, q_emb) / nn)
130
+
131
+ assert(len(scores) == num_docs)
132
+
133
+ return scores
134
+
135
+ def compute_overall_score(doc_model, tokenizer, query, papers, batch=5):
136
+ scores = []
137
+ titles = []
138
+ abstracts = []
139
+ for p in papers:
140
+ titles.append(p['title'])
141
+ abstracts.append(p['abstract'])
142
+ scores = predict_docscore(doc_model, tokenizer, query, titles, abstracts, batch=batch)
143
+ idx_sorted = np.argsort(scores)[::-1]
144
+
145
+ titles_sorted = [titles[x] for x in idx_sorted]
146
+ abstracts_sorted = [abstracts[x] for x in idx_sorted]
147
+ scores_sorted = [scores[x] for x in idx_sorted]
148
+
149
+ return titles_sorted, abstracts_sorted, scores_sorted