kiyer commited on
Commit
4351936
·
1 Parent(s): 022c0b9

linked up everything, qn type, consensus

Browse files
app.py CHANGED
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
6
  from typing import List, Dict, Any, Tuple
7
  from collections import defaultdict
8
  from tqdm import tqdm
9
- import pandas as pd
10
  from datetime import datetime, date
11
  from datasets import load_dataset, load_from_disk
12
  from collections import Counter
@@ -16,24 +16,27 @@ import concurrent.futures
16
 
17
  from langchain import hub
18
  from langchain_openai import ChatOpenAI as openai_llm
19
- from langchain_core.runnables import RunnableConfig
 
 
20
  from langchain_community.callbacks import StreamlitCallbackHandler
21
-
22
- from langchain.agents import create_react_agent, Tool, AgentExecutor
23
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
 
 
 
 
 
 
 
24
 
25
-
26
- ts = time.time()
27
-
28
-
29
- anthropic_key = st.secrets["anthropic_key"]
30
-
31
- openai_key = st.secrets["openai_key"]
32
 
33
  from nltk.corpus import stopwords
34
  import nltk
35
  from openai import OpenAI
36
- import anthropic
37
  import cohere
38
  import faiss
39
 
@@ -50,12 +53,28 @@ except:
50
  nltk.download('stopwords')
51
  stopwords.words('english')
52
 
 
53
  from bokeh.plotting import figure
54
  from bokeh.models import ColumnDataSource
55
  from bokeh.io import output_notebook
56
  from bokeh.palettes import Spectral5
57
  from bokeh.transform import linear_cmap
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  st.image('local_files/pathfinder_logo.png')
61
 
@@ -63,15 +82,13 @@ st.expander("About", expanded=False).write(
63
  """
64
  Pathfinder v2.0 is a framework for searching and visualizing astronomy papers on the [arXiv](https://arxiv.org/) and [ADS](https://ui.adsabs.harvard.edu/) using the context
65
  sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts.
66
-
67
  This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things.
68
 
69
- **👈 Select a tool from the sidebar** to see some examples
70
- of what this framework can do!
71
 
72
  ### Tool summary:
73
  - Please wait while the initial data loads and compiles, this takes about a minute initially.
74
- - `Paper search` looks for relevant papers given an arxiv id or a question.
75
 
76
  This is not meant to be a replacement to existing tools like the
77
  [ADS](https://ui.adsabs.harvard.edu/),
@@ -79,33 +96,34 @@ st.expander("About", expanded=False).write(
79
  that otherwise might be missed during a literature survey.
80
  It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata,
81
  if you are interested in extending it please reach out!
82
-
83
-
84
- Also add: more pages, actual generation, diff. toggles for retrieval/gen, feedback form, socials, literature, contact us, copyright, collaboration, etc.
85
 
86
  The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail
87
  using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an
88
  atlas that shows well studied (forests) and currently uncharted areas (water).
89
  """
90
  )
91
-
92
-
93
-
94
-
95
  if 'arxiv_corpus' not in st.session_state:
96
  with st.spinner('loading data...'):
97
- try:
98
  arxiv_corpus = load_from_disk('data/')
99
  except:
100
  st.write('downloading data')
101
- arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train')
 
102
  arxiv_corpus.save_to_disk('data/')
103
  arxiv_corpus.add_faiss_index('embed')
104
  st.session_state.arxiv_corpus = arxiv_corpus
105
  st.toast('loaded arxiv corpus')
106
  else:
107
  arxiv_corpus = st.session_state.arxiv_corpus
108
-
109
  if 'ids' not in st.session_state:
110
  st.session_state.ids = arxiv_corpus['ads_id']
111
  st.session_state.titles = arxiv_corpus['title']
@@ -114,8 +132,8 @@ if 'ids' not in st.session_state:
114
  st.session_state.years = arxiv_corpus['date']
115
  st.session_state.kws = arxiv_corpus['keywords']
116
  st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
117
-
118
-
119
  #---------------------------------------------------------------
120
 
121
  # A hack to "clear" the previous result when submitting a new prompt. This avoids
@@ -144,186 +162,33 @@ def with_clear_container(submit_clicked: bool) -> bool:
144
  return True
145
 
146
  return False
147
-
148
- #----------------------------------------------------------------
149
-
150
- class Filter():
151
- def filter(self, query: str, arxiv_id: str) -> List[str]:
152
- pass
153
-
154
- class CitationFilter(Filter): # can do it with all metadata
155
- def __init__(self, corpus):
156
- self.corpus = corpus
157
- ids = ids
158
- cites = cites
159
- self.citation_counts = {ids[i]: cites[i] for i in range(len(ids))}
160
-
161
- def citation_weight(self, x, shift, scale):
162
- return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function
163
-
164
- def filter(self, doc_scores, weight = 0.1): # additive weighting
165
- citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores])
166
- cmean, cstd = np.median(citation_count), np.std(citation_count)
167
- citation_score = self.citation_weight(citation_count, cmean, cstd)
168
-
169
- for i, doc in enumerate(doc_scores):
170
- doc_scores[i][2] += weight * citation_score[i]
171
-
172
- class DateFilter(Filter): # include time weighting eventually
173
- def __init__(self, document_dates):
174
- self.document_dates = document_dates
175
-
176
- def parse_date(self, arxiv_id: str) -> datetime: # only for documents
177
- if arxiv_id.startswith('astro-ph'):
178
- arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0]
179
- try:
180
- year = int("20" + arxiv_id[:2])
181
- month = int(arxiv_id[2:4])
182
- except:
183
- year = 2023
184
- month = 1
185
- return date(year, month, 1)
186
-
187
- def weight(self, time, shift, scale):
188
- return 1 / (1 + np.exp((time - shift) / scale))
189
-
190
- def evaluate_filter(self, year, filter_string):
191
- try:
192
- # Use ast.literal_eval to safely evaluate the expression
193
- result = eval(filter_string, {"__builtins__": None}, {"year": year})
194
- return result
195
- except Exception as e:
196
- print(f"Error evaluating filter: {e}")
197
- return False
198
-
199
- def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0):
200
- filtered = []
201
-
202
- if boolean_date is not None:
203
- boolean_date = boolean_date.replace("AND", "and").replace("OR", "or")
204
- for doc in docs:
205
- if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date):
206
- filtered.append(doc)
207
-
208
- else:
209
- if min_date == None: min_date = date(1990, 1, 1)
210
- if max_date == None: max_date = date(2024, 7, 3)
211
-
212
- for doc in docs:
213
- if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date:
214
- filtered.append(doc)
215
-
216
- if time_score is not None: # apply time weighting
217
- for i, item in enumerate(filtered):
218
- time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365
219
- filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5)
220
-
221
- return filtered
222
-
223
- class KeywordFilter(Filter):
224
- def __init__(self, corpus,
225
- remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False):
226
-
227
- self.index_path = 'keyword_index.json'
228
- # self.metadata = metadata
229
- self.remove_capitals = remove_capitals
230
- self.ne_only = ne_only
231
- self.stopwords = set(stopwords.words('english'))
232
- self.verbose = verbose
233
- self.index = None
234
- self.kws = st.session_state.kws
235
- self.ids = st.session_state.ids
236
- self.titles = st.session_state.titles
237
-
238
- self.load_or_build_index()
239
-
240
- def preprocess_text(self, text: str) -> str:
241
- text = ''.join(char for char in text if char.isalnum() or char.isspace())
242
- if self.remove_capitals: text = text.lower()
243
- return ' '.join(word for word in text.split() if word.lower() not in self.stopwords)
244
-
245
- def build_index(self): # include the title in the index
246
- print("Building index...")
247
- self.index = {}
248
-
249
- for i in range(len(self.kws)):
250
- paper = self.ids[i]
251
- title = self.titles[i]
252
- title_keywords = set()
253
- for keyword in set(self.kws[i]) | title_keywords:
254
- term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords)
255
- if term not in self.index:
256
- self.index[term] = []
257
- self.index[term].append(self.ids[i])
258
-
259
- with open(self.index_path, 'w') as f:
260
- json.dump(self.index, f)
261
-
262
- def load_index(self):
263
- print("Loading existing index...")
264
- with open(self.index_path, 'rb') as f:
265
- self.index = json.load(f)
266
-
267
- print("Index loaded successfully.")
268
-
269
- def load_or_build_index(self):
270
- if os.path.exists(self.index_path):
271
- self.load_index()
272
- else:
273
- self.build_index()
274
 
