Sophie commited on
Commit
c38e4ce
·
1 Parent(s): 6de8c39

added filtering and citation counts

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +242 -39
src/streamlit_app.py CHANGED
@@ -6,6 +6,10 @@ import os
6
  import boto3
7
  import psycopg2
8
  from psycopg2.extensions import connection
 
 
 
 
9
  from dotenv import load_dotenv
10
  from latex_clean import clean_latex_for_display
11
 
@@ -50,6 +54,11 @@ ALLOWED_TYPES = [
50
  "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
51
  ]
52
 
 
 
 
 
 
53
  # Load the Embedding Model
54
  @st.cache_resource
55
  def load_model():
@@ -63,7 +72,6 @@ def load_model():
63
  st.error(f"Error loading the embedding model: {e}")
64
  return None
65
 
66
-
67
  # Load Data from RDS
68
  @st.cache_data
69
  def load_papers_from_rds():
@@ -129,6 +137,25 @@ def load_papers_from_rds():
129
  elif isinstance(embedding, np.ndarray):
130
  embedding = embedding.astype(np.float32)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  all_theorems_data.append({
133
  "paper_id": paper_id,
134
  "authors": authors,
@@ -136,11 +163,15 @@ def load_papers_from_rds():
136
  "paper_url": link,
137
  "year": last_updated.year,
138
  "primary_category": primary_category,
 
 
 
 
139
  "theorem_name": theorem_name,
140
  "theorem_slogan": theorem_slogan,
141
  "theorem_body": theorem_body,
142
  "global_context": global_context,
143
- "stored_embedding": embedding
144
  })
145
 
146
  return all_theorems_data
@@ -149,63 +180,181 @@ def load_papers_from_rds():
149
  st.error(f"Error loading data from RDS: {e}")
150
  return []
151
 
