kiyer commited on
Commit
fe4a4f7
β€’
1 Parent(s): 2c466c8

adding current streamlit files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Kartheik Iyer
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -10,4 +10,6 @@ pinned: false
10
  license: mit
11
  ---
12
 
 
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  ---
12
 
13
+ An extension of chaotic_neural to visualize papers clustered using GPT-based embeddings
14
+
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
pages/.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
pages/1_paper_search.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime, os
2
+ from langchain.llms import OpenAI
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ import openai
5
+ import faiss
6
+ import streamlit as st
7
+ import feedparser
8
+ import urllib
9
+ import cloudpickle as cp
10
+ from urllib.request import urlopen
11
+ from summa import summarizer
12
+ import numpy as np
13
+
14
+ # openai.organization = "org-EBvNTjd2pLrK4vOhFlNMLr2v"
15
+ openai.organization = st.secrets.openai.org
16
+ openai.api_key = st.secrets.openai.api_key
17
+ os.environ["OPENAI_API_KEY"] = openai.api_key
18
+
19
+ @st.cache_data
20
+ def get_feeds_data(url):
21
+ data = cp.load(urlopen(url))
22
+ st.sidebar.success("Fetched data from API!")
23
+ return data
24
+
25
+ embeddings = OpenAIEmbeddings()
26
+
27
+ feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
28
+ embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
29
+ gal_feeds = get_feeds_data(feeds_link)
30
+ arxiv_ada_embeddings = get_feeds_data(embed_link)
31
+
32
+ ctr = -1
33
+ num_chunks = len(gal_feeds)
34
+ all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
35
+
36
+ for nc in range(num_chunks):
37
+
38
+ for i in range(len(gal_feeds[nc].entries)):
39
+ text = gal_feeds[nc].entries[i].summary
40
+ text = text.replace('\n', ' ')
41
+ text = text.replace('\\', '')
42
+ all_text.append(text)
43
+ all_titles.append(gal_feeds[nc].entries[i].title)
44
+ all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
45
+ all_links.append(gal_feeds[nc].entries[i].links[1].href)
46
+ all_authors.append(gal_feeds[nc].entries[i].authors)
47
+
48
+ d = arxiv_ada_embeddings.shape[1] # dimension
49
+ nb = arxiv_ada_embeddings.shape[0] # database size
50
+ xb = arxiv_ada_embeddings.astype('float32')
51
+ index = faiss.IndexFlatL2(d)
52
+ index.add(xb)
53
+
54
+ def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
55
+ """
56
+ Query ArXiv to return search results for a particular query
57
+ Parameters
58
+ ----------
59
+ query: str
60
+ query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
61
+ max_results: int, default = 10
62
+ number of results to return. numbers > 1000 generally lead to timeouts
63
+ start: int, default = 0
64
+ start index for results reported. use this if you're interested in running chunks.
65
+ Returns
66
+ -------
67
+ feed: dict
68
+ object containing requested results parsed with feedparser
69
+ Notes
70
+ -----
71
+ add functionality for chunk parsing, as well as storage and retreival
72
+ """
73
+
74
+ # Base api query url
75
+ base_url = 'http://export.arxiv.org/api/query?';
76
+ query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
77
+ start,
78
+ max_results,sort_by,sort_order)
79
+
80
+ response = urllib.request.urlopen(base_url+query).read()
81
+ feed = feedparser.parse(response)
82
+ return feed
83
+
84
+ def find_papers_by_author(auth_name):
85
+
86
+ doc_ids = []
87
+ for doc_id in range(len(all_authors)):
88
+ for auth_id in range(len(all_authors[doc_id])):
89
+ if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
90
+ print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
91
+ doc_ids.append(doc_id)
92
+
93
+ return doc_ids
94
+
95
+ def faiss_based_indices(input_vector, nindex=10):
96
+ xq = input_vector.reshape(-1,1).T.astype('float32')
97
+ D, I = index.search(xq, nindex)
98
+ return I[0], D[0]
99
+
100
+
101
+ def list_similar_papers_v2(model_data,
102
+ doc_id = [], input_type = 'doc_id',
103
+ show_authors = False, show_summary = False,
104
+ return_n = 10):
105
+
106
+ arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
107
+
108
+ if input_type == 'doc_id':
109
+ print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
110
+ # inferred_vector = model.infer_vector(train_corpus[doc_id].words)
111
+ inferred_vector = arxiv_ada_embeddings[doc_id,0:]
112
+ start_range = 1
113
+ elif input_type == 'arxiv_id':
114
+ print('ArXiv id: ',doc_id)
115
+ arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
116
+ if len(arxiv_query_feed.entries) == 0:
117
+ print('error: arxiv id not found.')
118
+ return
119
+ else:
120
+ print('Title: '+arxiv_query_feed.entries[0].title)
121
+ inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
122
+ # arxiv_query_tokens = gensim.utils.simple_preprocess(arxiv_query_feed.entries[0].summary)
123
+ # inferred_vector = model.infer_vector(arxiv_query_tokens)
124
+
125
+ start_range = 0
126
+ elif input_type == 'keywords':
127
+ # print('Keyword(s): ',[doc_id[i] for i in range(len(doc_id))])
128
+ # word_vector = model.wv[doc_id[0]]
129
+ # if len(doc_id) > 1:
130
+ # print('multi-keyword')
131
+ # for i in range(1,len(doc_id)):
132
+ # word_vector = word_vector + model.wv[doc_id[i]]
133
+ # # word_vector = model.infer_vector(doc_id)
134
+ # inferred_vector = word_vector
135
+ inferred_vector = np.array(embeddings.embed_query(doc_id))
136
+ start_range = 0
137
+ else:
138
+ print('unrecognized input type.')
139
+ return
140
+
141
+ # sims = model.docvecs.most_similar([inferred_vector], topn=len(model.docvecs))
142
+ sims, dists = faiss_based_indices(inferred_vector, return_n+2)
143
+ textstr = ''
144
+
145
+ textstr = textstr + '-----------------------------\n'
146
+ textstr = textstr + 'Most similar/relevant papers: \n'
147
+ textstr = textstr + '-----------------------------\n\n'
148
+ for i in range(start_range,start_range+return_n):
149
+
150
+ # print(i, all_titles[sims[i]], ' (Distance: %.2f' %dists[i] ,')')
151
+ textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
152
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
153
+ if show_authors == True:
154
+ textstr = textstr + '**Authors:** '
155
+ temp = all_authors[sims[i]]
156
+ for ak in range(len(temp)):
157
+ if ak < len(temp)-1:
158
+ textstr = textstr + temp[ak].name + ', '
159
+ else:
160
+ textstr = textstr + temp[ak].name + ' \n'
161
+ if show_summary == True:
162
+ textstr = textstr + '**Summary:** '
163
+ text = all_text[sims[i]]
164
+ text = text.replace('\n', ' ')
165
+ textstr = textstr + summarizer.summarize(text) + ' \n'
166
+ if show_authors == True or show_summary == True:
167
+ textstr = textstr + ' '
168
+ textstr = textstr + ' \n'
169
+ return textstr
170
+
171
+
172
+ model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
173
+
174
+ st.title('ArXiv similarity search:')
175
+ st.markdown('Search for similar papers by arxiv id or phrase:')
176
+
177
+ search_type = st.radio(
178
+ "What are you searching by?",
179
+ ('arxiv id', 'text query'), index=1)
180
+
181
+ query = st.text_input('Search query or arxivid', value="what causes galaxy quenching?")
182
+ show_authors = st.checkbox('Show author information', value = True)
183
+ show_summary = st.checkbox('Show paper summary', value = True)
184
+ return_n = st.slider('How many papers should I show?', 1, 30, 10)
185
+
186
+ if search_type == 'arxiv id':
187
+ sims = list_similar_papers_v2(model_data, doc_id = query, input_type='arxiv_id', show_authors = show_authors, show_summary = show_summary, return_n = return_n)
188
+ else:
189
+ sims = list_similar_papers_v2(model_data, doc_id = query, input_type='keywords', show_authors = show_authors, show_summary = show_summary, return_n = return_n)
190
+
191
+ st.markdown(sims)
pages/2_arxiv_embedding.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import pickle
6
+ from bokeh.palettes import OrRd
7
+ from bokeh.plotting import figure, show
8
+ from bokeh.plotting import ColumnDataSource, figure, output_notebook, show
9
+ import cloudpickle as cp
10
+ from scipy import stats
11
+ from urllib.request import urlopen
12
+
13
+ st.title("ArXiv+GPT3 embedding explorer")
14
+ st.markdown("This is an explorer for astro-ph.GA papers on the arXiv (up to Apt 18th, 2023). The papers have been preprocessed with `chaotic_neural` [(link)](http://chaotic-neural.readthedocs.io/) after which the collected abstracts are run through `text-embedding-ada-002` with [langchain](https://python.langchain.com/en/latest/ecosystem/openai.html) to generate a unique vector correpsonding to each paper. These are then compressed using [umap](https://umap-learn.readthedocs.io/en/latest/) and shown here, and can be used for similarity searches with methods like [faiss](https://github.com/facebookresearch/faiss). The scatterplot here can be paired with a heatmap for more targeted searches looking at a specific topic or area (see sidebar). Upgrade to chaotic neural suggested by Jo Ciucă, thank you! More to come (hopefully) with GPT-4 and its applications!")
15
+ st.markdown("Interpreting the UMAP plot: the algorithm creates a 2d embedding from the high-dim vector space that tries to conserve as much similarity information as possible. Nearby points in UMAP space are similar, and grow dissimiliar as you move farther away. The axes do not have any physical meaning.")
16
+
17
+ @st.cache_data
18
+ def get_embedding_data(url):
19
+ data = cp.load(urlopen(url))
20
+ st.sidebar.success("Fetched data from API!")
21
+ return data
22
+
23
+ url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
24
+ embedding, all_text, all_titles, all_arxivid, all_links = get_embedding_data(url)
25
+
26
+ def density_estimation(m1, m2, xmin=0, ymin=0, xmax=15, ymax=15):
27
+ X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
28
+ positions = np.vstack([X.ravel(), Y.ravel()])
29
+ values = np.vstack([m1, m2])
30
+ kernel = stats.gaussian_kde(values)
31
+ Z = np.reshape(kernel(positions).T, X.shape)
32
+ return X, Y, Z
33
+
34
+ st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
35
+ st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
36
+
37
+ st.sidebar.text_input("Search query", key="phrase", value="")
38
+ alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.1)
39
+ size_value = st.sidebar.slider("Pick the hexbin size",0.0,2.0,0.2)
40
+
41
+ phrase=st.session_state.phrase
42
+
43
+ phrase_flags = np.zeros((len(all_text),))
44
+ for i in range(len(all_text)):
45
+ if phrase.lower() in all_text[i].lower():
46
+ phrase_flags[i] = 1
47
+
48
+
49
+ source = ColumnDataSource(data=dict(
50
+ x=embedding[0:,0],
51
+ y=embedding[0:,1],
52
+ title=all_titles,
53
+ link=all_links,
54
+ ))
55
+
56
+ TOOLTIPS = """
57
+ <div style="width:300px;">
58
+ ID: $index
59
+ ($x, $y)
60
+ @title <br>
61
+ @link <br> <br>
62
+ </div>
63
+ """
64
+
65
+ p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
66
+ title="UMAP projection of trained ArXiv corpus | heatmap keyword: "+phrase)
67
+
68
+ p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
69
+ palette = np.flip(OrRd[8]), alpha=alpha_value)
70
+ p.circle('x', 'y', size=3, source=source, alpha=0.3)
71
+
72
+ st.bokeh_chart(p)
pages/3_qa_sources.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime, os
2
+ from langchain.llms import OpenAI
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ import openai
5
+ import faiss
6
+ import streamlit as st
7
+ import feedparser
8
+ import urllib
9
+ import cloudpickle as cp
10
+ from urllib.request import urlopen
11
+ from summa import summarizer
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+
15
+ import requests
16
+ import json
17
+ from langchain.document_loaders import TextLoader
18
+ from langchain.indexes import VectorstoreIndexCreator
19
+ API_ENDPOINT = "https://api.openai.com/v1/chat/completions"
20
+
21
+ openai.organization = st.secrets.openai.org
22
+ openai.api_key = st.secrets.openai.api_key
23
+ os.environ["OPENAI_API_KEY"] = openai.api_key
24
+
25
+ @st.cache_data
26
+ def get_feeds_data(url):
27
+ data = cp.load(urlopen(url))
28
+ st.sidebar.success("Fetched data from API!")
29
+ return data
30
+
31
+ embeddings = OpenAIEmbeddings()
32
+
33
+ feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
34
+ embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
35
+ gal_feeds = get_feeds_data(feeds_link)
36
+ arxiv_ada_embeddings = get_feeds_data(embed_link)
37
+
38
+ @st.cache_data
39
+ def get_embedding_data(url):
40
+ data = cp.load(urlopen(url))
41
+ st.sidebar.success("Fetched data from API!")
42
+ return data
43
+
44
+ url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
45
+ e2d, _, _, _, _ = get_embedding_data(url)
46
+
47
+ ctr = -1
48
+ num_chunks = len(gal_feeds)
49
+ all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
50
+
51
+ for nc in range(num_chunks):
52
+
53
+ for i in range(len(gal_feeds[nc].entries)):
54
+ text = gal_feeds[nc].entries[i].summary
55
+ text = text.replace('\n', ' ')
56
+ text = text.replace('\\', '')
57
+ all_text.append(text)
58
+ all_titles.append(gal_feeds[nc].entries[i].title)
59
+ all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
60
+ all_links.append(gal_feeds[nc].entries[i].links[1].href)
61
+ all_authors.append(gal_feeds[nc].entries[i].authors)
62
+
63
+ d = arxiv_ada_embeddings.shape[1] # dimension
64
+ nb = arxiv_ada_embeddings.shape[0] # database size
65
+ xb = arxiv_ada_embeddings.astype('float32')
66
+ index = faiss.IndexFlatL2(d)
67
+ index.add(xb)
68
+
69
+ def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
70
+ """
71
+ Query ArXiv to return search results for a particular query
72
+ Parameters
73
+ ----------
74
+ query: str
75
+ query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
76
+ max_results: int, default = 10
77
+ number of results to return. numbers > 1000 generally lead to timeouts
78
+ start: int, default = 0
79
+ start index for results reported. use this if you're interested in running chunks.
80
+ Returns
81
+ -------
82
+ feed: dict
83
+ object containing requested results parsed with feedparser
84
+ Notes
85
+ -----
86
+ add functionality for chunk parsing, as well as storage and retreival
87
+ """
88
+
89
+ base_url = 'http://export.arxiv.org/api/query?';
90
+ query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
91
+ start,
92
+ max_results,sort_by,sort_order)
93
+
94
+ response = urllib.request.urlopen(base_url+query).read()
95
+ feed = feedparser.parse(response)
96
+ return feed
97
+
98
+ def find_papers_by_author(auth_name):
99
+
100
+ doc_ids = []
101
+ for doc_id in range(len(all_authors)):
102
+ for auth_id in range(len(all_authors[doc_id])):
103
+ if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
104
+ print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
105
+ doc_ids.append(doc_id)
106
+
107
+ return doc_ids
108
+
109
+ def faiss_based_indices(input_vector, nindex=10):
110
+ xq = input_vector.reshape(-1,1).T.astype('float32')
111
+ D, I = index.search(xq, nindex)
112
+ return I[0], D[0]
113
+
114
+ def list_similar_papers_v2(model_data,
115
+ doc_id = [], input_type = 'doc_id',
116
+ show_authors = False, show_summary = False,
117
+ return_n = 10):
118
+
119
+ arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
120
+
121
+ if input_type == 'doc_id':
122
+ print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
123
+ # inferred_vector = model.infer_vector(train_corpus[doc_id].words)
124
+ inferred_vector = arxiv_ada_embeddings[doc_id,0:]
125
+ start_range = 1
126
+ elif input_type == 'arxiv_id':
127
+ print('ArXiv id: ',doc_id)
128
+ arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
129
+ if len(arxiv_query_feed.entries) == 0:
130
+ print('error: arxiv id not found.')
131
+ return
132
+ else:
133
+ print('Title: '+arxiv_query_feed.entries[0].title)
134
+ inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
135
+ start_range = 0
136
+ elif input_type == 'keywords':
137
+ inferred_vector = np.array(embeddings.embed_query(doc_id))
138
+ start_range = 0
139
+ else:
140
+ print('unrecognized input type.')
141
+ return
142
+
143
+ sims, dists = faiss_based_indices(inferred_vector, return_n+2)
144
+ textstr = ''
145
+ abstracts_relevant = []
146
+ fhdrs = []
147
+
148
+ for i in range(start_range,start_range+return_n):
149
+
150
+ abstracts_relevant.append(all_text[sims[i]])
151
+ fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
152
+ fhdrs.append(fhdr)
153
+ textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
154
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
155
+ if show_authors == True:
156
+ textstr = textstr + '**Authors:** '
157
+ temp = all_authors[sims[i]]
158
+ for ak in range(len(temp)):
159
+ if ak < len(temp)-1:
160
+ textstr = textstr + temp[ak].name + ', '
161
+ else:
162
+ textstr = textstr + temp[ak].name + ' \n'
163
+ if show_summary == True:
164
+ textstr = textstr + '**Summary:** '
165
+ text = all_text[sims[i]]
166
+ text = text.replace('\n', ' ')
167
+ textstr = textstr + summarizer.summarize(text) + ' \n'
168
+ if show_authors == True or show_summary == True:
169
+ textstr = textstr + ' '
170
+ textstr = textstr + ' \n'
171
+ return textstr, abstracts_relevant, fhdrs, sims
172
+
173
+
174
+ def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
175
+ headers = {
176
+ "Content-Type": "application/json",
177
+ "Authorization": f"Bearer {openai.api_key}",
178
+ }
179
+
180
+ data = {
181
+ "model": model,
182
+ "messages": messages,
183
+ "temperature": temperature,
184
+ }
185
+
186
+ if max_tokens is not None:
187
+ data["max_tokens"] = max_tokens
188
+ response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
189
+ if response.status_code == 200:
190
+ return response.json()["choices"][0]["message"]["content"]
191
+ else:
192
+ raise Exception(f"Error {response.status_code}: {response.text}")
193
+
194
+
195
+ model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
196
+
197
+ def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
198
+
199
+ show_authors = True
200
+ show_summary = True
201
+ sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
202
+ doc_id = query,
203
+ input_type='keywords',
204
+ show_authors = show_authors, show_summary = show_summary,
205
+ return_n = return_n)
206
+
207
+ temp_abst = ''
208
+ loaders = []
209
+ for i in range(len(absts)):
210
+ temp_abst = absts[i]
211
+
212
+ try:
213
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
214
+ except:
215
+ os.mkdir('absts')
216
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
217
+ n = text_file.write(temp_abst)
218
+ text_file.close()
219
+ loader = TextLoader("absts/"+fhdrs[i]+".txt")
220
+ loaders.append(loader)
221
+
222
+ lc_index = VectorstoreIndexCreator().from_loaders(loaders)
223
+
224
+ st.markdown('### User query: '+query)
225
+ if show_pure_answer == True:
226
+ st.markdown('pure answer:')
227
+ st.markdown(lc_index.query(query))
228
+ st.markdown(' ')
229
+ st.markdown('#### context-based answer from sources:')
230
+ output = lc_index.query_with_sources(query)
231
+ st.markdown(output['answer'])
232
+ opstr = '#### Primary sources: \n'
233
+ st.markdown(opstr)
234
+
235
+ # opstr = ''
236
+ # for i in range(len(output['sources'])):
237
+ # opstr = opstr +'\n'+ output['sources'][i]
238
+
239
+ textstr = ''
240
+ ng = len(output['sources'].split())
241
+ abs_indices = []
242
+
243
+ for i in range(ng):
244
+ if i == (ng-1):
245
+ tempid = output['sources'].split()[i].split('_')[1][0:-4]
246
+ else:
247
+ tempid = output['sources'].split()[i].split('_')[1][0:-5]
248
+ try:
249
+ abs_index = all_arxivid.index(tempid)
250
+ abs_indices.append(abs_index)
251
+ textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
252
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
253
+ textstr = textstr + '**Authors:** '
254
+ temp = all_authors[abs_index]
255
+ for ak in range(4):
256
+ if ak < len(temp)-1:
257
+ textstr = textstr + temp[ak].name + ', '
258
+ else:
259
+ textstr = textstr + temp[ak].name + ' \n'
260
+ if len(temp) > 3:
261
+ textstr = textstr + ' et al. \n'
262
+ textstr = textstr + '**Summary:** '
263
+ text = all_text[abs_index]
264
+ text = text.replace('\n', ' ')
265
+ textstr = textstr + summarizer.summarize(text) + ' \n'
266
+ except:
267
+ textstr = textstr + output['sources'].split()[i]
268
+ # opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
269
+ # opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
270
+
271
+ textstr = textstr + ' '
272
+ textstr = textstr + ' \n'
273
+ st.markdown(textstr)
274
+
275
+ fig = plt.figure(figsize=(9,9))
276
+ plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
277
+ plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
278
+ plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
279
+ st.pyplot(fig)
280
+
281
+ if show_all_sources == True:
282
+ st.markdown('\n #### Other interesting papers:')
283
+ st.markdown(sims)
284
+ return output
285
+
286
+ st.title('ArXiv-based question answering')
287
+ st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please use sparingly because it costs me money right now. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
288
+
289
+ query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
290
+ return_n = st.slider('How many papers should I show?', 1, 20, 10)
291
+
292
+ sims = run_query(query, return_n = return_n)
pages/Untitled.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ bokeh==2.4.3
3
+ cloudpickle
4
+ scipy
5
+ summa
6
+ faiss-cpu
7
+ langchain
8
+ openai
9
+ feedparser
10
+ tiktoken
streamlit_app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(
4
+ page_title="arXiv-GPT",
5
+ page_icon="πŸ‘‹",
6
+ )
7
+
8
+ st.write("# Welcome to arXiv-GPT! πŸ‘‹")
9
+
10
+ st.sidebar.success("Select a function above.")
11
+ st.sidebar.markdown("Current functions include visualizing papers in the arxiv embedding, or searching for similar papers to an input paper or prompt phrase.")
12
+
13
+ st.markdown(
14
+ """
15
+ arXiv+GPT is a framework for searching and visualizing papers on
16
+ the [arXiv](https://arxiv.org/) using the context sensitivity from modern
17
+ large language models (LLMs) like GPT3 to better link paper contexts
18
+
19
+ **πŸ‘ˆ Select a tool from the sidebar** to see some examples
20
+ of what this framework can do!
21
+ ### Want to learn more?
22
+ - Check out `chaotic_neural` [(link)](http://chaotic-neural.readthedocs.io/)
23
+ - Jump into our [documentation](https://docs.streamlit.io)
24
+ - Contribute!
25
+ """
26
+ )