275
- def parse_doc(self, doc):
276
- local_kws = []
277
-
278
- for phrase in doc._.phrases:
279
- local_kws.append(phrase.text.lower())
280
-
281
- return [self.preprocess_text(word) for word in local_kws]
282
-
283
- def get_propn(self, doc):
284
- result = []
285
-
286
- working_str = ''
287
- for token in doc:
288
- if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
289
- if working_str != '':
290
- result.append(working_str.strip())
291
- working_str = ''
292
-
293
- if(token.pos_ == "PROPN"):
294
- working_str += token.text + ' '
295
-
296
- if working_str != '': result.append(working_str.strip())
297
-
298
- return [self.preprocess_text(word) for word in result]
299
-
300
- def filter(self, query: str, doc_ids = None):
301
- doc = nlp(query)
302
- query_keywords = self.parse_doc(doc)
303
- nouns = self.get_propn(doc)
304
- if self.verbose: print('keywords:', query_keywords)
305
- if self.verbose: print('proper nouns:', nouns)
306
-
307
- filtered = set()
308
- if len(query_keywords) > 0 and not self.ne_only:
309
- for keyword in query_keywords:
310
- if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword])
311
-
312
- if len(nouns) > 0:
313
- ne_results = set()
314
- for noun in nouns:
315
- if noun in self.index.keys(): ne_results |= set(self.index[noun])
316
-
317
- if self.ne_only: filtered = ne_results # keep only named entity results
318
- else: filtered &= ne_results # take the intersection
319
-
320
- if doc_ids is not None: filtered &= doc_ids # apply filter to results
321
- return filtered
322
 
323
  class EmbeddingRetrievalSystem():
324
 
325
  def __init__(self, weight_citation = False, weight_date = False, weight_keywords = False):
326
-
327
  self.ids = st.session_state.ids
328
  self.years = st.session_state.years
329
  self.abstract = st.session_state.abstracts
@@ -331,7 +196,8 @@ class EmbeddingRetrievalSystem():
331
  self.embed_model = "text-embedding-3-small"
332
  self.dataset = arxiv_corpus
333
  self.kws = st.session_state.kws
334
-
 
335
  self.weight_citation = weight_citation
336
  self.weight_date = weight_date
337
  self.weight_keywords = weight_keywords
@@ -339,7 +205,7 @@ class EmbeddingRetrievalSystem():
339
 
340
  # self.citation_filter = CitationFilter(self.dataset)
341
  # self.date_filter = DateFilter(self.dataset['date'])
342
- self.keyword_filter = KeywordFilter(corpus=self.dataset, remove_capitals=True)
343
 
344
  def parse_date(self, id):
345
  # indexval = np.where(self.ids == id)[0][0]
@@ -354,12 +220,6 @@ class EmbeddingRetrievalSystem():
354
  embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data
355
  return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings]
356
 
357
- def init_filters(self):
358
-
359
- self.citation_filter = []
360
- self.date_filter = []
361
- self.keyword_filter = []
362
-
363
  def get_query_embedding(self, query):
364
  return self.make_embedding(query)
365
 
@@ -370,22 +230,77 @@ class EmbeddingRetrievalSystem():
370
  # xq = query_embedding.reshape(-1,1).T.astype('float32')
371
  # D, I = self.index.search(xq, top_k)
372
  # return I[0], D[0]
373
- tmp = self.dataset.search('embed',query_embedding, k=top_k)
374
  return [tmp.indices, tmp.scores]
375
-
376
  def rank_and_filter(self, query, query_embedding, query_date, top_k = 10, return_scores=False, time_result=None):
377
 
378
-
379
- topk_indices, similarities = self.calc_faiss(np.array(query_embedding), top_k = 300)
 
 
 
 
 
 
 
 
380
 
381
- if self.weight_keywords:
382
- keyword_matches = self.keyword_filter.filter(query)
383
- kw_indices = np.zeros_like(similarities)
384
- for s in keyword_matches:
385
- if self.id_to_index[s] in topk_indices:
386
- # print('yes', self.id_to_index[s], topk_indices[np.where(topk_indices == self.id_to_index[s])[0]])
387
- similarities[np.where(topk_indices == self.id_to_index[s])[0]] = similarities[np.where(topk_indices == self.id_to_index[s])[0]] * 10.
388
- similarities = similarities / 10.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))]
391
  top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k]
@@ -395,43 +310,43 @@ class EmbeddingRetrievalSystem():
395
 
396
  # Only keep the document IDs
397
  top_results = [doc[0] for doc in top_results]
398
- return top_results
399
-
400
  def retrieve(self, query, top_k, time_result=None, query_date = None, return_scores = False):
401
 
402
  query_embedding = self.get_query_embedding(query)
403
 
404
  # Judge time relevance
405
  if time_result is None:
406
- if self.weight_date:
407
  time_result, time_taken = self.analyze_temporal_query(query, self.anthropic_client)
408
- else:
409
  time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
410
 
411
- top_results = self.rank_and_filter(query,
412
- query_embedding,
413
- query_date,
414
- top_k,
415
- return_scores = return_scores,
416
  time_result = time_result)
417
-
418
  return top_results
419
 
420
  class HydeRetrievalSystem(EmbeddingRetrievalSystem):
421
- def __init__(self, generation_model: str = "claude-3-haiku-20240307",
422
- embedding_model: str = "text-embedding-3-small",
423
- temperature: float = 0.5,
424
- max_doclen: int = 500,
425
- generate_n: int = 1,
426
- embed_query = True,
427
  conclusion = False, **kwargs):
428
-
429
  # Handle the kwargs for the superclass init -- filters/citation weighting
430
  super().__init__(**kwargs)
431
-
432
  if max_doclen * generate_n > 8191:
433
  raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.")
434
-
435
  self.embedding_model = embedding_model
436
  self.generation_model = generation_model
437
 
@@ -442,58 +357,67 @@ class HydeRetrievalSystem(EmbeddingRetrievalSystem):
442
  self.embed_query = embed_query # embed the query vector?
443
  self.conclusion = conclusion # generate conclusion as well?
444
 
445
- self.anthropic_key = anthropic_key
446
- self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
447
-
 
448
  def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
449
  if time_result is None:
450
  if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
451
  else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
452
 
453
  docs = self.generate_docs(query)
 
 
454
  doc_embeddings = self.embed_docs(docs)
455
 
456
- if self.embed_query:
457
  query_emb = self.embed_docs([query])[0]
458
  doc_embeddings.append(query_emb)
459
-
460
  embedding = np.mean(np.array(doc_embeddings), axis = 0)
461
 
462
  top_results = self.rank_and_filter(query, embedding, query_date=None, top_k = top_k, return_scores = return_scores, time_result = time_result)
463
-
464
  return top_results
465
 
466
  def generate_doc(self, query: str):
467
- prompt = """You are an expert astronomer. Given a scientific query, generate the abstract"""
468
- if self.conclusion:
469
- prompt += " and conclusion"
470
- prompt += """ of an expert-level research paper
471
  that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion.
472
  Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen)
473
-
474
-
475
- message = self.generation_client.messages.create(
476
- model = self.generation_model,
477
- max_tokens = self.max_doclen,
478
- temperature = self.temperature,
479
- system = prompt,
480
- messages=[{ "role": "user",
481
- "content": [{"type": "text", "text": query,}] }]
482
- )
483
-
484
- return message.content[0].text
485
-
 
 
 
 
486
  def generate_docs(self, query: str):
487
  docs = []
488
- with concurrent.futures.ThreadPoolExecutor() as executor:
489
- future_to_query = {executor.submit(self.generate_doc, query): query for i in range(self.generate_n)}
490
- for future in concurrent.futures.as_completed(future_to_query):
491
- query = future_to_query[future]
492
- try:
493
- data = future.result()
494
- docs.append(data)
495
- except Exception as exc:
496
- pass
 
 
 
 
 
497
  return docs
498
 
499
  def embed_docs(self, docs: List[str]):
@@ -503,35 +427,35 @@ class HydeCohereRetrievalSystem(HydeRetrievalSystem):
503
  def __init__(self, **kwargs):
504
  super().__init__(**kwargs)
505
 
506
- self.cohere_key = "Of1MjzFjGmvzBAqdvNHTQLkAjecPcOKpiIPAnFMn"
507
  self.cohere_client = cohere.Client(self.cohere_key)
508
 
509
- def retrieve(self, query: str,
510
- top_k: int = 10,
511
  rerank_top_k: int = 250,
512
  return_scores = False, time_result = None,
513
  reweight = False) -> List[Tuple[str, str, float]]:
514
-
515
  if time_result is None:
516
  if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
517
  else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
518
-
519
  top_results = super().retrieve(query, top_k = rerank_top_k, time_result = time_result)
520
-
521
  # doc_texts = self.get_document_texts(top_results)
522
  # docs_for_rerank = [f"Abstract: {doc['abstract']}\nConclusions: {doc['conclusions']}" for doc in doc_texts]
523
  docs_for_rerank = [self.abstract[i] for i in top_results]
524
-
525
  if len(docs_for_rerank) == 0:
526
  return []
527
-
528
  reranked_results = self.cohere_client.rerank(
529
  query=query,
530
  documents=docs_for_rerank,
531
  model='rerank-english-v3.0',
532
  top_n=top_k
533
  )
534
-
535
  final_results = []
536
  for result in reranked_results.results:
537
  doc_id = top_results[result.index]
@@ -542,9 +466,9 @@ class HydeCohereRetrievalSystem(HydeRetrievalSystem):
542
  if reweight:
543
  if time_result['has_temporal_aspect']:
544
  final_results = self.date_filter.filter(final_results, time_score = time_result['expected_recency_weight'])
545
-
546
  if self.weight_citation: self.citation_filter.filter(final_results)
547
-
548
  if return_scores:
549
  return {result[0]: result[2] for result in final_results}
550
 
@@ -554,40 +478,113 @@ class HydeCohereRetrievalSystem(HydeRetrievalSystem):
554
  return self.embed_batch(docs)
555
 
556
  # ----------------------------------------------------------------
557
-
558
-
559
  if 'ec' not in st.session_state:
560
- ec = EmbeddingRetrievalSystem(weight_keywords=True)
561
  st.session_state.ec = ec
562
  st.toast('loaded retrieval system')
563
  else:
564
  ec = st.session_state.ec
565
-
566
- # Function to simulate question answering (replace with actual implementation)
567
- def answer_question(question, top_k, keywords, toggles, method, question_type):
568
- # Simulated answer (replace with actual logic)
569
- # return f"Answer to '{question}' using method {method} for {question_type} question."
570
- return run_ret(question, top_k)
571
 
 
 
 
 
572
 
573
- def get_papers(ids):
574
-
575
- papers, scores, links = [], [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  for i in ids:
577
  papers.append(st.session_state.titles[i])
578
  scores.append(ids[i])
579
  links.append('https://ui.adsabs.harvard.edu/abs/'+st.session_state.arxiv_corpus['bibcode'][i]+'/abstract')
580
-
 
 
 
581
  return pd.DataFrame({
582
  'Title': papers,
583
  'Relevance': scores,
584
- 'Link': links
 
 
 
585
  })
586
 
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
  def create_embedding_plot(rs):
590
-
 
 
591
 
592
  pltsource = ColumnDataSource(data=dict(
593
  x=st.session_state.arxiv_corpus['umap_x'],
@@ -595,10 +592,14 @@ def create_embedding_plot(rs):
595
  title=st.session_state.titles,
596
  link=st.session_state.arxiv_corpus['bibcode'],
597
  ))
598
-
599
  rsflag = np.zeros((len(st.session_state.ids),))
600
  rsflag[np.array([k for k in rs])] = 1
 
 
 
601
  pltsource.data['colors'] = rsflag * 0.8 + 0.1
 
602
  pltsource.data['sizes'] = (rsflag + 1)**5 / 100
603
 
604
  TOOLTIPS = """