152
-
153
- # --- 3. The Search Function ---
154
- def search_theorems(query, model, theorems_data, embeddings_db):
155
  """
156
- Takes a user query and finds the top 10 most similar theorems.
 
 
 
 
157
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if not query:
159
  st.info("Please enter a search query.")
160
  return
 
 
 
161
 
162
  query_embedding = model.encode(query, convert_to_tensor=True)
163
  cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
164
- top_results_indices = np.argsort(-cosine_scores.cpu())[:10]
165
-
166
- st.subheader("Top 5 Most Similar Theorems")
167
 
168
- if len(top_results_indices) == 0:
169
- st.write("No results found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  return
171
 
172
- for i, idx in enumerate(top_results_indices):
173
- idx = idx.item()
174
- similarity = cosine_scores[idx].item()
175
- theorem_info = theorems_data[idx]
176
-
177
- expander_title = f"**Result {i+1} | Similarity: {similarity:.4f}**"
178
- if theorem_info.get("theorem_name"):
179
- expander_title += f" | {theorem_info['theorem_name']}"
180
-
181
  with st.expander(expander_title):
182
- st.markdown(f"**Paper:** {theorem_info.get('paper_title', 'Unknown')}")
183
- st.markdown(f"**Authors:** {', '.join(theorem_info['authors']) if theorem_info['authors'] else 'N/A'}")
184
- st.markdown(f"**Source:** [{theorem_info['paper_url']}]({theorem_info['paper_url']})")
 
 
185
  st.markdown(
186
- f"**Math Tag:** `{theorem_info['primary_category']}` | **Year:** {theorem_info.get('year', 'N/A')}")
 
 
 
187
  st.markdown("---")
188
 
189
- if theorem_info.get("theorem_slogan"):
190
- st.markdown(f"**Slogan:** {theorem_info['theorem_slogan']}")
191
- st.write("")
192
 
193
- if theorem_info["global_context"]:
194
- cleaned_ctx = clean_latex_for_display(theorem_info["global_context"])
195
- blockquote_context = "> " + cleaned_ctx.replace("\n", "\n> ")
196
- st.markdown(blockquote_context)
197
- st.write("")
198
 
199
- cleaned_content = clean_latex_for_display(theorem_info['theorem_body'])
200
- st.markdown(f"**Theorem Body:**")
201
  st.markdown(cleaned_content)
202
 
203
  # --- Main App Interface ---
204
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
205
  st.title("📚 Semantic Theorem Search")
206
  st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
207
- st.markdown("*Note: Linking to a specific page within an arXiv PDF is not directly possible.*",
208
- help="arXiv links redirect to the paper's abstract, not a specific page in the PDF.")
209
 
210
  model = load_model()
211
  theorems_data = load_papers_from_rds()
@@ -214,11 +363,65 @@ if model and theorems_data:
214
  with st.spinner("Preparing embeddings from database..."):
215
  corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
216
 
217
- st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv. Ready to search!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  user_query = st.text_input("Enter your query:", "")
220
-
221
  if st.button("Search") or user_query:
222
- search_theorems(user_query, model, theorems_data, corpus_embeddings)
223
  else:
224
  st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")
 
6
  import boto3
7
  import psycopg2
8
  from psycopg2.extensions import connection
9
+ import torch
10
+ import re
11
+ import requests
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
  from dotenv import load_dotenv
14
  from latex_clean import clean_latex_for_display
15
 
 
54
  "theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"
55
  ]
56
 
57
+ ARXIV_ID_RE = re.compile(
58
+ r'arxiv\.org/(?:abs|pdf)/((?:\d{4}\.\d{4,5}|[a-z\-]+/\d{7}))(?:v\d+)?',
59
+ re.IGNORECASE
60
+ )
61
+
62
  # Load the Embedding Model
63
  @st.cache_resource
64
  def load_model():
 
72
  st.error(f"Error loading the embedding model: {e}")
73
  return None
74
 
 
75
  # Load Data from RDS
76
  @st.cache_data
77
  def load_papers_from_rds():
 
137
  elif isinstance(embedding, np.ndarray):
138
  embedding = embedding.astype(np.float32)
139
 
140
+ # Determine source from url
141
+ link_str = link or ""
142
+ if link_str.startswith("http://arxiv.org") or link_str.startswith("https://arxiv.org"):
143
+ source = "arXiv"
144
+ else:
145
+ source = "Stacks Project"
146
+
147
+ # Determine type from name
148
+ def infer_type(name: str) -> str:
149
+ if not name:
150
+ return "theorem"
151
+ lower = name.lower()
152
+ for t in ["theorem", "lemma", "proposition", "corollary", "definition", "remark", "assumption"]:
153
+ if t in lower:
154
+ return t
155
+ return "theorem"
156
+
157
+ inferred_type = infer_type(theorem_name or "")
158
+
159
  all_theorems_data.append({
160
  "paper_id": paper_id,
161
  "authors": authors,
 
163
  "paper_url": link,
164
  "year": last_updated.year,
165
  "primary_category": primary_category,
166
+ "source": source,
167
+ "type": inferred_type,
168
+ "journal_published": bool(journal_ref),
169
+ "citations": None,
170
  "theorem_name": theorem_name,
171
  "theorem_slogan": theorem_slogan,
172
  "theorem_body": theorem_body,
173
  "global_context": global_context,
174
+ "stored_embedding": embedding,
175
  })
176
 
177
  return all_theorems_data
 
180
  st.error(f"Error loading data from RDS: {e}")
181
  return []
182
 
183
+ @st.cache_data(ttl=60*60*24) # cache for 24 hours
184
+ def fetch_citations(paper_url: str, title: str) -> int | None:
 
185
  """
