kiyer commited on
Commit
f28b621
β€’
1 Parent(s): 182832e

updates to codebase for embeddings and RAG QA.

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
absts/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pages/{2_arxiv_embedding.py β†’ 1_arxiv_embedding_explorer.py} RENAMED
@@ -74,9 +74,9 @@ def density_estimation(m1, m2, xmin=0, ymin=0, xmax=15, ymax=15):
74
  st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
75
  st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
76
 
77
- st.sidebar.text_input("Search query", key="phrase", value="")
78
- alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.1)
79
- size_value = st.sidebar.slider("Pick the hexbin size",0.0,2.0,0.2)
80
 
81
  phrase=st.session_state.phrase
82
 
@@ -103,10 +103,19 @@ ID: $index
103
  """
104
 
105
  p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
106
- title="UMAP projection of trained ArXiv corpus | heatmap keyword: "+phrase)
107
 
108
- p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
109
- palette = np.flip(OrRd[8]), alpha=alpha_value)
110
  p.circle('x', 'y', size=3, source=source, alpha=0.3)
111
-
112
  st.bokeh_chart(p)
 
 
 
 
 
 
 
 
 
 
 
74
  st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
75
  st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
76
 
77
+ st.sidebar.text_input("Search query", key="phrase", value="Quenching")
78
+ alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.81)
79
+ size_value = st.sidebar.slider("Pick the hexbin gridsize",10,50,20)
80
 
81
  phrase=st.session_state.phrase
82
 
 
103
  """
104
 
105
  p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
106
+ title="UMAP projection of embeddings for the astro-ph.GA corpus"+phrase)
107
 
108
+ # p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
109
+ # palette = np.flip(OrRd[8]), alpha=alpha_value)
110
  p.circle('x', 'y', size=3, source=source, alpha=0.3)
 
111
  st.bokeh_chart(p)
112
+
113
+ fig = plt.figure(figsize=(10.5,9*0.8328))
114
+ plt.scatter(embedding[0:,0], embedding[0:,1],s=2,alpha=0.1)
115
+ plt.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1],
116
+ gridsize=size_value, cmap = 'viridis', alpha=alpha_value,extent=(-1,16,1.5,16),mincnt=10)
117
+ plt.title("UMAP localization of heatmap keyword: "+phrase)
118
+ plt.axis([0,15,2.5,15]);
119
+ clbr = plt.colorbar(); clbr.set_label('# papers')
120
+ plt.axis('off')
121
+ st.pyplot(fig)
pages/{1_paper_search.py β†’ 2_paper_search.py} RENAMED
File without changes
pages/{3_qa_sources_v2.py β†’ 3_answering_questions.py} RENAMED
@@ -1,4 +1,3 @@
1
- # set the environment variables needed for openai package to know to reach out to azure
2
  import os
3
  import datetime
4
  import faiss