@@ -609,22 +610,21 @@ def create_embedding_plot(rs):
609
  @link <br> <br>
610
  </div>
611
  """
612
-
613
  mapper = linear_cmap(field_name="colors", palette=Spectral5, low=0., high=1.)
614
 
615
  p = figure(width=700, height=900, tooltips=TOOLTIPS, x_range=(0, 20), y_range=(-4.2,18),
616
  title="UMAP projection of embeddings for the astro-ph corpus")
617
-
618
  p.axis.visible=False
619
  p.grid.visible=False
620
  p.outline_line_alpha = 0.
621
-
622
  p.circle('x', 'y', radius='sizes', source=pltsource, alpha=0.3, fill_color=mapper, fill_alpha='colors', line_color="lightgrey",line_alpha=0.1)
623
-
624
  return p
625
 
626
- # Function to simulate keyword extraction (replace with actual implementation)
627
- def extract_keywords(question):
628
  # Simulated keyword extraction (replace with actual logic)
629
  return ['keyword1', 'keyword2', 'keyword3']
630
 
@@ -633,184 +633,401 @@ def estimate_consensus():
633
  # Simulated consensus estimation (replace with actual calculation)
634
  return 0.75
635
 
636
- def run_ret(query, top_k):
637
- rs = ec.retrieve(query, top_k, return_scores=True)
638
- output_str = ''
639
- for i in rs:
640
- if rs[i] > 0.5:
641
- output_str = output_str + '---> ' + st.session_state.abstracts[i] + '(score: %.2f) \n' %rs[i]
642
- else:
643
- output_str = output_str + st.session_state.abstracts[i] + '(score: %.2f) \n' %rs[i]
644
- return output_str, rs
645
 
646
- def Library(query, top_k=7):
647
- rs = ec.retrieve(query, top_k, return_scores=True)
648
- op_docs = ''
649
- for paperno, i in enumerate(rs):
650
- # op_docs.append(abstracts[i])
651
- op_docs = op_docs + 'Paper %.0f:' %(paperno+1) +' (published in '+st.session_state.arxiv_corpus['bibcode'][i][0:4] + ') ' + st.session_state.titles[i] + '\n' + st.session_state.abstracts[i] + '\n\n'
652
- # st.write(op_docs)
653
- return op_docs
 
 
 
 
 
 
 
 
654
 
655
- search = DuckDuckGoSearchAPIWrapper()
656
- tools = [
657
- Tool(
658
- name="Library",
659
- func=Library,
660
- description="A source of information pertinent to your question. Do not answer a question without consulting this!"
661
- ),
662
- Tool(
663
- name="Search",
664
- func=search.run,
665
- description="useful for when you need to look up knowledge about common topics or current events",
666
- )
667
- ]
668
-
669
- if 'tools' not in st.session_state:
670
- st.session_state.tools = tools
671
-
672
- # for another question type:
673
- # First, find the quotes from the document that are most relevant to answering the question, and then print them in numbered order.
674
- # Quotes should be relatively short. If there are no relevant quotes, write “No relevant quotes” instead.
675
 
676
- gen_llm = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
677
 
678
- template = """You are an expert astronomer and cosmologist.
679
- Answer the following question as best you can using information from the library, but speaking in a concise and factual manner.
680
- If you can not come up with an answer, say you do not know.
681
- Try to break the question down into smaller steps and solve it in a logical manner.
682
 
683
- You have access to the following tools:
684
 
685
- {tools}
 
 
 
686
 
687
- Use the following format:
688
 
689
- Question: the input question you must answer
690
- Thought: you should always think about what to do
691
- Action: the action to take, should be one of [{tool_names}]
692
- Action Input: the input to the action
693
- Observation: the result of the action
694
- ... (this Thought/Action/Action Input/Observation can repeat N times)
695
- Thought: I now know the final answer
696
- Final Answer: the final answer to the original input question. provide information about how you arrived at the answer, and any nuances or uncertainties the reader should be aware of
697
 
698
- Begin! Remember to speak in a pedagogical and factual manner."
699
 