186
+ Returns citation count if found, else None.
187
+ Tries the following sources in order:
188
+ 1) OpenAlex by arXiv id
189
+ 2) Semantic Scholar by arXiv id
190
+ 3) Semantic Scholar by title
191
  """
192
+ arx_id = None
193
+ if paper_url:
194
+ m = ARXIV_ID_RE.search(paper_url)
195
+ if m:
196
+ arx_id = m.group(1)
197
+ # OpenAlex by arXiv id
198
+ if arx_id:
199
+ try:
200
+ r = requests.get(f"https://api.openalex.org/works/arXiv:{arx_id}", timeout=10)
201
+ if r.ok:
202
+ data = r.json()
203
+ c = data.get("cited_by_count")
204
+ if isinstance(c, int):
205
+ return c
206
+ except Exception:
207
+ pass
208
+ # Semantic Scholar by arXiv id
209
+ if arx_id:
210
+ try:
211
+ r = requests.get(
212
+ f"https://api.semanticscholar.org/graph/v1/paper/arXiv:{arx_id}",
213
+ params={"fields": "citationCount"},
214
+ timeout=10
215
+ )
216
+ if r.ok:
217
+ j = r.json()
218
+ c = j.get("citationCount")
219
+ if isinstance(c, int):
220
+ return c
221
+ except Exception:
222
+ pass
223
+ # Fallback: Semantic Scholar by title
224
+ if title:
225
+ try:
226
+ r = requests.get(
227
+ "https://api.semanticscholar.org/graph/v1/paper/search",
228
+ params={"query": title, "limit": 1, "fields": "title,citationCount"},
229
+ timeout=10
230
+ )
231
+ if r.ok:
232
+ j = r.json()
233
+ if j.get("data"):
234
+ c = j["data"][0].get("citationCount")
235
+ if isinstance(c, int):
236
+ return c
237
+ except Exception:
238
+ pass
239
+
240
+ return None
241
+
242
+ def add_citations(candidates: list[dict], max_workers: int = 6) -> None:
243
+ # Select targets with missing citations
244
+ targets = [
245
+ it for it in candidates
246
+ if it.get("source") == "arXiv" and (it.get("citations") in (None, 0))
247
+ ]
248
+ if not targets:
249
+ return
250
+
251
+ with ThreadPoolExecutor(max_workers=max_workers) as exe:
252
+ fut2item = {
253
+ exe.submit(fetch_citations, it.get("paper_url"), it.get("paper_title")): it
254
+ for it in targets
255
+ }
256
+ for fut in as_completed(fut2item):
257
+ it = fut2item[fut]
258
+ try:
259
+ c = fut.result()
260
+ if c is not None:
261
+ it["citations"] = c
262
+ except Exception:
263
+ pass
264
+
265
+ # --- Search and Display ---
266
+ def search_and_display_with_filters(query, model, theorems_data, embeddings_db, filters):
267
  if not query:
268
  st.info("Please enter a search query.")
269
  return
270
+ if not filters['sources']:
271
+ st.warning("Please select at least one source.")
272
+ return
273
 
274
  query_embedding = model.encode(query, convert_to_tensor=True)
275
  cosine_scores = util.cos_sim(query_embedding, embeddings_db)[0]
 
 
 
276
 
277
+ # Get a larger pool to filter from
278
+ top_k_pool = min(200, len(theorems_data))
279
+ top_indices = torch.topk(cosine_scores, k=top_k_pool, sorted=True).indices
280
+ pool_items = [theorems_data[int(i.item())] for i in top_indices]
281
+ add_citations(pool_items)
282
+
283
+ results = []
284
+ low, high = filters['citation_range']
285
+
286
+ # Filter results
287
+ for item in pool_items:
288
+ type_match = (not filters['types']) or (item.get('type','').lower() in filters['types'])
289
+ tag_match = (not filters['tags']) or (item.get('primary_category') in filters['tags'])
290
+ author_match = (not filters['authors']) or any(a in (item.get('authors') or []) for a in filters['authors'])
291
+ source_match = item.get('source') in filters['sources']
292
+
293
+ # Citations & year & journal only meaningful for arXiv
294
+ cit = item.get('citations')
295
+ if cit is None:
296
+ if not filters['include_unknown_citations']:
297
+ continue
298
+ citation_match = True
299
+ else:
300
+ citation_match = (low <= int(cit) <= high)
301
+
302
+ year_match = True
303
+ if filters['year_range'] and item.get('source') == 'arXiv':
304
+ y = item.get('year') or 0
305
+ yr0, yr1 = filters['year_range']
306
+ year_match = (yr0 <= y <= yr1)
307
+
308
+ journal_match = True
309
+ if item.get('source') == 'arXiv':
310
+ status = filters['journal_status']
311
+ jp = bool(item.get('journal_published'))
312
+ if status == "Journal Article":
313
+ journal_match = jp
314
+ elif status == "Preprint Only":
315
+ journal_match = not jp
316
+
317
+ if all([type_match, tag_match, author_match, source_match, citation_match, year_match, journal_match]):
318
+ results.append({"info": item, "similarity": float(cosine_scores[theorems_data.index(item)].item())})
319
+ if len(results) >= filters['top_k']:
320
+ break
321
+
322
+ st.subheader(f"Found {len(results)} Matching Results")
323
+ if not results:
324
+ st.warning("No results found for the current filters.")
325
  return
326
 
327
+ for i, r in enumerate(results):
328
+ info = r["info"]
329
+ expander_title = f"**Result {i+1} | Similarity: {r['similarity']:.4f} | Type: {info.get('type','').title()}**"
 
 
 
 
 
 
330
  with st.expander(expander_title):
331
+ st.markdown(f"**Paper:** *{info.get('paper_title','Unknown')}*")
332
+ st.markdown(f"**Authors:** {', '.join(info.get('authors') or []) or 'N/A'}")
333
+ st.markdown(f"**Source:** {info.get('source')} ([Link]({info.get('paper_url')}))")
334
+ cit = info.get("citations")
335
+ cit_str = "Unknown" if cit is None else str(cit)
336
  st.markdown(
337
+ f"**Math Tag:** `{info.get('primary_category')}` | "
338
+ f"**Citations:** {cit_str} | "
339
+ f"**Year:** {info.get('year', 'N/A')}"
340
+ )
341
  st.markdown("---")
342
 
343
+ if info.get("theorem_slogan"):
344
+ st.markdown(f"**Slogan:** {info['theorem_slogan']}\n")
 
345
 
346
+ if info.get("global_context"):
347
+ cleaned_ctx = clean_latex_for_display(info["global_context"])
348
+ st.markdown("> " + cleaned_ctx.replace("\n", "\n> ") )
 
 
349
 
350
+ cleaned_content = clean_latex_for_display(info['theorem_body'])
351
+ st.markdown("**Theorem Body:**")
352
  st.markdown(cleaned_content)
353
 
354
  # --- Main App Interface ---
355
  st.set_page_config(page_title="Theorem Search Demo", layout="wide")
356
  st.title("📚 Semantic Theorem Search")
357
  st.write("This demo uses a specialized mathematical language model to find theorems semantically similar to your query.")
 
 
358
 
359
  model = load_model()
360
  theorems_data = load_papers_from_rds()
 
363
  with st.spinner("Preparing embeddings from database..."):
364
  corpus_embeddings = np.array([item['stored_embedding'] for item in theorems_data])
365
 
366
+ st.success(f"Successfully loaded {len(theorems_data)} theorems from arXiv and the Stacks Project. Ready to search!")
367
+
368
+ # --- Sidebar filters ---
369
+ with st.sidebar:
370
+ st.header("Search Filters")
371
+
372
+ all_sources = ['arXiv', 'Stacks Project']
373
+ selected_sources = st.multiselect(
374
+ "Filter by Source(s):",
375
+ all_sources,
376
+ default=all_sources[:1] if all_sources else [],
377
+ help="Select one or more sources to reveal more filters."
378
+ )
379
+
380
+ selected_authors, selected_types, selected_tags = [], [], []
381
+ year_range, journal_status = None, "All"
382
+ citation_range = (0, 1000)
383
+ top_k_results = 5
384
+
385
+ if selected_sources:
386
+ st.write("---")
387
+ selected_types = st.multiselect("Filter by Type:", ALLOWED_TYPES)
388
+ all_authors = sorted(list(set(a for it in theorems_data for a in (it.get('authors') or []))))
389
+ selected_authors = st.multiselect("Filter by Author(s):", all_authors)
390
+
391
+ # Tags come from union of categories per selected source
392
+ from collections import defaultdict
393
+ tags_per_source = defaultdict(set)
394
+ for it in theorems_data:
395
+ tags_per_source[it['source']].add(it.get('primary_category'))
396
+ union_tags = sorted({t for s in selected_sources for t in tags_per_source.get(s, set()) if t})
397
+ selected_tags = st.multiselect("Filter by Math Tag/Category:", union_tags)
398
+
399
+ if 'arXiv' in selected_sources:
400
+ year_range = st.slider("Filter by Year (for arXiv):", 1991, 2025, (1991, 2025))
401
+ journal_status = st.radio("Publication Status (for arXiv):", ["All", "Journal Article", "Preprint Only"], horizontal=True)
402
+
403
+ citation_range = st.slider("Filter by Citations:", 0, 1000, (0, 1000))
404
+ include_unknown_citations = st.checkbox(
405
+ "Include entries with unknown citation counts",
406
+ value=True,
407
+ help="If unchecked, results with unknown citation counts are excluded."
408
+ )
409
+ top_k_results = st.slider("Number of results to display:", 1, 20, 5)
410
+
411
+ filters = {
412
+ "authors": selected_authors,
413
+ "types": [t.lower() for t in selected_types],
414
+ "tags": selected_tags,
415
+ "sources": selected_sources,
416
+ "year_range": year_range,
417
+ "journal_status": journal_status,
418
+ "citation_range": citation_range,
419
+ "include_unknown_citations": include_unknown_citations,
420
+ "top_k": top_k_results
421
+ }
422
 
423
  user_query = st.text_input("Enter your query:", "")
 
424
  if st.button("Search") or user_query:
425
+ search_and_display_with_filters(user_query, model, theorems_data, corpus_embeddings, filters)
426
  else:
427
  st.error("Could not load the model or data from RDS. Please check your RDS database connection and credentials.")