malteos commited on
Commit
a1866c7
β€’
1 Parent(s): 18556fd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ Run via: streamlit run app.py
4
+
5
+ """
6
+
7
+ import json
8
+ import logging
9
+
10
+ import requests
11
+ import streamlit as st
12
+ import torch
13
+ from datasets import load_dataset
14
+ from datasets.dataset_dict import DatasetDict
15
+ from transformers import AutoTokenizer, AutoModel
16
+
17
+ logging.basicConfig(
18
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
19
+ datefmt="%Y-%m-%d %H:%M:%S",
20
+ level=logging.INFO,
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ model_hub_url = 'https://huggingface.co/malteos/aspect-scibert-task'
25
+
26
+ about_page_markdown = f"""# πŸ” Find Papers With Similar Task
27
+
28
+ See
29
+ - GitHub: https://github.com/malteos/aspect-document-embeddings
30
+ - Paper: #TODO
31
+ - Model hub: https://huggingface.co/malteos/aspect-scibert-task
32
+
33
+ """
34
+
35
+ # Page setup
36
+ st.set_page_config(
37
+ page_title="Papers with similar Task",
38
+ page_icon="πŸ”",
39
+ layout="centered",
40
+ initial_sidebar_state="auto",
41
+ menu_items={
42
+ 'Get help': None,
43
+ 'Report a bug': None,
44
+ 'About': about_page_markdown,
45
+ }
46
+ )
47
+
48
+ aspects = [
49
+ 'task', 'method', 'dataset'
50
+ ]
51
+ tokenizer_name_or_path = f'malteos/aspect-scibert-{aspects[0]}' # any aspect
52
+ dataset_config = 'malteos/aspect-paper-metadata'
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
55
+
56
+
57
+ @st.cache(show_spinner=False)
58
+ def st_load_model(name_or_path):
59
+ with st.spinner(f'Loading the model `{name_or_path}` (this might take a while)...'):
60
+ model = AutoModel.from_pretrained(name_or_path)
61
+ return model
62
+
63
+
64
+ @st.cache(show_spinner=False)
65
+ def st_load_dataset(name_or_path):
66
+ with st.spinner('Loading the dataset (this might take a while)...'):
67
+ dataset = load_dataset(name_or_path)
68
+
69
+ if isinstance(dataset, DatasetDict):
70
+ dataset = dataset['train']
71
+
72
+ # load existing faiss
73
+ for a in aspects:
74
+ dataset.load_faiss_index(f'{a}_embeddings', f'{a}_embeddings.faiss')
75
+
76
+ # add faiss
77
+ #dataset.add_faiss_index(column=f'{aspect}_embeddings')
78
+ #loaded_dataset.add_faiss_index(column='method_embeddings')
79
+ #loaded_dataset.add_faiss_index(column='dataset_embeddings')
80
+
81
+ return dataset
82
+
83
+
84
+ aspect_to_model = dict(
85
+ task=st_load_model('malteos/aspect-scibert-task'),
86
+ method=st_load_model('malteos/aspect-scibert-method'),
87
+ dataset=st_load_model('malteos/aspect-scibert-dataset'),
88
+ )
89
+ dataset = st_load_dataset(dataset_config)
90
+
91
+
92
+ def get_paper(doc_id):
93
+ res = requests.get(f'https://api.semanticscholar.org/v1/paper/{doc_id}')
94
+
95
+ if res.status_code == 200:
96
+ return res.json()
97
+ else:
98
+ raise ValueError(f'Cannot load paper from S2 API: {doc_id}')
99
+
100
+
101
+ def find_related_papers(paper_id, user_aspect):
102
+ # Add result to session
103
+
104
+ paper = get_paper(paper_id)
105
+
106
+ if paper is None or 'title' not in paper or 'abstract' not in paper:
107
+ raise ValueError('Could not retrieve data for input paper')
108
+
109
+ title_abs = paper['title'] + ': ' + paper['abstract']
110
+
111
+ # preprocess the input
112
+ inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
113
+
114
+ # inference
115
+ outputs = aspect_to_model[user_aspect](**inputs)
116
+
117
+ # logger.info(f'attention_mask: {inputs["attention_mask"].shape}')
118
+ #
119
+ # logger.info(f'Outputs: {outputs["last_hidden_state"]}')
120
+ # logger.info(f'Outputs: {outputs["last_hidden_state"].shape}')
121
+
122
+ # Mean pool the token-level embeddings to get sentence-level embeddings
123
+ embeddings = torch.sum(
124
+ outputs["last_hidden_state"] * inputs['attention_mask'].unsqueeze(-1), dim=1
125
+ ) / torch.clamp(torch.sum(inputs['attention_mask'], dim=1, keepdims=True), min=1e-9)
126
+
127
+ result = dict(
128
+ paper=paper,
129
+ aspect=user_aspect,
130
+ )
131
+
132
+ result.update(dict(
133
+ #embeddings=embeddings.tolist(),
134
+ ))
135
+
136
+ # Retrieval
137
+ prompt = embeddings.detach().numpy()[0]
138
+ scores, retrieved_examples = dataset.get_nearest_examples(f'{user_aspect}_embeddings', prompt, k=10)
139
+
140
+ result.update(dict(
141
+ related_papers=retrieved_examples,
142
+ ))
143
+
144
+ # st.session_state.results.append(result)
145
+
146
+ return result
147
+
148
+
149
+ # # Start session
150
+ # if 'results' not in st.session_state:
151
+ # st.session_state.results = []
152
+
153
+ # Page
154
+ st.title('Aspect-based Paper Similarity')
155
+ st.markdown("""This demo showcases [Specialized Document Embeddings for Aspect-based Research Paper Similarity](#TODO).""")
156
+
157
+ # Introduction
158
+ st.markdown(f"""The model was trained using a triplet loss on machine learning papers from the [paperswithcode.com](https://paperswithcode.com/) corpus with the objective of pulling embeddings of papers with the same task, method, or datasetclose together. For a more comprehensive overview of the model check out the [model card on πŸ€— Model Hub]({model_hub_url}) or read [our paper](#TODO).
159
+ """)
160
+ st.markdown("""Enter a ArXiv ID or a DOI of a paper for that you want find similar papers.
161
+
162
+ Try it yourself! πŸ‘‡""",
163
+ unsafe_allow_html=True)
164
+
165
+ # Demo
166
+ with st.form("aspect-input", clear_on_submit=False):
167
+ paper_id = st.text_input(
168
+ label='Enter paper ID (format "arXiv:<arxiv_id>", "<doi>", or "ACL:<acl_id>"):',
169
+ # value="arXiv:2202.06671",
170
+ placeholder='Any DOI, ACL, or ArXiv ID'
171
+ )
172
+
173
+ example = st.selectbox(
174
+ label='Or select example',
175
+ options=[
176
+ "arXiv:2202.06671",
177
+ '10.1016/j.eswa.2019.06.026'
178
+ ]
179
+ )
180
+
181
+ # click_clear = st.button('clear text input', key=1)
182
+ # if click_clear:
183
+ # paper_id = st.text_input(
184
+ # label='Enter paper ID (arXiv:<arxiv_id>, or <doi>):', value="XXX", placeholder='123')
185
+
186
+ user_aspect = st.radio(
187
+ label="In what aspect are you interested?",
188
+ options=aspects
189
+ )
190
+
191
+ cols = st.columns(3)
192
+ submitted = cols[1].form_submit_button("Find related papers")
193
+
194
+ # Listener
195
+ if submitted:
196
+ if paper_id or example:
197
+ with st.spinner('Finding related papers...'):
198
+ try:
199
+ result = find_related_papers(paper_id if paper_id else example, user_aspect)
200
+
201
+ input_paper = result['paper']
202
+ related_papers = result['related_papers']
203
+
204
+ # with st.empty():
205
+
206
+ st.markdown(
207
+ f'''Your input paper: \n\n<a href="{input_paper['url']}"><b>{input_paper['title']}</b></a> ({input_paper['year']})<hr />''',
208
+ unsafe_allow_html=True)
209
+
210
+ related_html = '<ul>'
211
+
212
+ for i in range(len(related_papers['paper_id'])):
213
+ related_html += f'''<li><a href="{related_papers['url_abs'][i]}">{related_papers['title'][i]}</a></li>'''
214
+
215
+ related_html += '</ul>'
216
+
217
+ st.markdown(f'''Related papers with similar {result['aspect']}: {related_html}''', unsafe_allow_html=True)
218
+
219
+ except (TypeError, ValueError, KeyError) as e:
220
+ st.error(f'**Error**: {e}')
221
+
222
+ else:
223
+ st.error('**Error**: No paper ID provided. Please provide a ArXiv ID or DOI.')
224
+
225
+ # # Results
226
+ # if 'results' in st.session_state and st.session_state.results:
227
+ # first = True
228
+ # for result in st.session_state.results[::-1]:
229
+ # if not first:
230
+ # st.markdown("---")
231
+ # # st.markdown(f"ID:\n> {result['paperId']}")
232
+ # # col_1, col_2, col_3 = st.columns([1,2,2])
233
+ # # col_1.metric(label='', value=json.dumps(result))
234
+ # # col_2.metric(label='Label', value=f"fooo")
235
+ # # col_3.metric(label='Score', value=f"123")
236
+ # input_paper = result['paper']
237
+ # related_papers = result['related_papers']
238
+ #
239
+ # # with st.empty():
240
+ #
241
+ # st.markdown(f'''Your input paper: \n\n<a href="{input_paper['url']}"><b>{input_paper['title']}</b></a> ({input_paper['year']})<hr />''', unsafe_allow_html=True)
242
+ #
243
+ # related_html = '<ul>'
244
+ #
245
+ # for i in range(len(related_papers['paper_id'])):
246
+ # related_html += f'''<li><a href="{related_papers['url_abs'][i]}">{related_papers['title'][i]}</a></li>'''
247
+ #
248
+ # related_html += '</ul>'
249
+ #
250
+ # st.markdown(f'''Related papers with similar {result['aspect']}: {related_html}''', unsafe_allow_html=True)
251
+ #
252
+ # # st.markdown(f'''Related papers: {related_html}''', unsafe_allow_html=True)
253
+ #
254
+ # first = False