700
- Question: {input}
701
- Thought:{agent_scratchpad}"""
 
 
 
 
 
 
702
 
 
703
 
704
- prompt = hub.pull("hwchase17/react")
705
- prompt.template=template
706
 
707
- from langchain.callbacks import FileCallbackHandler
708
- from langchain.callbacks.manager import CallbackManager
709
- # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
710
- # file_path = f"agent_trace_{timestamp}.txt"
711
- file_path = "agent_trace.txt"
712
- file_handler = FileCallbackHandler(file_path)
713
- callback_manager=CallbackManager([file_handler])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
714
 
715
- tool_names = [tool.name for tool in st.session_state.tools]
716
- if 'agent' not in st.session_state:
717
- # agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
718
- agent = create_react_agent(llm=gen_llm, tools=tools, prompt=prompt)
719
- st.session_state.agent = agent
720
 
721
- if 'agent_executor' not in st.session_state:
722
- agent_executor = AgentExecutor(agent=st.session_state.agent, tools=st.session_state.tools, verbose=True, handle_parsing_errors=True, callbacks=CallbackManager([file_handler]))
723
- st.session_state.agent_executor = agent_executor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
 
725
 
726
  # Streamlit app
727
  def main():
728
-
729
  # st.title("Question Answering App")
730
-
731
-
732
  # Sidebar (Inputs)
733
  st.sidebar.header("Fine-tune the search")
734
  top_k = st.sidebar.slider("Number of papers to retrieve:", 3, 30, 10)
735
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
736
-
737
  st.sidebar.subheader("Toggles")
738
- toggle_a = st.sidebar.checkbox("Weight by keywords")
739
- toggle_b = st.sidebar.checkbox("weight by time")
740
- toggle_c = st.sidebar.checkbox("Weight by citations")
741
-
742
- method = st.sidebar.radio("Choose a method:", ["Semantic search", "Semantic search + HyDE", "Semantic search + HyDE + CoHERE"])
743
- question_type = st.sidebar.selectbox("Select question type:", ["Single paper", "Multi-paper", "Summary"])
744
- # store_output = st.sidebar.checkbox("Store the output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
 
746
-
747
  store_output = st.sidebar.button("Save output")
748
 
749
  # Main page (Outputs)
750
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  query = st.text_input("Ask me anything:")
752
  submit_button = st.button("Submit")
753
-
754
  if submit_button:
755
-
756
- # Process inputs
757
- keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
758
- toggles = {'Keyword weighting': toggle_a, 'Time weighting': toggle_b, 'Citation weighting': toggle_c}
759
-
760
- # Generate outputs
761
- answer, rs = answer_question(query, top_k, keywords, toggles, method, question_type)
762
- papers_df = get_papers(rs)
763
- embedding_plot = create_embedding_plot(rs)
764
- triggered_keywords = extract_keywords(query)
765
- consensus = estimate_consensus()
766
-
767
-
768
- # Display outputs
769
- answer = st.session_state.agent_executor.invoke({"input": query,})
770
- st.write(answer["output"])
771
-
772
- with open(file_path, 'r') as file:
773
- intermediate_steps = file.read()
774
-
775
- st.expander('Intermediate steps', expanded=False).write(intermediate_steps)
776
- # st.write(answer)
777
-
778
-
779
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  with st.expander("Relevant papers", expanded=True):
781
  # st.dataframe(papers_df, hide_index=True)
782
  st.data_editor(papers_df,
783
- column_config = {'Link':st.column_config.LinkColumn(display_text= 'https://ui.adsabs.harvard.edu/abs/(.*?)/abstract')}
784
  )
785
 
786
  with st.expander("Embedding map", expanded=False):
787
  st.bokeh_chart(embedding_plot)
788
-
789
  col1, col2 = st.columns(2)
790
-
791
  with col1:
792
-
793
- st.subheader("Question Type")
794
- st.write(question_type)
795
-
796
- st.subheader("Triggered Keywords")
797
- st.write(", ".join(triggered_keywords))
798
-
 
 
 
799
  with col2:
800
-
801
- st.subheader("Consensus Estimate")
802
- st.write(f"{consensus:.2%}")
803
-
804
- # st.subheader("Papers Used")
805
- # st.dataframe(papers_df)
806
-
807
-
808
-
 
 
809
  else:
810
- st.info("Use the sidebar to input parameters and submit to see results.")
811
-
812
  if store_output:
813
  st.toast("Output stored successfully!")
814
 
815
  if __name__ == "__main__":
816
- main()
 
6
  from typing import List, Dict, Any, Tuple
7
  from collections import defaultdict
8
  from tqdm import tqdm
9
+ import pandas as pd
10
  from datetime import datetime, date
11
  from datasets import load_dataset, load_from_disk
12
  from collections import Counter
 
16
 
17
  from langchain import hub
18
  from langchain_openai import ChatOpenAI as openai_llm
19
+ from langchain_openai import OpenAIEmbeddings
20
+ from langchain_core.runnables import RunnableConfig, RunnablePassthrough, RunnableParallel
21
+ from langchain_core.prompts import PromptTemplate
22
  from langchain_community.callbacks import StreamlitCallbackHandler
 
 
23
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
24
+ from langchain_community.vectorstores import Chroma
25
+ from langchain_community.document_loaders import TextLoader
26
+ from langchain.agents import create_react_agent, Tool, AgentExecutor
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain_core.output_parsers import StrOutputParser
29
+ from langchain.callbacks import FileCallbackHandler
30
+ from langchain.callbacks.manager import CallbackManager
31
 
32
+ import instructor
33
+ from pydantic import BaseModel, Field
34
+ from typing import List, Literal
 
 
 
 
35
 
36
  from nltk.corpus import stopwords
37
  import nltk
38
  from openai import OpenAI
39
+ # import anthropic
40
  import cohere
41
  import faiss
42
 
 
53
  nltk.download('stopwords')
54
  stopwords.words('english')
55
 
56
+
57
  from bokeh.plotting import figure
58
  from bokeh.models import ColumnDataSource
59
  from bokeh.io import output_notebook
60
  from bokeh.palettes import Spectral5
61
  from bokeh.transform import linear_cmap
62
 
63
+ ts = time.time()
64
+
65
+
66
+ # anthropic_key = st.secrets["anthropic_key"]
67
+ openai_key = st.secrets["openai_key"]
68
+ cohere_key = st.secrets['cohere_key']
69
+
70
+ gen_llm = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
71
+ consensus_client = instructor.patch(OpenAI(api_key=openai_key))
72
+
73
+ embed_client = OpenAI(api_key = openai_key)
74
+ embed_model = "text-embedding-3-small"
75
+ embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key)
76
+
77
+
78
 
79
  st.image('local_files/pathfinder_logo.png')
80
 
 
82
  """