@@ -181,7 +180,7 @@ def list_similar_papers_v2(model_data,
181
  for i in range(start_range,start_range+return_n):
182
 
183
  abstracts_relevant.append(all_text[sims[i]])
184
- fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
185
  fhdrs.append(fhdr)
186
  textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
187
  textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
@@ -325,7 +324,7 @@ def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
325
  temp = temp[0:-2] + ' et al. 19' + temp[-2:]
326
  temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
327
  st.markdown(temp)
328
- simids = np.array(srcindices)
329
 
330
  fig = plt.figure(figsize=(9,9))
331
  plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
@@ -338,100 +337,15 @@ def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
338
 
339
  return rag_answer
340
 
341
- def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
342
-
343
- show_authors = True
344
- show_summary = True
345
- sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
346
- doc_id = query,
347
- input_type='keywords',
348
- show_authors = show_authors, show_summary = show_summary,
349
- return_n = return_n)
350
-
351
- temp_abst = ''
352
- loaders = []
353
- for i in range(len(absts)):
354
- temp_abst = absts[i]
355
-
356
- try:
357
- text_file = open("absts/"+fhdrs[i]+".txt", "w")
358
- except:
359
- os.mkdir('absts')
360
- text_file = open("absts/"+fhdrs[i]+".txt", "w")
361
- n = text_file.write(temp_abst)
362
- text_file.close()
363
- loader = TextLoader("absts/"+fhdrs[i]+".txt")
364
- loaders.append(loader)
365
-
366
- lc_index = VectorstoreIndexCreator().from_loaders(loaders)
367
-
368
- st.markdown('### User query: '+query)
369
- if show_pure_answer == True:
370
- st.markdown('pure answer:')
371
- st.markdown(lc_index.query(query))
372
- st.markdown(' ')
373
- st.markdown('#### context-based answer from sources:')
374
- output = lc_index.query_with_sources(query + ' Let\'s work this out in a step by step way to be sure we have the right answer.' ) #zero-shot in-context prompting from Zhou+22, Kojima+22
375
- st.markdown(output['answer'])
376
- opstr = '#### Primary sources: \n'
377
- st.markdown(opstr)
378
-
379
- # opstr = ''
380
- # for i in range(len(output['sources'])):
381
- # opstr = opstr +'\n'+ output['sources'][i]
382
-
383
- textstr = ''
384
- ng = len(output['sources'].split())
385
- abs_indices = []
386
-
387
- for i in range(ng):
388
- if i == (ng-1):
389
- tempid = output['sources'].split()[i].split('_')[1][0:-4]
390
- else:
391
- tempid = output['sources'].split()[i].split('_')[1][0:-5]
392
- try:
393
- abs_index = all_arxivid.index(tempid)
394
- abs_indices.append(abs_index)
395
- textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
396
- textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
397
- textstr = textstr + '**Authors:** '
398
- temp = all_authors[abs_index]
399
- for ak in range(4):
400
- if ak < len(temp)-1:
401
- textstr = textstr + temp[ak].name + ', '
402
- else:
403
- textstr = textstr + temp[ak].name + ' \n'
404
- if len(temp) > 3:
405
- textstr = textstr + ' et al. \n'
406
- textstr = textstr + '**Summary:** '
407
- text = all_text[abs_index]
408
- text = text.replace('\n', ' ')
409
- textstr = textstr + summarizer.summarize(text) + ' \n'
410
- except:
411
- textstr = textstr + output['sources'].split()[i]
412
- # opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
413
- # opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
414
-
415
- textstr = textstr + ' '
416
- textstr = textstr + ' \n'
417
- st.markdown(textstr)
418
-
419
- fig = plt.figure(figsize=(9,9))
420
- plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
421
- plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
422
- plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
423
- st.pyplot(fig)
424
-
425
- if show_all_sources == True:
426
- st.markdown('\n #### Other interesting papers:')
427
- st.markdown(sims)
428
- return output
429
 
430
  st.title('ArXiv-based question answering')
431
  st.markdown('[Includes papers up to: `'+dateval+'`]')
432
- st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please use sparingly because it costs me money right now. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
 
 
433
 
434
- query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
435
- return_n = st.slider('How many papers should I show?', 1, 20, 10)
 
436
 
437
- sims = run_query(query, return_n = return_n)
 
 
1
  import os
2
  import datetime
3
  import faiss
 
180
  for i in range(start_range,start_range+return_n):
181
 
182
  abstracts_relevant.append(all_text[sims[i]])
183
+ fhdr = str(sims[i])+'_'+all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
184
  fhdrs.append(fhdr)
185
  textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
186
  textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
 
324
  temp = temp[0:-2] + ' et al. 19' + temp[-2:]
325
  temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
326
  st.markdown(temp)
327
+ abs_indices = np.array(srcindices)
328
 
329
  fig = plt.figure(figsize=(9,9))
330
  plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
 
337
 
338
  return rag_answer
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  st.title('ArXiv-based question answering')
342
  st.markdown('[Includes papers up to: `'+dateval+'`]')
343
+ st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
344
+ st.markdown('The answers are followed by relevant source(s) used in the answer, a graph showing which part of the astro-ph.GA manifold it drew the answer from (tightly clustered points generally indicate high quality/consensus answers) followed by a bunch of relevant papers used by the RAG to compose the answer.')
345
+ st.markdown('If this does not satisfactorily answer your question or rambles too much, you can also try the older `qa_sources_v1` page.')
346
 
347
+ query = st.text_input('Your question here:',
348
+ value="What causes galaxy quenching at high redshifts?")
349
+ return_n = st.slider('How many papers should I show?', 1, 30, 10)
350
 
