malteos commited on
Commit
00e4f69
β€’
1 Parent(s): 8f50a7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -27
app.py CHANGED
@@ -45,14 +45,15 @@ st.set_page_config(
45
  }
46
  )
47
 
48
- aspects = [
49
- 'task', 'method', 'dataset'
50
- ]
 
 
 
51
  tokenizer_name_or_path = f'malteos/aspect-scibert-{aspects[0]}' # any aspect
52
  dataset_config = 'malteos/aspect-paper-metadata'
53
 
54
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
55
-
56
 
57
  @st.cache(show_spinner=False)
58
  def st_load_model(name_or_path):
@@ -63,7 +64,7 @@ def st_load_model(name_or_path):
63
 
64
  @st.cache(show_spinner=False)
65
  def st_load_dataset(name_or_path):
66
- with st.spinner('Loading the dataset (this might take a while)...'):
67
  dataset = load_dataset(name_or_path)
68
 
69
  if isinstance(dataset, DatasetDict):
@@ -84,6 +85,7 @@ aspect_to_model = dict(
84
  dataset = st_load_dataset(dataset_config)
85
 
86
 
 
87
  def get_paper(doc_id):
88
  res = requests.get(f'https://api.semanticscholar.org/v1/paper/{doc_id}')
89
 
@@ -93,32 +95,35 @@ def get_paper(doc_id):
93
  raise ValueError(f'Cannot load paper from S2 API: {doc_id}')
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def find_related_papers(paper_id, user_aspect):
97
  with st.spinner('Searching for related papers...'):
 
98
 
99
  paper = get_paper(paper_id)
100
 
101
  if paper is None or 'title' not in paper or paper['title'] is None or 'abstract' not in paper or paper['abstract'] is None:
102
- raise ValueError(f'Could not retrieve title and abstract for input paper: {paper_id}')
103
 
104
  title_abs = paper['title'] + ': ' + paper['abstract']
105
 
106
- # preprocess the input
107
- inputs = tokenizer(title_abs, padding=True, truncation=True, return_tensors="pt", max_length=512)
108
-
109
- # inference
110
- outputs = aspect_to_model[user_aspect](**inputs)
111
-
112
- # logger.info(f'attention_mask: {inputs["attention_mask"].shape}')
113
- #
114
- # logger.info(f'Outputs: {outputs["last_hidden_state"]}')
115
- # logger.info(f'Outputs: {outputs["last_hidden_state"].shape}')
116
-
117
- # Mean pool the token-level embeddings to get sentence-level embeddings
118
- embeddings = torch.sum(
119
- outputs["last_hidden_state"] * inputs['attention_mask'].unsqueeze(-1), dim=1
120
- ) / torch.clamp(torch.sum(inputs['attention_mask'], dim=1, keepdims=True), min=1e-9)
121
-
122
  result = dict(
123
  paper=paper,
124
  aspect=user_aspect,
@@ -129,7 +134,7 @@ def find_related_papers(paper_id, user_aspect):
129
  ))
130
 
131
  # Retrieval
132
- prompt = embeddings.detach().numpy()[0]
133
  scores, retrieved_examples = dataset.get_nearest_examples(f'{user_aspect}_embeddings', prompt, k=10)
134
 
135
  result.update(dict(
@@ -144,9 +149,9 @@ st.title('Aspect-based Paper Similarity')
144
  st.markdown("""This demo showcases [Specialized Document Embeddings for Aspect-based Research Paper Similarity](#TODO).""")
145
 
146
  # Introduction
147
- st.markdown(f"""The model was trained using a triplet loss on machine learning papers from the [paperswithcode.com](https://paperswithcode.com/) corpus with the objective of pulling embeddings of papers with the same task, method, or datasetclose together.
148
  For a more comprehensive overview of the model check out the [model card on πŸ€— Model Hub]({model_hub_url}) or read [our paper](#TODO).""")
149
- st.markdown("""Enter a ArXiv ID or a DOI of a paper for that you want find similar papers.
150
 
151
  Try it yourself! πŸ‘‡""",
152
  unsafe_allow_html=True)
@@ -165,6 +170,8 @@ with st.form("aspect-input", clear_on_submit=False):
165
  "ACL:N19-1423": "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
166
  "10.18653/v1/S16-1001": "SemEval-2016 Task 4: Sentiment Analysis in Twitter",
167
  "10.1145/3065386": "ImageNet classification with deep convolutional neural networks",
 
 
168
  }
169
 
170
  example = st.selectbox(
@@ -175,7 +182,8 @@ with st.form("aspect-input", clear_on_submit=False):
175
 
176
  user_aspect = st.radio(
177
  label="In what aspect are you interested?",
178
- options=aspects
 
179
  )
180
 
181
  cols = st.columns(3)
 
45
  }
46
  )