83
  Pathfinder v2.0 is a framework for searching and visualizing astronomy papers on the [arXiv](https://arxiv.org/) and [ADS](https://ui.adsabs.harvard.edu/) using the context
84
  sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts.
85
+
86
  This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things.
87
 
88
+ **👈 Use the sidebar to tweak the search parameters to get better results**.
 
89
 
90
  ### Tool summary:
91
  - Please wait while the initial data loads and compiles, this takes about a minute initially.
 
92
 
93
  This is not meant to be a replacement to existing tools like the
94
  [ADS](https://ui.adsabs.harvard.edu/),
 
96
  that otherwise might be missed during a literature survey.
97
  It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata,
98
  if you are interested in extending it please reach out!
99
+
100
+
101
+ Also add: feedback form, socials, literature, contact us, copyright, collaboration, etc.
102
 
103
  The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail
104
  using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an
105
  atlas that shows well studied (forests) and currently uncharted areas (water).
106
  """
107
  )
108
+
109
+
110
+ # ---------------- get data and set up session state ---------------------------
111
+
112
  if 'arxiv_corpus' not in st.session_state:
113
  with st.spinner('loading data...'):
114
+ try:
115
  arxiv_corpus = load_from_disk('data/')
116
  except:
117
  st.write('downloading data')
118
+ # arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train')
119
+ arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data_galaxy',split='train')
120
  arxiv_corpus.save_to_disk('data/')
121
  arxiv_corpus.add_faiss_index('embed')
122
  st.session_state.arxiv_corpus = arxiv_corpus
123
  st.toast('loaded arxiv corpus')
124
  else:
125
  arxiv_corpus = st.session_state.arxiv_corpus
126
+
127
  if 'ids' not in st.session_state:
128
  st.session_state.ids = arxiv_corpus['ads_id']
129
  st.session_state.titles = arxiv_corpus['title']
 
132
  st.session_state.years = arxiv_corpus['date']
133
  st.session_state.kws = arxiv_corpus['keywords']
134
  st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
135
+
136
+
137
  #---------------------------------------------------------------
138
 
139
  # A hack to "clear" the previous result when submitting a new prompt. This avoids
 
162
  return True
163
 
164
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ # ---------------- define embedding retrieval systems --------------------------
167
+
168
+ def get_keywords(text):
169
+ result = []
170
+ pos_tag = ['PROPN', 'ADJ', 'NOUN']
171
+ doc = nlp(text.lower())
172
+ for token in doc:
173
+ if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
174
+ continue
175
+ if(token.pos_ in pos_tag):
176
+ result.append(token.text)
177
+ return result
178
+
179
+ def parse_doc(text, nret = 10):
180
+ local_kws = []
181
+ doc = nlp(text)
182
+ # examine the top-ranked phrases in the document
183
+ for phrase in doc._.phrases[:nret]:
184
+ # print(phrase.text)
185
+ local_kws.append(phrase.text)
186
+ return local_kws
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  class EmbeddingRetrievalSystem():
189
 
190
  def __init__(self, weight_citation = False, weight_date = False, weight_keywords = False):
191
+
192
  self.ids = st.session_state.ids
193
  self.years = st.session_state.years
194
  self.abstract = st.session_state.abstracts
 
196
  self.embed_model = "text-embedding-3-small"
197
  self.dataset = arxiv_corpus
198
  self.kws = st.session_state.kws
199
+ self.cites = st.session_state.cites
200
+
201
  self.weight_citation = weight_citation
202
  self.weight_date = weight_date
203
  self.weight_keywords = weight_keywords
 
205
 
206
  # self.citation_filter = CitationFilter(self.dataset)
207
  # self.date_filter = DateFilter(self.dataset['date'])
208
+ # self.keyword_filter = KeywordFilter(corpus=self.dataset, remove_capitals=True)
209
 
210
  def parse_date(self, id):
211
  # indexval = np.where(self.ids == id)[0][0]
 
220
  embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data
221
  return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings]
222
 
 
 
 
 
 
 
223
  def get_query_embedding(self, query):
224
  return self.make_embedding(query)
225
 
 
230
  # xq = query_embedding.reshape(-1,1).T.astype('float32')
231
  # D, I = self.index.search(xq, top_k)
232
  # return I[0], D[0]
233
+ tmp = self.dataset.search('embed', query_embedding, k=top_k)
234
  return [tmp.indices, tmp.scores]
235
+
236
  def rank_and_filter(self, query, query_embedding, query_date, top_k = 10, return_scores=False, time_result=None):
237
 
238
+ # st.write('status')
239
+
240
+ # st.write('toggles', self.toggles)
241
+ # st.write('question_type', self.question_type)
242
+ # st.write('rag method', self.rag_method)
243
+ # st.write('gen method', self.gen_method)
244
+
245
+ self.weight_keywords = self.toggles["Keyword weighting"]
246
+ self.weight_date = self.toggles["Time weighting"]
247
+ self.weight_citation = self.toggles["Citation weighting"]
248
 
249
+ topk_indices, similarities = self.calc_faiss(np.array(query_embedding), top_k = 1000)
250
+ similarities = 1/similarities # converting from a distance (less is better) to a similarity (more is better)
251
+
252
+ query_kws = get_keywords(query)
253
+ input_kws = self.query_input_keywords
254
+ query_kws = query_kws + input_kws
255
+ self.query_kws = query_kws
256
+
257
+ if self.weight_keywords == True:
258
+ sub_kws = [self.kws[i] for i in topk_indices]
259
+ kw_weight = np.zeros((len(topk_indices),)) + 0.1
260
+
261
+ for k in query_kws:
262
+ for i in (range(len(topk_indices))):
263
+ for j in range(len(sub_kws[i])):
264
+ if k.lower() in sub_kws[i][j].lower():
265
+ kw_weight[i] = kw_weight[i] + 0.1
266
+ # print(i, k, sub_kws[i][j])
267
+
268
+ # kw_weight = kw_weight**0.36 / np.amax(kw_weight**0.36)
269
+ kw_weight = kw_weight / np.amax(kw_weight)
270
+ else:
271
+ kw_weight = np.ones((len(topk_indices),))
272
+
273
+ if self.weight_date == True:
274
+ sub_dates = [self.years[i] for i in topk_indices]
275
+ date = datetime.now().date()
276
+ date_diff = np.array([((date - i).days / 365.) for i in sub_dates])
277
+ # age_weight = (1 + np.exp(date_diff/2.1))**(-1) + 0.5
278
+ age_weight = (1 + np.exp(date_diff/0.7))**(-1)
279
+ age_weight = age_weight / np.amax(age_weight)
280
+ else:
281
+ age_weight = np.ones((len(topk_indices),))
282
+
283
+ if self.weight_citation == True:
284
+ # st.write('weighting by citations')
285
+ sub_cites = np.array([self.cites[i] for i in topk_indices])
286
+ temp = sub_cites.copy()
287
+ temp[sub_cites > 300] = 300.
288
+ cite_weight = (1 + np.exp((300-temp)/42.0))**(-1.)
289
+ cite_weight = cite_weight / np.amax(cite_weight)
290
+ else:
291
+ cite_weight = np.ones((len(topk_indices),))
292
+
293
+ similarities = similarities * (kw_weight) * (age_weight) * (cite_weight)
294
+
295
+ # if self.weight_keywords:
296
+ # keyword_matches = self.keyword_filter.filter(query)
297
+ # self.query_kws = keyword_matches
298
+ # kw_indices = np.zeros_like(similarities)
299
+ # for s in keyword_matches:
300
+ # if self.id_to_index[s] in topk_indices:
301
+ # # print('yes', self.id_to_index[s], topk_indices[np.where(topk_indices == self.id_to_index[s])[0]])
302
+ # similarities[np.where(topk_indices == self.id_to_index[s])[0]] = similarities[np.where(topk_indices == self.id_to_index[s])[0]] * 10.
303
+ # similarities = similarities / 10.
304
 
305
  filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))]
306
  top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k]
 
310
 
311
  # Only keep the document IDs
312
  top_results = [doc[0] for doc in top_results]
313
+ return top_results
314
+
315
  def retrieve(self, query, top_k, time_result=None, query_date = None, return_scores = False):
316
 
317
  query_embedding = self.get_query_embedding(query)
318
 
319
  # Judge time relevance
320
  if time_result is None:
321
+ if self.weight_date:
322
  time_result, time_taken = self.analyze_temporal_query(query, self.anthropic_client)
323
+ else:
324
  time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
325
 
326
+ top_results = self.rank_and_filter(query,
327
+ query_embedding,
328
+ query_date,
329
+ top_k,
330
+ return_scores = return_scores,
331
  time_result = time_result)
332
+
333
  return top_results
334
 
335
  class HydeRetrievalSystem(EmbeddingRetrievalSystem):
336
+ def __init__(self, generation_model: str = "claude-3-haiku-20240307",
337
+ embedding_model: str = "text-embedding-3-small",
338
+ temperature: float = 0.5,
339
+ max_doclen: int = 500,
340
+ generate_n: int = 1,
341
+ embed_query = True,
342
  conclusion = False, **kwargs):
343
+
344
  # Handle the kwargs for the superclass init -- filters/citation weighting
345
  super().__init__(**kwargs)
346
+
347
  if max_doclen * generate_n > 8191:
348
  raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.")
349
+
350
  self.embedding_model = embedding_model
351
  self.generation_model = generation_model
352
 
 
357
  self.embed_query = embed_query # embed the query vector?
358
  self.conclusion = conclusion # generate conclusion as well?
359
 
360
+ # self.anthropic_key = anthropic_key
361
+ # self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
362
+ self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
363
+
364
  def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
365
  if time_result is None:
366
  if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
367
  else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
368
 
369
  docs = self.generate_docs(query)
370
+ st.expander('Abstract generated with hyde', expanded=False).write(docs)
371
+
372
  doc_embeddings = self.embed_docs(docs)
373
 
374
+ if self.embed_query:
375
  query_emb = self.embed_docs([query])[0]
376
  doc_embeddings.append(query_emb)
377
+
378
  embedding = np.mean(np.array(doc_embeddings), axis = 0)
379
 
380
  top_results = self.rank_and_filter(query, embedding, query_date=None, top_k = top_k, return_scores = return_scores, time_result = time_result)
381
+
382
  return top_results
383
 
384
  def generate_doc(self, query: str):
385
+ prompt = """You are an expert astronomer. Given a scientific query, generate the abstract of an expert-level research paper
 
 
 
386
  that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion.
387
  Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen)
388
+ # st.write('invoking hyde generation')
389
+
390
+ # message = self.generation_client.messages.create(
391
+ # model = self.generation_model,
392
+ # max_tokens = self.max_doclen,
393
+ # temperature = self.temperature,
394
+ # system = prompt,
395
+ # messages=[{ "role": "user",
396
+ # "content": [{"type": "text", "text": query,}] }]
397
+ # )
398
+ # return message.content[0].text
399
+
400
+ messages = [("system",prompt,),("human", query),]
401
+ return self.generation_client.invoke(messages).content
402
+
403
+
404
+
405
  def generate_docs(self, query: str):
406
  docs = []
407
+ for i in range(self.generate_n):
408
+ # st.write('invoking hyde generation2')
409
+
410
+ docs.append(self.generate_doc(query))
411
+ # with concurrent.futures.ThreadPoolExecutor() as executor:
412
+ # st.write('invoking hyde generation2')
413
+ # future_to_query = {executor.submit(self.generate_doc, query): query for i in range(self.generate_n)}
414
+ # for future in concurrent.futures.as_completed(future_to_query):
415
+ # query = future_to_query[future]
416
+ # try:
417
+ # data = future.result()
418
+ # docs.append(data)
419
+ # except Exception as exc:
420
+ # pass
421
  return docs
422
 
423
  def embed_docs(self, docs: List[str]):
 
427
  def __init__(self, **kwargs):
428
  super().__init__(**kwargs)
429
 
430
+ self.cohere_key = cohere_key
431
  self.cohere_client = cohere.Client(self.cohere_key)
432
 
433
+ def retrieve(self, query: str,
434
+ top_k: int = 10,
435
  rerank_top_k: int = 250,
436
  return_scores = False, time_result = None,
437
  reweight = False) -> List[Tuple[str, str, float]]:
438
+
439
  if time_result is None:
440
  if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
441
  else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
442
+
443
  top_results = super().retrieve(query, top_k = rerank_top_k, time_result = time_result)
444
+
445
  # doc_texts = self.get_document_texts(top_results)
446
  # docs_for_rerank = [f"Abstract: {doc['abstract']}\nConclusions: {doc['conclusions']}" for doc in doc_texts]
447
  docs_for_rerank = [self.abstract[i] for i in top_results]
448
+
449
  if len(docs_for_rerank) == 0:
450
  return []
451
+
452
  reranked_results = self.cohere_client.rerank(
453
  query=query,
454
  documents=docs_for_rerank,
455
  model='rerank-english-v3.0',
456
  top_n=top_k
457
  )
458
+
459
  final_results = []
460
  for result in reranked_results.results:
461
  doc_id = top_results[result.index]
 
466
  if reweight:
467
  if time_result['has_temporal_aspect']:
468
  final_results = self.date_filter.filter(final_results, time_score = time_result['expected_recency_weight'])
469
+
470
  if self.weight_citation: self.citation_filter.filter(final_results)
471
+
472
  if return_scores:
473
  return {result[0]: result[2] for result in final_results}
474
 
 
478
  return self.embed_batch(docs)
479
 
480
  # ----------------------------------------------------------------
481
+
 
482
  if 'ec' not in st.session_state:
483
+ ec = HydeCohereRetrievalSystem(weight_keywords=True)
484
  st.session_state.ec = ec
485
  st.toast('loaded retrieval system')
486
  else:
487
  ec = st.session_state.ec
 
 
 
 
 
 
488
 
489
+ def get_topk(query, top_k):
490
+ print('running retrieval')
491
+ rs = st.session_state.ec.retrieve(query, top_k, return_scores=True)
492
+ return rs
493
 
494
+ def Library(query, top_k = 7):
495
+ rs = get_topk(query, top_k = top_k)
496
+ op_docs = ''
497
+ for paperno, i in enumerate(rs):
498
+ op_docs = op_docs + 'Paper %.0f:' %(paperno+1) +' (published in '+st.session_state.arxiv_corpus['bibcode'][i][0:4] + ') ' + st.session_state.titles[i] + '\n' + st.session_state.abstracts[i] + '\n\n'
499
+
500
+ return op_docs
501
+
502
+ def Library2(query, top_k = 7):
503
+ rs = get_topk(query, top_k = top_k)
504
+ absts, fnames = [], []
505
+ for paperno, i in enumerate(rs):
506
+ absts.append(st.session_state.abstracts[i])
507
+ fnames.append(st.session_state.arxiv_corpus['bibcode'][i])
508
+ return absts, fnames, rs
509
+
510
+ def get_paper_df(ids):
511
+
512
+ papers, scores, yrs, links, cites, kws = [], [], [], [], [], []
513
  for i in ids:
514
  papers.append(st.session_state.titles[i])
515
  scores.append(ids[i])
516
  links.append('https://ui.adsabs.harvard.edu/abs/'+st.session_state.arxiv_corpus['bibcode'][i]+'/abstract')
517
+ yrs.append(st.session_state.arxiv_corpus['bibcode'][i][0:4])
518
+ cites.append(st.session_state.arxiv_corpus['cites'][i])
519
+ kws.append(st.session_state.arxiv_corpus['ads_keywords'][i])
520
+
521
  return pd.DataFrame({
522
  'Title': papers,
523
  'Relevance': scores,
524
+ 'Year': yrs,
525
+ 'ADS Link': links,
526
+ 'Citations': cites,
527
+ 'Keywords': kws,
528
  })
529
 
530
 
531
+ # def find_outliers(inp_simids, arxiv_cutoff_distance = 0.8):
532
+ #
533
+ # inp_simids = np.array(inp_simids)
534
+ #
535
+ # # Calculate the centroid for each point, excluding itself
536
+ # orange_black_points = st.session_state.embed[inp_simids]
537
+ #
538
+ # topk_dists = []
539
+ # for i, point in enumerate(orange_black_points):
540
+ # # Exclude the current point
541
+ # other_points = np.delete(orange_black_points, i, axis=0)
542
+ # # Calculate centroid of other points
543
+ # centroid = np.mean(other_points, axis=0)
544
+ # # Calculate distance from the point to this centroid
545
+ # dist = np.sqrt(np.sum((point - centroid)**2))
546
+ # topk_dists.append(dist)
547
+ #
548
+ # topk_dists = np.array(topk_dists)
549
+ #
550
+ # # Separate distances for orange and black points
551
+ # orange_distances = topk_dists[:len(inp_simids)]
552
+ # black_distances = topk_dists[len(inp_simids):]
553
+ #
554
+ # # Calculate the median of distances
555
+ # orange_black_distances = topk_dists
556
+ # median_topk_distance = np.median(orange_black_distances)
557
+ #
558
+ # # def get_sims_and_dists(inp_data):
559
+ #
560
+ # # all_sims, all_dists = [], []
561
+ #
562
+ # # np.random.seed(12)
563
+ # # rand_indices = np.random.choice(inp_data.shape[0], size=return_n, replace=False)
564
+ #
565
+ # # for j in tqdm(range(len(rand_indices))):
566
+ #
567
+ # # i = rand_indices[j]
568
+ # # inferred_vector = inp_data[i,:]
569
+ # # sims, dists = find_closest_dists(i, inp_data, return_n + 1)
570
+ # # all_sims.append(sims[1:])
571
+ # # all_dists.append(dists[1:])
572
+ #
573
+ # # return np.array(all_sims), np.array(all_dists)
574
+ #
575
+ # # # Identify papers with distances greater than the 95th percentile
576
+ # # _, all_dists = get_sims_and_dists(arxiv_ada_embeddings)
577
+ # # arxiv_cutoff_distance = find_cutoff_dist(all_dists)
578
+ # # hardcoding for now
579
+ # outlier_indices = inp_simids[np.where(orange_black_distances > arxiv_cutoff_distance)[0]]
580
+ # # outlier_titles = [titles[i] for i in outlier_indices]
581
+ #
582
+ # return outlier_indices #, outlier_titles
583
 
584
  def create_embedding_plot(rs):
585
+ """
586
+ function to create embedding plot
587
+ """
588
 
589
  pltsource = ColumnDataSource(data=dict(
590
  x=st.session_state.arxiv_corpus['umap_x'],
 
592
  title=st.session_state.titles,
593
  link=st.session_state.arxiv_corpus['bibcode'],
594
  ))
595
+
596
  rsflag = np.zeros((len(st.session_state.ids),))
597
  rsflag[np.array([k for k in rs])] = 1
598
+
599
+ # outflag = np.zeros((len(st.session_state.ids),))
600
+ # outflag[np.array([k for k in find_outliers(rs)])] = 1
601
  pltsource.data['colors'] = rsflag * 0.8 + 0.1
602
+ # pltsource.data['colors'][outflag] = 0.5
603
  pltsource.data['sizes'] = (rsflag + 1)**5 / 100
604
 
605
  TOOLTIPS = """
 
610
  @link <br> <br>
611
  </div>
612
  """
613
+
614
  mapper = linear_cmap(field_name="colors", palette=Spectral5, low=0., high=1.)
615
 
616
  p = figure(width=700, height=900, tooltips=TOOLTIPS, x_range=(0, 20), y_range=(-4.2,18),
617
  title="UMAP projection of embeddings for the astro-ph corpus")
618
+
619
  p.axis.visible=False
620
  p.grid.visible=False
621
  p.outline_line_alpha = 0.
622
+
623
  p.circle('x', 'y', radius='sizes', source=pltsource, alpha=0.3, fill_color=mapper, fill_alpha='colors', line_color="lightgrey",line_alpha=0.1)
624
+
625
  return p
626
 
627
+ def extract_keywords(question, ec):
 
628
  # Simulated keyword extraction (replace with actual logic)
629
  return ['keyword1', 'keyword2', 'keyword3']
630
 
 
633
  # Simulated consensus estimation (replace with actual calculation)
634
  return 0.75
635
 
 
 
 
 
 
 
 
 
 
636
 
637
+ def run_agent_qa(query, top_k):
638
+
639
+ # define tools
640
+ search = DuckDuckGoSearchAPIWrapper()
641
+ tools = [
642
+ Tool(
643
+ name="Library",
644
+ func=Library,
645
+ description="A source of information pertinent to your question. Do not answer a question without consulting this!"
646
+ ),
647
+ Tool(
648
+ name="Search",
649
+ func=search.run,
650
+ description="useful for when you need to look up knowledge about common topics or current events",
651
+ )
652
+ ]
653
 
654
+ if 'tools' not in st.session_state:
655
+ st.session_state.tools = tools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
 
657
+ # define prompt
658
 
659
+ # for another question type:
660
+ # First, find the quotes from the document that are most relevant to answering the question, and then print them in numbered order.
661
+ # Quotes should be relatively short. If there are no relevant quotes, write “No relevant quotes” instead.
 
662
 
 
663
 
664
+ template = """You are an expert astronomer and cosmologist.
665
+ Answer the following question as best you can using information from the library, but speaking in a concise and factual manner.
666
+ If you can not come up with an answer, say you do not know.
667
+ Try to break the question down into smaller steps and solve it in a logical manner.
668
 
669
+ You have access to the following tools:
670
 
671
+ {tools}
 
 
 
 
 
 
 
672
 
673
+ Use the following format:
674
 
675
+ Question: the input question you must answer
676
+ Thought: you should always think about what to do
677
+ Action: the action to take, should be one of [{tool_names}]
678
+ Action Input: the input to the action
679
+ Observation: the result of the action
680
+ ... (this Thought/Action/Action Input/Observation can repeat N times)
681
+ Thought: I now know the final answer
682
+ Final Answer: the final answer to the original input question. provide information about how you arrived at the answer, and any nuances or uncertainties the reader should be aware of
683
 
684
+ Begin! Remember to speak in a pedagogical and factual manner."
685
 
686
+ Question: {input}
687
+ Thought:{agent_scratchpad}"""
688
 
689
+ prompt = hub.pull("hwchase17/react")
690
+ prompt.template=template
691
+
692
+ # path to write intermediate trace to
693
+
694
+ file_path = "agent_trace.txt"
695
+ try:
696
+ os.remove(file_path)
697
+ except:
698
+ pass
699
+ file_handler = FileCallbackHandler(file_path)
700
+ callback_manager=CallbackManager([file_handler])
701
+
702
+ # define and execute agent
703
+
704
+ tool_names = [tool.name for tool in st.session_state.tools]
705
+ if 'agent' not in st.session_state:
706
+ # agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
707
+ agent = create_react_agent(llm=gen_llm, tools=tools, prompt=prompt)
708
+ st.session_state.agent = agent
709
+
710
+ if 'agent_executor' not in st.session_state:
711
+ agent_executor = AgentExecutor(agent=st.session_state.agent, tools=st.session_state.tools, verbose=True, handle_parsing_errors=True, callbacks=CallbackManager([file_handler]))
712
+ st.session_state.agent_executor = agent_executor
713
+
714
+ answer = st.session_state.agent_executor.invoke({"input": query,})
715
+ return answer
716
+
717
+ def make_rag_qa_answer(query, top_k = 10):
718
+
719
+ absts, fhdrs, rs = Library2(query, top_k = top_k)
720
+
721
+ temp_abst = ''
722
+ loaders = []
723
+ for i in range(len(absts)):
724
+ temp_abst = absts[i]
725
+
726
+ try:
727
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
728
+ except:
729
+ os.mkdir('absts')
730
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
731
+ n = text_file.write(temp_abst)
732
+ text_file.close()
733
+ loader = TextLoader("absts/"+fhdrs[i]+".txt")
734
+ loaders.append(loader)
735
+
736
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=50, add_start_index=True)
737
+
738
+ splits = text_splitter.split_documents([loader.load()[0] for loader in loaders])
739
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, collection_name='retdoc4')
740
+ # retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6, "fetch_k": len(splits)})
741
+ retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
742
+
743
+ for i in range(len(absts)):
744
+ os.remove("absts/"+fhdrs[i]+".txt")
745
+
746
+ template = """You are an expert astronomer and cosmologist.
747
+ Answer the following question as best you can using information from the library, but speaking in a concise and factual manner.
748
+ If you can not come up with an answer, say you do not know.
749
+ Try to break the question down into smaller steps and solve it in a logical manner.
750
+
751
+ Provide information about how you arrived at the answer, and any nuances or uncertainties the reader should be aware of.
752
+
753
+ Begin! Remember to speak in a pedagogical and factual manner."
754
+
755
+ Relevant documents:{context}
756
 
757
+ Question: {question}
758
+ Answer:"""
759
+ prompt = PromptTemplate.from_template(template)
 
 
760
 