351
+ sims = run_rag(query, return_n = return_n)
pages/{3_qa_sources_v1.py β†’ 4_qa_sources_v1.py} RENAMED
@@ -118,7 +118,7 @@ def find_papers_by_author(auth_name):
118
 
119
  return doc_ids
120
 
121
- def faiss_based_indices(input_vector, nindex=10):
122
  xq = input_vector.reshape(-1,1).T.astype('float32')
123
  D, I = index.search(xq, nindex)
124
  return I[0], D[0]
@@ -126,7 +126,7 @@ def faiss_based_indices(input_vector, nindex=10):
126
  def list_similar_papers_v2(model_data,
127
  doc_id = [], input_type = 'doc_id',
128
  show_authors = False, show_summary = False,
129
- return_n = 10):
130
 
131
  arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
132
 
@@ -152,7 +152,7 @@ def list_similar_papers_v2(model_data,
152
  print('unrecognized input type.')
153
  return
154
 
155
- sims, dists = faiss_based_indices(inferred_vector, return_n+2)
156
  textstr = ''
157
  abstracts_relevant = []
158
  fhdrs = []
@@ -182,30 +182,9 @@ def list_similar_papers_v2(model_data,
182
  textstr = textstr + ' \n'
183
  return textstr, abstracts_relevant, fhdrs, sims
184
 
185
-
186
- def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
187
- headers = {
188
- "Content-Type": "application/json",
189
- "Authorization": f"Bearer {openai.api_key}",
190
- }
191
-
192
- data = {
193
- "model": model,
194
- "messages": messages,
195
- "temperature": temperature,
196
- }
197
-
198
- if max_tokens is not None:
199
- data["max_tokens"] = max_tokens
200
- response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
201
- if response.status_code == 200:
202
- return response.json()["choices"][0]["message"]["content"]
203
- else:
204
- raise Exception(f"Error {response.status_code}: {response.text}")
205
-
206
  model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
207
 
208
- def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
209
 
210
  show_authors = True
211
  show_summary = True
@@ -213,7 +192,7 @@ def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources =
213
  doc_id = query,
214
  input_type='keywords',
215
  show_authors = show_authors, show_summary = show_summary,
216
- return_n = return_n)
217
 
218
  temp_abst = ''
219
  loaders = []
@@ -300,5 +279,8 @@ st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please
300
 
301
  query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
302
  return_n = st.slider('How many papers should I show?', 1, 20, 10)
 
 
 
303
 
304
- sims = run_query(query, return_n = return_n)
 
118
 
119
  return doc_ids
120
 
121
+ def faiss_based_indices(input_vector, nindex=10, yrmin = 1990, yrmax = 2024):
122
  xq = input_vector.reshape(-1,1).T.astype('float32')
123
  D, I = index.search(xq, nindex)
124
  return I[0], D[0]
 
126
  def list_similar_papers_v2(model_data,
127
  doc_id = [], input_type = 'doc_id',
128
  show_authors = False, show_summary = False,
129
+ return_n = 10, yrmin = 1990, yrmax = 2024):
130
 
131
  arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
132
 
 
152
  print('unrecognized input type.')
153
  return
154
 
155
+ sims, dists = faiss_based_indices(inferred_vector, return_n+2, yrmin = 1990, yrmax = 2024)
156
  textstr = ''
157
  abstracts_relevant = []
158
  fhdrs = []
 
182
  textstr = textstr + ' \n'
183
  return textstr, abstracts_relevant, fhdrs, sims
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
186
 
187
+ def run_query(query, return_n = 3, yrmin = 1990, yrmax = 2024, show_pure_answer = False, show_all_sources = True):
188
 
189
  show_authors = True
190
  show_summary = True
 
192
  doc_id = query,
193
  input_type='keywords',
194
  show_authors = show_authors, show_summary = show_summary,
195
+ return_n = return_n, yrmin = 1990, yrmax = 2024)
196
 
197
  temp_abst = ''
198
  loaders = []
 
279
 
280
  query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
281
  return_n = st.slider('How many papers should I show?', 1, 20, 10)
282
+ yrmin = st.slider('Min year', 1990,2023, 1990)
283
+ yrmax = st.slider('Max year', 1990, 2024, 2024)
284
+
285
 
286
+ sims = run_query(query, return_n = return_n, yrmin = yrmin, yrmax = yrmax)