47
 
48
+ aspect_labels = {
49
+ 'task': 'Task 🎯 ',
50
+ 'method': 'Method πŸ”¨ ',
51
+ 'dataset': 'Dataset 🏷️',
52
+ }
53
+ aspects = list(aspect_labels.keys())
54
  tokenizer_name_or_path = f'malteos/aspect-scibert-{aspects[0]}' # any aspect
55
  dataset_config = 'malteos/aspect-paper-metadata'
56
 
 
 
57
 
58
  @st.cache(show_spinner=False)
59
  def st_load_model(name_or_path):
 
64
 
65
  @st.cache(show_spinner=False)
66
  def st_load_dataset(name_or_path):
67
+ with st.spinner('Loading the dataset and search index (this might take a while)...'):
68
  dataset = load_dataset(name_or_path)
69
 
70
  if isinstance(dataset, DatasetDict):
 
85
  dataset = st_load_dataset(dataset_config)
86
 
87
 
88
+ @st.cache(show_spinner=False)
89
  def get_paper(doc_id):
90
  res = requests.get(f'https://api.semanticscholar.org/v1/paper/{doc_id}')
91
 
 
95
  raise ValueError(f'Cannot load paper from S2 API: {doc_id}')
96
 
97
 
98
+ def get_embedding(input_text, user_aspect):
99
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
100
+
101
+ # preprocess the input
102
+ inputs = tokenizer(input_text, padding=True, truncation=True, return_tensors="pt", max_length=512)
103
+
104
+ # inference
105
+ outputs = aspect_to_model[user_aspect](**inputs)
106
+
107
+ # Mean pool the token-level embeddings to get sentence-level embeddings
108
+ embeddings = torch.sum(
109
+ outputs["last_hidden_state"] * inputs['attention_mask'].unsqueeze(-1), dim=1
110
+ ) / torch.clamp(torch.sum(inputs['attention_mask'], dim=1, keepdims=True), min=1e-9)
111
+
112
+ return embeddings.detach().numpy()[0]
113
+
114
+
115
+ @st.cache(show_spinner=False)
116
  def find_related_papers(paper_id, user_aspect):
117
  with st.spinner('Searching for related papers...'):
118
+ paper_id = paper_id.strip() # remove white spaces
119
 
120
  paper = get_paper(paper_id)
121
 
122
  if paper is None or 'title' not in paper or paper['title'] is None or 'abstract' not in paper or paper['abstract'] is None:
123
+ raise ValueError(f'Could not retrieve title and abstract for input paper (the paper is probably behind a paywall): {paper_id}')
124
 
125
  title_abs = paper['title'] + ': ' + paper['abstract']
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  result = dict(
128
  paper=paper,
129
  aspect=user_aspect,
 
134
  ))
135
 
136
  # Retrieval
137
+ prompt = get_embedding(title_abs, user_aspect)
138
  scores, retrieved_examples = dataset.get_nearest_examples(f'{user_aspect}_embeddings', prompt, k=10)
139
 
140
  result.update(dict(
 
149
  st.markdown("""This demo showcases [Specialized Document Embeddings for Aspect-based Research Paper Similarity](#TODO).""")
150
 
151
  # Introduction
152
+ st.markdown(f"""The model was trained using a triplet loss on machine learning papers from the [paperswithcode.com](https://paperswithcode.com/) corpus with the objective of pulling embeddings of papers with the same task, method, or dataset close together.
153
  For a more comprehensive overview of the model check out the [model card on πŸ€— Model Hub]({model_hub_url}) or read [our paper](#TODO).""")
154
+ st.markdown("""Enter a ArXiv ID or a DOI of a paper for that you want find similar papers. The title and abstract of the input paper must be available through the [Semantic Scholar API](https://www.semanticscholar.org/product/api).
155
 
156
  Try it yourself! πŸ‘‡""",
157
  unsafe_allow_html=True)
 
170
  "ACL:N19-1423": "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding",
171
  "10.18653/v1/S16-1001": "SemEval-2016 Task 4: Sentiment Analysis in Twitter",
172
  "10.1145/3065386": "ImageNet classification with deep convolutional neural networks",
173
+ "arXiv:2101.08700": "Multi-sense embeddings through a word sense disambiguation process",
174
+ "10.1145/3340531.3411878": "Incremental and parallel computation of structural graph summaries for evolving graphs",
175
  }
176
 
177
  example = st.selectbox(
 
182
 
183
  user_aspect = st.radio(
184
  label="In what aspect are you interested?",
185
+ options=aspects,
186
+ format_func=lambda option_key: aspect_labels[option_key],
187
  )
188
 
189
  cols = st.columns(3)