761
+ def format_docs(docs):
762
+ return "\n\n".join(doc.page_content for doc in docs)
763
+
764
+
765
+ rag_chain_from_docs = (
766
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
767
+ | prompt
768
+ | gen_llm
769
+ | StrOutputParser()
770
+ )
771
+
772
+ rag_chain_with_source = RunnableParallel(
773
+ {"context": retriever, "question": RunnablePassthrough()}
774
+ ).assign(answer=rag_chain_from_docs)
775
+
776
+ rag_answer = rag_chain_with_source.invoke(query, )
777
+
778
+ vectorstore.delete_collection()
779
+ return rag_answer, rs
780
+
781
+ def guess_question_type(query: str):
782
+ categorization_prompt = """You are an expert astrophysicist and computer scientist specializing in linguistics and semantics. Your task is to categorize a given query into one of the following categories:
783
+
784
+ 1. Summarization
785
+ 2. Single-paper factual
786
+ 3. Multi-paper factual
787
+ 4. Named entity recognition
788
+ 5. Jargon-specific questions / overloaded words
789
+ 6. Time-sensitive
790
+ 7. Consensus evaluation
791
+ 8. What-ifs and counterfactuals
792
+ 9. Compositional
793
+
794
+ Analyze the query carefully, considering its content, structure, and implications. Then, determine which of the above categories best fits the query.
795
+
796
+ In your analysis, consider the following:
797
+ - Does the query ask for a well-known datapoint or mechanism?
798
+ - Can it be answered by a single paper or does it require multiple sources?
799
+ - Does it involve proper nouns or specific scientific terms?
800
+ - Is it time-dependent or likely to change in the near future?
801
+ - Does it require evaluating consensus across multiple sources?
802
+ - Is it a hypothetical or counterfactual question?
803
+ - Does it need to be broken down into sub-queries (i.e. compositional)?
804
+
805
+ After your analysis, categorize the query into one of the nine categories listed above.
806
+
807
+ Provide a brief explanation for your categorization, highlighting the key aspects of the query that led to your decision.
808
+
809
+ Present your final answer in the following format:
810
+
811
+ <categorization>
812
+ Category: [Selected category]
813
+ Explanation: [Your explanation for the categorization]
814
+ </categorization>"""
815
+ # st.write('invoking hyde generation')
816
+
817
+ # message = self.generation_client.messages.create(
818
+ # model = self.generation_model,
819
+ # max_tokens = self.max_doclen,
820
+ # temperature = self.temperature,
821
+ # system = prompt,
822
+ # messages=[{ "role": "user",
823
+ # "content": [{"type": "text", "text": query,}] }]
824
+ # )
825
+ # return message.content[0].text
826
+
827
+ messages = [("system",categorization_prompt,),("human", query),]
828
+ return st.session_state.ec.generation_client.invoke(messages).content
829
+
830
+
831
+ class OverallConsensusEvaluation(BaseModel):
832
+ consensus: Literal["Strong Agreement", "Moderate Agreement", "Weak Agreement", "No Clear Consensus", "Weak Disagreement", "Moderate Disagreement", "Strong Disagreement"] = Field(
833
+ ...,
834
+ description="The overall level of consensus between the query and the abstracts"
835
+ )
836
+ explanation: str = Field(
837
+ ...,
838
+ description="A detailed explanation of the consensus evaluation"
839
+ )
840
+ relevance_score: float = Field(
841
+ ...,
842
+ description="A score from 0 to 1 indicating how relevant the abstracts are to the query overall",
843
+ ge=0,
844
+ le=1
845
+ )
846
+
847
+ def evaluate_overall_consensus(query: str, abstracts: List[str]) -> OverallConsensusEvaluation:
848
+ """
849
+ Evaluates the overall consensus of the abstracts in relation to the query in a single LLM call.
850
+ """
851
+ prompt = f"""
852
+ Query: {query}
853
+
854
+ You will be provided with {len(abstracts)} scientific abstracts. Your task is to:
855
+ 1. Evaluate the overall consensus between the query and the abstracts.
856
+ 2. Provide a detailed explanation of your consensus evaluation.
857
+ 3. Assign an overall relevance score from 0 to 1, where 0 means completely irrelevant and 1 means highly relevant.
858
+
859
+ For the consensus evaluation, use one of the following levels:
860
+ Strong Agreement, Moderate Agreement, Weak Agreement, No Clear Consensus, Weak Disagreement, Moderate Disagreement, Strong Disagreement
861
+
862
+ Here are the abstracts:
863
+
864
+ {' '.join([f"Abstract {i+1}: {abstract}" for i, abstract in enumerate(abstracts)])}
865
+
866
+ Provide your evaluation in a structured format.
867
+ """
868
+
869
+ response = consensus_client.chat.completions.create(
870
+ model="gpt-4",
871
+ response_model=OverallConsensusEvaluation,
872
+ messages=[
873
+ {"role": "system", "content": """You are an assistant with expertise in astrophysics for question-answering tasks.
874
+ Evaluate the overall consensus of the retrieved scientific abstracts in relation to a given query.
875
+ If you don't know the answer, just say that you don't know.
876
+ Use six sentences maximum and keep the answer concise."""},
877
+ {"role": "user", "content": prompt}
878
+ ],
879
+ temperature=0
880
+ )
881
+
882
+ return response
883
 
