Spaces:
Runtime error
Runtime error
adding current streamlit files
Browse files- LICENSE +21 -0
- README.md +2 -0
- pages/.ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
- pages/1_paper_search.py +191 -0
- pages/2_arxiv_embedding.py +72 -0
- pages/3_qa_sources.py +292 -0
- pages/Untitled.ipynb +6 -0
- requirements.txt +10 -0
- streamlit_app.py +26 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Kartheik Iyer
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -10,4 +10,6 @@ pinned: false
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
|
|
|
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
10 |
license: mit
|
11 |
---
|
12 |
|
13 |
+
An extension of chaotic_neural to visualize papers clustered using GPT-based embeddings
|
14 |
+
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
pages/.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
pages/1_paper_search.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime, os
|
2 |
+
from langchain.llms import OpenAI
|
3 |
+
from langchain.embeddings import OpenAIEmbeddings
|
4 |
+
import openai
|
5 |
+
import faiss
|
6 |
+
import streamlit as st
|
7 |
+
import feedparser
|
8 |
+
import urllib
|
9 |
+
import cloudpickle as cp
|
10 |
+
from urllib.request import urlopen
|
11 |
+
from summa import summarizer
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
# openai.organization = "org-EBvNTjd2pLrK4vOhFlNMLr2v"
|
15 |
+
openai.organization = st.secrets.openai.org
|
16 |
+
openai.api_key = st.secrets.openai.api_key
|
17 |
+
os.environ["OPENAI_API_KEY"] = openai.api_key
|
18 |
+
|
19 |
+
@st.cache_data
|
20 |
+
def get_feeds_data(url):
|
21 |
+
data = cp.load(urlopen(url))
|
22 |
+
st.sidebar.success("Fetched data from API!")
|
23 |
+
return data
|
24 |
+
|
25 |
+
embeddings = OpenAIEmbeddings()
|
26 |
+
|
27 |
+
feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
|
28 |
+
embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
|
29 |
+
gal_feeds = get_feeds_data(feeds_link)
|
30 |
+
arxiv_ada_embeddings = get_feeds_data(embed_link)
|
31 |
+
|
32 |
+
ctr = -1
|
33 |
+
num_chunks = len(gal_feeds)
|
34 |
+
all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
|
35 |
+
|
36 |
+
for nc in range(num_chunks):
|
37 |
+
|
38 |
+
for i in range(len(gal_feeds[nc].entries)):
|
39 |
+
text = gal_feeds[nc].entries[i].summary
|
40 |
+
text = text.replace('\n', ' ')
|
41 |
+
text = text.replace('\\', '')
|
42 |
+
all_text.append(text)
|
43 |
+
all_titles.append(gal_feeds[nc].entries[i].title)
|
44 |
+
all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
|
45 |
+
all_links.append(gal_feeds[nc].entries[i].links[1].href)
|
46 |
+
all_authors.append(gal_feeds[nc].entries[i].authors)
|
47 |
+
|
48 |
+
d = arxiv_ada_embeddings.shape[1] # dimension
|
49 |
+
nb = arxiv_ada_embeddings.shape[0] # database size
|
50 |
+
xb = arxiv_ada_embeddings.astype('float32')
|
51 |
+
index = faiss.IndexFlatL2(d)
|
52 |
+
index.add(xb)
|
53 |
+
|
54 |
+
def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
|
55 |
+
"""
|
56 |
+
Query ArXiv to return search results for a particular query
|
57 |
+
Parameters
|
58 |
+
----------
|
59 |
+
query: str
|
60 |
+
query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
|
61 |
+
max_results: int, default = 10
|
62 |
+
number of results to return. numbers > 1000 generally lead to timeouts
|
63 |
+
start: int, default = 0
|
64 |
+
start index for results reported. use this if you're interested in running chunks.
|
65 |
+
Returns
|
66 |
+
-------
|
67 |
+
feed: dict
|
68 |
+
object containing requested results parsed with feedparser
|
69 |
+
Notes
|
70 |
+
-----
|
71 |
+
add functionality for chunk parsing, as well as storage and retreival
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Base api query url
|
75 |
+
base_url = 'http://export.arxiv.org/api/query?';
|
76 |
+
query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
|
77 |
+
start,
|
78 |
+
max_results,sort_by,sort_order)
|
79 |
+
|
80 |
+
response = urllib.request.urlopen(base_url+query).read()
|
81 |
+
feed = feedparser.parse(response)
|
82 |
+
return feed
|
83 |
+
|
84 |
+
def find_papers_by_author(auth_name):
|
85 |
+
|
86 |
+
doc_ids = []
|
87 |
+
for doc_id in range(len(all_authors)):
|
88 |
+
for auth_id in range(len(all_authors[doc_id])):
|
89 |
+
if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
|
90 |
+
print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
|
91 |
+
doc_ids.append(doc_id)
|
92 |
+
|
93 |
+
return doc_ids
|
94 |
+
|
95 |
+
def faiss_based_indices(input_vector, nindex=10):
|
96 |
+
xq = input_vector.reshape(-1,1).T.astype('float32')
|
97 |
+
D, I = index.search(xq, nindex)
|
98 |
+
return I[0], D[0]
|
99 |
+
|
100 |
+
|
101 |
+
def list_similar_papers_v2(model_data,
|
102 |
+
doc_id = [], input_type = 'doc_id',
|
103 |
+
show_authors = False, show_summary = False,
|
104 |
+
return_n = 10):
|
105 |
+
|
106 |
+
arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
|
107 |
+
|
108 |
+
if input_type == 'doc_id':
|
109 |
+
print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
|
110 |
+
# inferred_vector = model.infer_vector(train_corpus[doc_id].words)
|
111 |
+
inferred_vector = arxiv_ada_embeddings[doc_id,0:]
|
112 |
+
start_range = 1
|
113 |
+
elif input_type == 'arxiv_id':
|
114 |
+
print('ArXiv id: ',doc_id)
|
115 |
+
arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
|
116 |
+
if len(arxiv_query_feed.entries) == 0:
|
117 |
+
print('error: arxiv id not found.')
|
118 |
+
return
|
119 |
+
else:
|
120 |
+
print('Title: '+arxiv_query_feed.entries[0].title)
|
121 |
+
inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
|
122 |
+
# arxiv_query_tokens = gensim.utils.simple_preprocess(arxiv_query_feed.entries[0].summary)
|
123 |
+
# inferred_vector = model.infer_vector(arxiv_query_tokens)
|
124 |
+
|
125 |
+
start_range = 0
|
126 |
+
elif input_type == 'keywords':
|
127 |
+
# print('Keyword(s): ',[doc_id[i] for i in range(len(doc_id))])
|
128 |
+
# word_vector = model.wv[doc_id[0]]
|
129 |
+
# if len(doc_id) > 1:
|
130 |
+
# print('multi-keyword')
|
131 |
+
# for i in range(1,len(doc_id)):
|
132 |
+
# word_vector = word_vector + model.wv[doc_id[i]]
|
133 |
+
# # word_vector = model.infer_vector(doc_id)
|
134 |
+
# inferred_vector = word_vector
|
135 |
+
inferred_vector = np.array(embeddings.embed_query(doc_id))
|
136 |
+
start_range = 0
|
137 |
+
else:
|
138 |
+
print('unrecognized input type.')
|
139 |
+
return
|
140 |
+
|
141 |
+
# sims = model.docvecs.most_similar([inferred_vector], topn=len(model.docvecs))
|
142 |
+
sims, dists = faiss_based_indices(inferred_vector, return_n+2)
|
143 |
+
textstr = ''
|
144 |
+
|
145 |
+
textstr = textstr + '-----------------------------\n'
|
146 |
+
textstr = textstr + 'Most similar/relevant papers: \n'
|
147 |
+
textstr = textstr + '-----------------------------\n\n'
|
148 |
+
for i in range(start_range,start_range+return_n):
|
149 |
+
|
150 |
+
# print(i, all_titles[sims[i]], ' (Distance: %.2f' %dists[i] ,')')
|
151 |
+
textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
|
152 |
+
textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
|
153 |
+
if show_authors == True:
|
154 |
+
textstr = textstr + '**Authors:** '
|
155 |
+
temp = all_authors[sims[i]]
|
156 |
+
for ak in range(len(temp)):
|
157 |
+
if ak < len(temp)-1:
|
158 |
+
textstr = textstr + temp[ak].name + ', '
|
159 |
+
else:
|
160 |
+
textstr = textstr + temp[ak].name + ' \n'
|
161 |
+
if show_summary == True:
|
162 |
+
textstr = textstr + '**Summary:** '
|
163 |
+
text = all_text[sims[i]]
|
164 |
+
text = text.replace('\n', ' ')
|
165 |
+
textstr = textstr + summarizer.summarize(text) + ' \n'
|
166 |
+
if show_authors == True or show_summary == True:
|
167 |
+
textstr = textstr + ' '
|
168 |
+
textstr = textstr + ' \n'
|
169 |
+
return textstr
|
170 |
+
|
171 |
+
|
172 |
+
model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
|
173 |
+
|
174 |
+
st.title('ArXiv similarity search:')
|
175 |
+
st.markdown('Search for similar papers by arxiv id or phrase:')
|
176 |
+
|
177 |
+
search_type = st.radio(
|
178 |
+
"What are you searching by?",
|
179 |
+
('arxiv id', 'text query'), index=1)
|
180 |
+
|
181 |
+
query = st.text_input('Search query or arxivid', value="what causes galaxy quenching?")
|
182 |
+
show_authors = st.checkbox('Show author information', value = True)
|
183 |
+
show_summary = st.checkbox('Show paper summary', value = True)
|
184 |
+
return_n = st.slider('How many papers should I show?', 1, 30, 10)
|
185 |
+
|
186 |
+
if search_type == 'arxiv id':
|
187 |
+
sims = list_similar_papers_v2(model_data, doc_id = query, input_type='arxiv_id', show_authors = show_authors, show_summary = show_summary, return_n = return_n)
|
188 |
+
else:
|
189 |
+
sims = list_similar_papers_v2(model_data, doc_id = query, input_type='keywords', show_authors = show_authors, show_summary = show_summary, return_n = return_n)
|
190 |
+
|
191 |
+
st.markdown(sims)
|
pages/2_arxiv_embedding.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import pickle
|
6 |
+
from bokeh.palettes import OrRd
|
7 |
+
from bokeh.plotting import figure, show
|
8 |
+
from bokeh.plotting import ColumnDataSource, figure, output_notebook, show
|
9 |
+
import cloudpickle as cp
|
10 |
+
from scipy import stats
|
11 |
+
from urllib.request import urlopen
|
12 |
+
|
13 |
+
st.title("ArXiv+GPT3 embedding explorer")
|
14 |
+
st.markdown("This is an explorer for astro-ph.GA papers on the arXiv (up to Apt 18th, 2023). The papers have been preprocessed with `chaotic_neural` [(link)](http://chaotic-neural.readthedocs.io/) after which the collected abstracts are run through `text-embedding-ada-002` with [langchain](https://python.langchain.com/en/latest/ecosystem/openai.html) to generate a unique vector correpsonding to each paper. These are then compressed using [umap](https://umap-learn.readthedocs.io/en/latest/) and shown here, and can be used for similarity searches with methods like [faiss](https://github.com/facebookresearch/faiss). The scatterplot here can be paired with a heatmap for more targeted searches looking at a specific topic or area (see sidebar). Upgrade to chaotic neural suggested by Jo Ciucฤ, thank you! More to come (hopefully) with GPT-4 and its applications!")
|
15 |
+
st.markdown("Interpreting the UMAP plot: the algorithm creates a 2d embedding from the high-dim vector space that tries to conserve as much similarity information as possible. Nearby points in UMAP space are similar, and grow dissimiliar as you move farther away. The axes do not have any physical meaning.")
|
16 |
+
|
17 |
+
@st.cache_data
|
18 |
+
def get_embedding_data(url):
|
19 |
+
data = cp.load(urlopen(url))
|
20 |
+
st.sidebar.success("Fetched data from API!")
|
21 |
+
return data
|
22 |
+
|
23 |
+
url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
|
24 |
+
embedding, all_text, all_titles, all_arxivid, all_links = get_embedding_data(url)
|
25 |
+
|
26 |
+
def density_estimation(m1, m2, xmin=0, ymin=0, xmax=15, ymax=15):
|
27 |
+
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
|
28 |
+
positions = np.vstack([X.ravel(), Y.ravel()])
|
29 |
+
values = np.vstack([m1, m2])
|
30 |
+
kernel = stats.gaussian_kde(values)
|
31 |
+
Z = np.reshape(kernel(positions).T, X.shape)
|
32 |
+
return X, Y, Z
|
33 |
+
|
34 |
+
st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
|
35 |
+
st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
|
36 |
+
|
37 |
+
st.sidebar.text_input("Search query", key="phrase", value="")
|
38 |
+
alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.1)
|
39 |
+
size_value = st.sidebar.slider("Pick the hexbin size",0.0,2.0,0.2)
|
40 |
+
|
41 |
+
phrase=st.session_state.phrase
|
42 |
+
|
43 |
+
phrase_flags = np.zeros((len(all_text),))
|
44 |
+
for i in range(len(all_text)):
|
45 |
+
if phrase.lower() in all_text[i].lower():
|
46 |
+
phrase_flags[i] = 1
|
47 |
+
|
48 |
+
|
49 |
+
source = ColumnDataSource(data=dict(
|
50 |
+
x=embedding[0:,0],
|
51 |
+
y=embedding[0:,1],
|
52 |
+
title=all_titles,
|
53 |
+
link=all_links,
|
54 |
+
))
|
55 |
+
|
56 |
+
TOOLTIPS = """
|
57 |
+
<div style="width:300px;">
|
58 |
+
ID: $index
|
59 |
+
($x, $y)
|
60 |
+
@title <br>
|
61 |
+
@link <br> <br>
|
62 |
+
</div>
|
63 |
+
"""
|
64 |
+
|
65 |
+
p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
|
66 |
+
title="UMAP projection of trained ArXiv corpus | heatmap keyword: "+phrase)
|
67 |
+
|
68 |
+
p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
|
69 |
+
palette = np.flip(OrRd[8]), alpha=alpha_value)
|
70 |
+
p.circle('x', 'y', size=3, source=source, alpha=0.3)
|
71 |
+
|
72 |
+
st.bokeh_chart(p)
|
pages/3_qa_sources.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime, os
|
2 |
+
from langchain.llms import OpenAI
|
3 |
+
from langchain.embeddings import OpenAIEmbeddings
|
4 |
+
import openai
|
5 |
+
import faiss
|
6 |
+
import streamlit as st
|
7 |
+
import feedparser
|
8 |
+
import urllib
|
9 |
+
import cloudpickle as cp
|
10 |
+
from urllib.request import urlopen
|
11 |
+
from summa import summarizer
|
12 |
+
import numpy as np
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
import requests
|
16 |
+
import json
|
17 |
+
from langchain.document_loaders import TextLoader
|
18 |
+
from langchain.indexes import VectorstoreIndexCreator
|
19 |
+
API_ENDPOINT = "https://api.openai.com/v1/chat/completions"
|
20 |
+
|
21 |
+
openai.organization = st.secrets.openai.org
|
22 |
+
openai.api_key = st.secrets.openai.api_key
|
23 |
+
os.environ["OPENAI_API_KEY"] = openai.api_key
|
24 |
+
|
25 |
+
@st.cache_data
|
26 |
+
def get_feeds_data(url):
|
27 |
+
data = cp.load(urlopen(url))
|
28 |
+
st.sidebar.success("Fetched data from API!")
|
29 |
+
return data
|
30 |
+
|
31 |
+
embeddings = OpenAIEmbeddings()
|
32 |
+
|
33 |
+
feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
|
34 |
+
embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
|
35 |
+
gal_feeds = get_feeds_data(feeds_link)
|
36 |
+
arxiv_ada_embeddings = get_feeds_data(embed_link)
|
37 |
+
|
38 |
+
@st.cache_data
|
39 |
+
def get_embedding_data(url):
|
40 |
+
data = cp.load(urlopen(url))
|
41 |
+
st.sidebar.success("Fetched data from API!")
|
42 |
+
return data
|
43 |
+
|
44 |
+
url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
|
45 |
+
e2d, _, _, _, _ = get_embedding_data(url)
|
46 |
+
|
47 |
+
ctr = -1
|
48 |
+
num_chunks = len(gal_feeds)
|
49 |
+
all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
|
50 |
+
|
51 |
+
for nc in range(num_chunks):
|
52 |
+
|
53 |
+
for i in range(len(gal_feeds[nc].entries)):
|
54 |
+
text = gal_feeds[nc].entries[i].summary
|
55 |
+
text = text.replace('\n', ' ')
|
56 |
+
text = text.replace('\\', '')
|
57 |
+
all_text.append(text)
|
58 |
+
all_titles.append(gal_feeds[nc].entries[i].title)
|
59 |
+
all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
|
60 |
+
all_links.append(gal_feeds[nc].entries[i].links[1].href)
|
61 |
+
all_authors.append(gal_feeds[nc].entries[i].authors)
|
62 |
+
|
63 |
+
d = arxiv_ada_embeddings.shape[1] # dimension
|
64 |
+
nb = arxiv_ada_embeddings.shape[0] # database size
|
65 |
+
xb = arxiv_ada_embeddings.astype('float32')
|
66 |
+
index = faiss.IndexFlatL2(d)
|
67 |
+
index.add(xb)
|
68 |
+
|
69 |
+
def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
|
70 |
+
"""
|
71 |
+
Query ArXiv to return search results for a particular query
|
72 |
+
Parameters
|
73 |
+
----------
|
74 |
+
query: str
|
75 |
+
query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
|
76 |
+
max_results: int, default = 10
|
77 |
+
number of results to return. numbers > 1000 generally lead to timeouts
|
78 |
+
start: int, default = 0
|
79 |
+
start index for results reported. use this if you're interested in running chunks.
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
feed: dict
|
83 |
+
object containing requested results parsed with feedparser
|
84 |
+
Notes
|
85 |
+
-----
|
86 |
+
add functionality for chunk parsing, as well as storage and retreival
|
87 |
+
"""
|
88 |
+
|
89 |
+
base_url = 'http://export.arxiv.org/api/query?';
|
90 |
+
query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
|
91 |
+
start,
|
92 |
+
max_results,sort_by,sort_order)
|
93 |
+
|
94 |
+
response = urllib.request.urlopen(base_url+query).read()
|
95 |
+
feed = feedparser.parse(response)
|
96 |
+
return feed
|
97 |
+
|
98 |
+
def find_papers_by_author(auth_name):
|
99 |
+
|
100 |
+
doc_ids = []
|
101 |
+
for doc_id in range(len(all_authors)):
|
102 |
+
for auth_id in range(len(all_authors[doc_id])):
|
103 |
+
if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
|
104 |
+
print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
|
105 |
+
doc_ids.append(doc_id)
|
106 |
+
|
107 |
+
return doc_ids
|
108 |
+
|
109 |
+
def faiss_based_indices(input_vector, nindex=10):
|
110 |
+
xq = input_vector.reshape(-1,1).T.astype('float32')
|
111 |
+
D, I = index.search(xq, nindex)
|
112 |
+
return I[0], D[0]
|
113 |
+
|
114 |
+
def list_similar_papers_v2(model_data,
|
115 |
+
doc_id = [], input_type = 'doc_id',
|
116 |
+
show_authors = False, show_summary = False,
|
117 |
+
return_n = 10):
|
118 |
+
|
119 |
+
arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
|
120 |
+
|
121 |
+
if input_type == 'doc_id':
|
122 |
+
print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
|
123 |
+
# inferred_vector = model.infer_vector(train_corpus[doc_id].words)
|
124 |
+
inferred_vector = arxiv_ada_embeddings[doc_id,0:]
|
125 |
+
start_range = 1
|
126 |
+
elif input_type == 'arxiv_id':
|
127 |
+
print('ArXiv id: ',doc_id)
|
128 |
+
arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
|
129 |
+
if len(arxiv_query_feed.entries) == 0:
|
130 |
+
print('error: arxiv id not found.')
|
131 |
+
return
|
132 |
+
else:
|
133 |
+
print('Title: '+arxiv_query_feed.entries[0].title)
|
134 |
+
inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
|
135 |
+
start_range = 0
|
136 |
+
elif input_type == 'keywords':
|
137 |
+
inferred_vector = np.array(embeddings.embed_query(doc_id))
|
138 |
+
start_range = 0
|
139 |
+
else:
|
140 |
+
print('unrecognized input type.')
|
141 |
+
return
|
142 |
+
|
143 |
+
sims, dists = faiss_based_indices(inferred_vector, return_n+2)
|
144 |
+
textstr = ''
|
145 |
+
abstracts_relevant = []
|
146 |
+
fhdrs = []
|
147 |
+
|
148 |
+
for i in range(start_range,start_range+return_n):
|
149 |
+
|
150 |
+
abstracts_relevant.append(all_text[sims[i]])
|
151 |
+
fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
|
152 |
+
fhdrs.append(fhdr)
|
153 |
+
textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
|
154 |
+
textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
|
155 |
+
if show_authors == True:
|
156 |
+
textstr = textstr + '**Authors:** '
|
157 |
+
temp = all_authors[sims[i]]
|
158 |
+
for ak in range(len(temp)):
|
159 |
+
if ak < len(temp)-1:
|
160 |
+
textstr = textstr + temp[ak].name + ', '
|
161 |
+
else:
|
162 |
+
textstr = textstr + temp[ak].name + ' \n'
|
163 |
+
if show_summary == True:
|
164 |
+
textstr = textstr + '**Summary:** '
|
165 |
+
text = all_text[sims[i]]
|
166 |
+
text = text.replace('\n', ' ')
|
167 |
+
textstr = textstr + summarizer.summarize(text) + ' \n'
|
168 |
+
if show_authors == True or show_summary == True:
|
169 |
+
textstr = textstr + ' '
|
170 |
+
textstr = textstr + ' \n'
|
171 |
+
return textstr, abstracts_relevant, fhdrs, sims
|
172 |
+
|
173 |
+
|
174 |
+
def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
|
175 |
+
headers = {
|
176 |
+
"Content-Type": "application/json",
|
177 |
+
"Authorization": f"Bearer {openai.api_key}",
|
178 |
+
}
|
179 |
+
|
180 |
+
data = {
|
181 |
+
"model": model,
|
182 |
+
"messages": messages,
|
183 |
+
"temperature": temperature,
|
184 |
+
}
|
185 |
+
|
186 |
+
if max_tokens is not None:
|
187 |
+
data["max_tokens"] = max_tokens
|
188 |
+
response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
|
189 |
+
if response.status_code == 200:
|
190 |
+
return response.json()["choices"][0]["message"]["content"]
|
191 |
+
else:
|
192 |
+
raise Exception(f"Error {response.status_code}: {response.text}")
|
193 |
+
|
194 |
+
|
195 |
+
model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
|
196 |
+
|
197 |
+
def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
|
198 |
+
|
199 |
+
show_authors = True
|
200 |
+
show_summary = True
|
201 |
+
sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
|
202 |
+
doc_id = query,
|
203 |
+
input_type='keywords',
|
204 |
+
show_authors = show_authors, show_summary = show_summary,
|
205 |
+
return_n = return_n)
|
206 |
+
|
207 |
+
temp_abst = ''
|
208 |
+
loaders = []
|
209 |
+
for i in range(len(absts)):
|
210 |
+
temp_abst = absts[i]
|
211 |
+
|
212 |
+
try:
|
213 |
+
text_file = open("absts/"+fhdrs[i]+".txt", "w")
|
214 |
+
except:
|
215 |
+
os.mkdir('absts')
|
216 |
+
text_file = open("absts/"+fhdrs[i]+".txt", "w")
|
217 |
+
n = text_file.write(temp_abst)
|
218 |
+
text_file.close()
|
219 |
+
loader = TextLoader("absts/"+fhdrs[i]+".txt")
|
220 |
+
loaders.append(loader)
|
221 |
+
|
222 |
+
lc_index = VectorstoreIndexCreator().from_loaders(loaders)
|
223 |
+
|
224 |
+
st.markdown('### User query: '+query)
|
225 |
+
if show_pure_answer == True:
|
226 |
+
st.markdown('pure answer:')
|
227 |
+
st.markdown(lc_index.query(query))
|
228 |
+
st.markdown(' ')
|
229 |
+
st.markdown('#### context-based answer from sources:')
|
230 |
+
output = lc_index.query_with_sources(query)
|
231 |
+
st.markdown(output['answer'])
|
232 |
+
opstr = '#### Primary sources: \n'
|
233 |
+
st.markdown(opstr)
|
234 |
+
|
235 |
+
# opstr = ''
|
236 |
+
# for i in range(len(output['sources'])):
|
237 |
+
# opstr = opstr +'\n'+ output['sources'][i]
|
238 |
+
|
239 |
+
textstr = ''
|
240 |
+
ng = len(output['sources'].split())
|
241 |
+
abs_indices = []
|
242 |
+
|
243 |
+
for i in range(ng):
|
244 |
+
if i == (ng-1):
|
245 |
+
tempid = output['sources'].split()[i].split('_')[1][0:-4]
|
246 |
+
else:
|
247 |
+
tempid = output['sources'].split()[i].split('_')[1][0:-5]
|
248 |
+
try:
|
249 |
+
abs_index = all_arxivid.index(tempid)
|
250 |
+
abs_indices.append(abs_index)
|
251 |
+
textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
|
252 |
+
textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
|
253 |
+
textstr = textstr + '**Authors:** '
|
254 |
+
temp = all_authors[abs_index]
|
255 |
+
for ak in range(4):
|
256 |
+
if ak < len(temp)-1:
|
257 |
+
textstr = textstr + temp[ak].name + ', '
|
258 |
+
else:
|
259 |
+
textstr = textstr + temp[ak].name + ' \n'
|
260 |
+
if len(temp) > 3:
|
261 |
+
textstr = textstr + ' et al. \n'
|
262 |
+
textstr = textstr + '**Summary:** '
|
263 |
+
text = all_text[abs_index]
|
264 |
+
text = text.replace('\n', ' ')
|
265 |
+
textstr = textstr + summarizer.summarize(text) + ' \n'
|
266 |
+
except:
|
267 |
+
textstr = textstr + output['sources'].split()[i]
|
268 |
+
# opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
|
269 |
+
# opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
|
270 |
+
|
271 |
+
textstr = textstr + ' '
|
272 |
+
textstr = textstr + ' \n'
|
273 |
+
st.markdown(textstr)
|
274 |
+
|
275 |
+
fig = plt.figure(figsize=(9,9))
|
276 |
+
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
|
277 |
+
plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
|
278 |
+
plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
|
279 |
+
st.pyplot(fig)
|
280 |
+
|
281 |
+
if show_all_sources == True:
|
282 |
+
st.markdown('\n #### Other interesting papers:')
|
283 |
+
st.markdown(sims)
|
284 |
+
return output
|
285 |
+
|
286 |
+
st.title('ArXiv-based question answering')
|
287 |
+
st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please use sparingly because it costs me money right now. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
|
288 |
+
|
289 |
+
query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
|
290 |
+
return_n = st.slider('How many papers should I show?', 1, 20, 10)
|
291 |
+
|
292 |
+
sims = run_query(query, return_n = return_n)
|
pages/Untitled.ipynb
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [],
|
3 |
+
"metadata": {},
|
4 |
+
"nbformat": 4,
|
5 |
+
"nbformat_minor": 5
|
6 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
bokeh==2.4.3
|
3 |
+
cloudpickle
|
4 |
+
scipy
|
5 |
+
summa
|
6 |
+
faiss-cpu
|
7 |
+
langchain
|
8 |
+
openai
|
9 |
+
feedparser
|
10 |
+
tiktoken
|
streamlit_app.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="arXiv-GPT",
|
5 |
+
page_icon="๐",
|
6 |
+
)
|
7 |
+
|
8 |
+
st.write("# Welcome to arXiv-GPT! ๐")
|
9 |
+
|
10 |
+
st.sidebar.success("Select a function above.")
|
11 |
+
st.sidebar.markdown("Current functions include visualizing papers in the arxiv embedding, or searching for similar papers to an input paper or prompt phrase.")
|
12 |
+
|
13 |
+
st.markdown(
|
14 |
+
"""
|
15 |
+
arXiv+GPT is a framework for searching and visualizing papers on
|
16 |
+
the [arXiv](https://arxiv.org/) using the context sensitivity from modern
|
17 |
+
large language models (LLMs) like GPT3 to better link paper contexts
|
18 |
+
|
19 |
+
**๐ Select a tool from the sidebar** to see some examples
|
20 |
+
of what this framework can do!
|
21 |
+
### Want to learn more?
|
22 |
+
- Check out `chaotic_neural` [(link)](http://chaotic-neural.readthedocs.io/)
|
23 |
+
- Jump into our [documentation](https://docs.streamlit.io)
|
24 |
+
- Contribute!
|
25 |
+
"""
|
26 |
+
)
|