pathfinder / pages /1_arxiv_embedding_explorer.py
kiyer's picture
updates to codebase for embeddings and RAG QA.
f28b621
raw
history blame
5.4 kB
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
from bokeh.palettes import OrRd
from bokeh.plotting import figure, show
from bokeh.plotting import ColumnDataSource, figure, output_notebook, show
import cloudpickle as cp
import pickle
from scipy import stats
from urllib.request import urlopen
@st.cache_data
def get_feeds_data(url):
# data = cp.load(urlopen(url))
with open(url, "rb") as fp:
data = pickle.load(fp)
st.sidebar.success("Fetched data from API!")
return data
# embeddings = OpenAIEmbeddings()
dateval = "27-Jun-2023"
feeds_link = "local_files/astro_ph_ga_feeds_upto_"+dateval+".pkl"
embed_link = "local_files/astro_ph_ga_feeds_ada_embedding_"+dateval+".pkl"
gal_feeds = get_feeds_data(feeds_link)
arxiv_ada_embeddings = get_feeds_data(embed_link)
@st.cache_data
def get_embedding_data(url):
# data = cp.load(urlopen(url))
with open(url, "rb") as fp:
data = pickle.load(fp)
st.sidebar.success("Fetched data from API!")
return data
url = "local_files/astro_ph_ga_embedding_"+dateval+".pkl"
# e2d, _, _, _, _ = get_embedding_data(url)
embedding = get_embedding_data(url)
st.title("ArXiv+GPT3 embedding explorer")
st.markdown('[Includes papers up to: `'+dateval+'`]')
st.markdown("This is an explorer for astro-ph.GA papers on the arXiv (up to Apt 18th, 2023). The papers have been preprocessed with `chaotic_neural` [(link)](http://chaotic-neural.readthedocs.io/) after which the collected abstracts are run through `text-embedding-ada-002` with [langchain](https://python.langchain.com/en/latest/ecosystem/openai.html) to generate a unique vector correpsonding to each paper. These are then compressed using [umap](https://umap-learn.readthedocs.io/en/latest/) and shown here, and can be used for similarity searches with methods like [faiss](https://github.com/facebookresearch/faiss). The scatterplot here can be paired with a heatmap for more targeted searches looking at a specific topic or area (see sidebar). Upgrade to chaotic neural suggested by Jo Ciucă, thank you! More to come (hopefully) with GPT-4 and its applications!")
st.markdown("Interpreting the UMAP plot: the algorithm creates a 2d embedding from the high-dim vector space that tries to conserve as much similarity information as possible. Nearby points in UMAP space are similar, and grow dissimiliar as you move farther away. The axes do not have any physical meaning.")
from tqdm import tqdm
ctr = -1
num_chunks = len(gal_feeds)
all_text = []
all_titles = []
all_arxivid = []
all_links = []
for nc in tqdm(range(num_chunks)):
for i in range(len(gal_feeds[nc].entries)):
text = gal_feeds[nc].entries[i].summary
text = text.replace('\n', ' ')
text = text.replace('\\', '')
all_text.append(text)
all_titles.append(gal_feeds[nc].entries[i].title)
all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
all_links.append(gal_feeds[nc].entries[i].links[1].href)
def density_estimation(m1, m2, xmin=0, ymin=0, xmax=15, ymax=15):
X, Y = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([X.ravel(), Y.ravel()])
values = np.vstack([m1, m2])
kernel = stats.gaussian_kde(values)
Z = np.reshape(kernel(positions).T, X.shape)
return X, Y, Z
st.sidebar.markdown('This is a widget that allows you to look for papers containing specific phrases in the dataset and show it as a heatmap. Enter the phrase of interest, then change the size and opacity of the heatmap as desired to find the high-density regions. Hover over blue points to see the details of individual papers.')
st.sidebar.markdown('`Note`: (i) if you enter a query that is not in the corpus of abstracts, it will return an error. just enter a different query in that case. (ii) there are some empty tooltips when you hover, these correspond to the underlying hexbins, and can be ignored.')
st.sidebar.text_input("Search query", key="phrase", value="Quenching")
alpha_value = st.sidebar.slider("Pick the hexbin opacity",0.0,1.0,0.81)
size_value = st.sidebar.slider("Pick the hexbin gridsize",10,50,20)
phrase=st.session_state.phrase
phrase_flags = np.zeros((len(all_text),))
for i in range(len(all_text)):
if phrase.lower() in all_text[i].lower():
phrase_flags[i] = 1
source = ColumnDataSource(data=dict(
x=embedding[0:,0],
y=embedding[0:,1],
title=all_titles,
link=all_links,
))
TOOLTIPS = """
<div style="width:300px;">
ID: $index
($x, $y)
@title <br>
@link <br> <br>
</div>
"""
p = figure(width=700, height=583, tooltips=TOOLTIPS, x_range=(0, 15), y_range=(2.5,15),
title="UMAP projection of embeddings for the astro-ph.GA corpus"+phrase)
# p.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1], size=size_value,
# palette = np.flip(OrRd[8]), alpha=alpha_value)
p.circle('x', 'y', size=3, source=source, alpha=0.3)
st.bokeh_chart(p)
fig = plt.figure(figsize=(10.5,9*0.8328))
plt.scatter(embedding[0:,0], embedding[0:,1],s=2,alpha=0.1)
plt.hexbin(embedding[phrase_flags==1,0],embedding[phrase_flags==1,1],
gridsize=size_value, cmap = 'viridis', alpha=alpha_value,extent=(-1,16,1.5,16),mincnt=10)
plt.title("UMAP localization of heatmap keyword: "+phrase)
plt.axis([0,15,2.5,15]);
clbr = plt.colorbar(); clbr.set_label('# papers')
plt.axis('off')
st.pyplot(fig)