884
 
885
  # Streamlit app
886
  def main():
887
+
888
  # st.title("Question Answering App")
889
+
890
+
891
  # Sidebar (Inputs)
892
  st.sidebar.header("Fine-tune the search")
893
  top_k = st.sidebar.slider("Number of papers to retrieve:", 3, 30, 10)
894
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
895
+
896
  st.sidebar.subheader("Toggles")
897
+ toggle_a = st.sidebar.toggle("Weight by keywords", value = False)
898
+ toggle_b = st.sidebar.toggle("Weight by date", value = False)
899
+ toggle_c = st.sidebar.toggle("Weight by citations", value = False)
900
+
901
+ method = st.sidebar.radio("Retrieval method:", ["Semantic search", "Semantic search + HyDE", "Semantic search + HyDE + CoHERE"], index=2)
902
+ if (method == "Semantic search"):
903
+ with st.spinner('set retrieval method to'+ method):
904
+ st.session_state.ec = EmbeddingRetrievalSystem(weight_keywords=True)
905
+ elif (method == "Semantic search + HyDE"):
906
+ with st.spinner('set retrieval method to'+ method):
907
+ st.session_state.ec = HydeRetrievalSystem(weight_keywords=True)
908
+ elif (method == "Semantic search + HyDE + CoHERE"):
909
+ with st.spinner('set retrieval method to'+ method):
910
+ st.session_state.ec = HydeCohereRetrievalSystem(weight_keywords=True)
911
+
912
+ method2 = st.sidebar.radio("Generation complexity:", ["Basic RAG","ReAct Agent"])
913
+ if method2 == "Basic RAG":
914
+ st.session_state.gen_method = 'rag'
915
+ elif method2 == "ReAct Agent":
916
+ st.session_state.gen_method = 'agent'
917
+
918
 
919
+ question_type = st.sidebar.selectbox("Select question type:", ["Single paper", "Multi-paper", "Summary"])
920
  store_output = st.sidebar.button("Save output")
921
 
922
  # Main page (Outputs)
923
+ # st.markdown("""
924
+ # <style>
925
+ # .stTextInput > div > div { font-size: 50px; }
926
+ # </style>
927
+ # """, unsafe_allow_html=True)
928
+
929
+ # st.markdown(
930
+ # """
931
+ # <style>
932
+ # textarea {
933
+ # font-size: 3rem !important;
934
+ # font-weight: bold;
935
+ # font-family: "Times New Roman", Times, serif;
936
+ # }
937
+ # input {
938
+ # font-size: 3rem !important;
939
+ # font-weight: bold;
940
+ # font-family: "Times New Roman", Times, serif;
941
+ # }
942
+ # </style>
943
+ # """,
944
+ # unsafe_allow_html=True,
945
+ # )
946
+ # query = st.text_area("Ask me anything:", height=30)
947
+
948
  query = st.text_input("Ask me anything:")
949
  submit_button = st.button("Submit")
950
+
951
  if submit_button:
952
+
953
+ search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...']
954
+
955
+ with st.spinner(search_text_list[np.random.choice(len(search_text_list))]):
956
+
957
+ # Process inputs
958
+ keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
959
+ toggles = {'Keyword weighting': toggle_a, 'Time weighting': toggle_b, 'Citation weighting': toggle_c}
960
+ # Generate outputs
961
+
962
+ st.session_state.ec.query_input_keywords = keywords
963
+ st.session_state.ec.toggles = toggles
964
+ st.session_state.ec.question_type = question_type
965
+ st.session_state.ec.rag_method = method
966
+ st.session_state.ec.gen_method = method2
967
+
968
+ # Display outputs
969
+ if st.session_state.gen_method == 'agent':
970
+ answer = run_agent_qa(query, top_k)
971
+ rs = get_topk(query, top_k)
972
+
973
+ st.write(answer["output"])
974
+
975
+ file_path = "agent_trace.txt"
976
+ with open(file_path, 'r') as file:
977
+ intermediate_steps = file.read()
978
+
979
+ st.expander('Intermediate steps', expanded=False).write(intermediate_steps)
980
+
981
+ elif st.session_state.gen_method == 'rag':
982
+ answer, rs = make_rag_qa_answer(query, top_k)
983
+ st.write(answer['answer'])
984
+
985
+ papers_df = get_paper_df(rs)
986
+ embedding_plot = create_embedding_plot(rs)
987
+ triggered_keywords = st.session_state.ec.query_kws
988
+ st.write('**Triggered keywords:** `'+ "`, `".join(triggered_keywords)+'`')
989
+ # consensus = estimate_consensus()
990
+
991
+
992
  with st.expander("Relevant papers", expanded=True):
993
  # st.dataframe(papers_df, hide_index=True)
994
  st.data_editor(papers_df,
995
+ column_config = {'ADS Link':st.column_config.LinkColumn(display_text= 'https://ui.adsabs.harvard.edu/abs/(.*?)/abstract')}
996
  )
997
 
998
  with st.expander("Embedding map", expanded=False):
999
  st.bokeh_chart(embedding_plot)
1000
+
1001
  col1, col2 = st.columns(2)
1002
+
1003
  with col1:
1004
+
1005
+ st.subheader("Question type suggestion")
1006
+ question_type_gen = guess_question_type(query)
1007
+ if '<categorization>' in question_type_gen:
1008
+ question_type_gen = question_type_gen.split('<categorization>')[1]
1009
+ if '</categorization>' in question_type_gen:
1010
+ question_type_gen = question_type_gen.split('</categorization>')[0]
1011
+ question_type_gen = question_type_gen.replace('\n',' \n')
1012
+ st.markdown(question_type_gen)
1013
+
1014
  with col2:
1015
+
1016
+ # st.subheader("Triggered Keywords")
1017
+ # st.write(", ".join(triggered_keywords))
1018
+
1019
+ consensus_answer = evaluate_overall_consensus(query, [st.session_state.abstracts[i] for i in rs])
1020
+ st.subheader("Consensus: "+consensus_answer.consensus)
1021
+ st.markdown(consensus_answer.explanation)
1022
+ st.markdown('Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score)
1023
+
1024
+ # st.write(f"{consensus:.2%}")
1025
+
1026
  else:
1027
+ st.info("Use the sidebar to tweak the search parameters to get better results.")
1028
+
1029
  if store_output:
1030
  st.toast("Output stored successfully!")
1031
 
1032
  if __name__ == "__main__":
1033
+ main()
data/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2534048757c630a5a9addf362d3077da0427e55ae1cae0c93dd213363ddfbcc7
3
+ size 498031096
data/dataset_info.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "builder_name": "parquet",
3
+ "citation": "",
4
+ "config_name": "default",
5
+ "dataset_name": "pathfinder_arxiv_data_galaxy",
6
+ "dataset_size": 505886100,
7
+ "description": "",
8
+ "download_checksums": {
9
+ "hf://datasets/kiyer/pathfinder_arxiv_data_galaxy@29754b03f3cd82e4051ece1cf96605f8756bc197/data/train-00000-of-00001.parquet": {
10
+ "num_bytes": 379674094,
11
+ "checksum": null
12
+ }
13
+ },
14
+ "download_size": 379674094,
15
+ "features": {
16
+ "ads_id": {
17
+ "dtype": "string",
18
+ "_type": "Value"
19
+ },
20
+ "arxiv_id": {
21
+ "dtype": "string",
22
+ "_type": "Value"
23
+ },
24
+ "title": {
25
+ "dtype": "string",
26
+ "_type": "Value"
27
+ },
28
+ "abstract": {
29
+ "dtype": "string",
30
+ "_type": "Value"
31
+ },
32
+ "embed": {
33
+ "feature": {
34
+ "dtype": "float32",
35
+ "_type": "Value"
36
+ },
37
+ "_type": "Sequence"
38
+ },
39
+ "umap_x": {
40
+ "dtype": "float32",
41
+ "_type": "Value"
42
+ },
43
+ "umap_y": {
44
+ "dtype": "float32",
45
+ "_type": "Value"
46
+ },
47
+ "date": {
48
+ "dtype": "date32",
49
+ "_type": "Value"
50
+ },
51
+ "cites": {
52
+ "dtype": "int64",
53
+ "_type": "Value"
54
+ },
55
+ "bibcode": {
56
+ "dtype": "string",
57
+ "_type": "Value"
58
+ },
59
+ "keywords": {
60
+ "feature": {
61
+ "dtype": "string",
62
+ "_type": "Value"
63
+ },
64
+ "_type": "Sequence"
65
+ },
66
+ "ads_keywords": {
67
+ "feature": {
68
+ "dtype": "string",
69
+ "_type": "Value"
70
+ },
71
+ "_type": "Sequence"
72
+ },
73
+ "read_count": {
74
+ "dtype": "int64",
75
+ "_type": "Value"
76
+ },
77
+ "doi": {
78
+ "feature": {
79
+ "dtype": "string",
80
+ "_type": "Value"
81
+ },
82
+ "_type": "Sequence"
83
+ },
84
+ "authors": {
85
+ "feature": {
86
+ "dtype": "string",
87
+ "_type": "Value"
88
+ },
89
+ "_type": "Sequence"
90
+ },
91
+ "aff": {
92
+ "feature": {
93
+ "dtype": "string",
94
+ "_type": "Value"
95
+ },
96
+ "_type": "Sequence"
97
+ },
98
+ "cite_bibcodes": {
99
+ "feature": {
100
+ "dtype": "string",
101
+ "_type": "Value"
102
+ },
103
+ "_type": "Sequence"
104
+ },
105
+ "ref_bibcodes": {
106
+ "feature": {
107
+ "dtype": "string",
108
+ "_type": "Value"
109
+ },
110
+ "_type": "Sequence"
111
+ }
112
+ },
113
+ "homepage": "",
114
+ "license": "",
115
+ "size_in_bytes": 885560194,
116
+ "splits": {
117
+ "train": {
118
+ "name": "train",
119
+ "num_bytes": 505886100,
120
+ "num_examples": 41195,
121
+ "shard_lengths": [
122
+ 41000,
123
+ 195
124
+ ],
125
+ "dataset_name": "pathfinder_arxiv_data_galaxy"
126
+ }
127
+ },
128
+ "version": {
129
+ "version_str": "0.0.0",
130
+ "major": 0,
131
+ "minor": 0,
132
+ "patch": 0
133
+ }
134
+ }
data/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "61bcd9aec14a17d4",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": "train"
13
+ }
requirements.txt CHANGED
@@ -10,7 +10,8 @@ langchain_community
10
  langchain_core
11
  langchainhub
12
  openai
13
- anthropic
 
14
  feedparser
15
  tiktoken
16
  chromadb
 
10
  langchain_core
11
  langchainhub
12
  openai
13
+ instructor
14
+ pydantic
15
  feedparser
16
  tiktoken
17
  chromadb