Spaces:
Sleeping
Sleeping
from distutils.fancy_getopt import wrap_text | |
from top2vec import Top2Vec | |
import joblib | |
import streamlit as st | |
import pandas as pd | |
from pathlib import Path | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from streamlit_plotly_events import plotly_events | |
from st_aggrid import AgGrid, GridOptionsBuilder, ColumnsAutoSizeMode | |
from logging import getLogger | |
def initialize_state(): | |
with st.spinner("Loading app..."): | |
if 'model' not in st.session_state: | |
model = Top2Vec.load('models/model.pkl') | |
model._check_model_status() | |
model.hierarchical_topic_reduction(num_topics=20) | |
st.session_state.model = model | |
st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav') | |
logger.info("loading data...") | |
data = pd.read_csv(proj_dir/'data'/'data.csv') | |
data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') | |
st.session_state.data = data | |
topics = pd.read_csv(proj_dir/'data'/'topics.csv') | |
topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') | |
st.session_state.topics = topics | |
if 'data' not in st.session_state: | |
logger.info("loading data...") | |
data = pd.read_csv(proj_dir/'data'/'data.csv') | |
data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') | |
st.session_state.data = data | |
st.session_state.selected_data = data | |
st.session_state.all_topics = list(data.topic_id.unique()) | |
if 'topics' not in st.session_state: | |
logger.info("loading topics...") | |
topics = pd.read_csv(proj_dir/'data'/'topics.csv') | |
topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') | |
st.session_state.topics = topics | |
st.session_state.selected_points = [] | |
def main(): | |
max_docs = st.sidebar.slider("# docs", 10, 100, value=50) | |
to_search = st.text_input("Write your query here", "") or "" | |
with st.spinner('Embedding Query...'): | |
vector = st.session_state.model.embed([to_search]) | |
with st.spinner('Dimension Reduction...'): | |
point = st.session_state.umap_model.transform(vector.reshape(1, -1)) | |
documents, document_scores, document_ids = st.session_state.model.search_documents_by_vector(vector.flatten(), num_docs=max_docs) | |
st.session_state.search_raw_df = pd.DataFrame({'document_ids':document_ids, 'document_scores':document_scores}) | |
st.session_state.data_to_model = st.session_state.data.merge(st.session_state.search_raw_df, left_on='id', right_on='document_ids').drop(['document_ids'], axis=1) | |
st.session_state.data_to_model = st.session_state.data_to_model.sort_values(by='document_scores', ascending=False) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847 | |
st.session_state.data_to_model.loc[len(st.session_state.data_to_model.index)] = ['Point', *point[0].tolist(), to_search, 'Query', 0] | |
st.session_state.data_to_model_with_point = st.session_state.data_to_model | |
st.session_state.data_to_model_without_point = st.session_state.data_to_model.iloc[:-1] | |
def get_topics_counts() -> pd.DataFrame: | |
topic_counts = st.session_state.data_to_model_without_point["topic_id"].value_counts().to_frame() | |
merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id') | |
cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x':'topic_count'}, axis=1) | |
cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id'] | |
return cleaned[cols] | |
st.write(""" | |
# Semantic Search | |
This shows a 2d representation of documents embeded in a semantic space. Each dot is a document | |
and the dots close represent documents that are close in meaning. | |
Note that the distance metrics were computed at a higher dimension so take the representation with | |
a grain of salt. | |
The Query is shown with the documents in yellow. | |
""" | |
) | |
df = st.session_state.data_to_model_with_point.sort_values(by='topic_id', ascending=True) | |
fig = px.scatter(df.iloc[:-1], x='x', y='y', color='topic_id', template='plotly_dark', hover_data=['id', 'topic_id', 'x', 'y']) | |
fig.add_traces(px.scatter(df.tail(1), x="x", y="y").update_traces(marker_size=10, marker_color="yellow").data) | |
st.plotly_chart(fig, use_container_width=True) | |
tab1, tab2 = st.tabs(["Docs", "Topics"]) | |
with tab1: | |
cols = ['id', 'document_scores', 'topic_id', 'documents'] | |
builder = GridOptionsBuilder.from_dataframe(st.session_state.data_to_model_without_point.loc[:, cols]) | |
builder.configure_pagination() | |
builder.configure_column('document_scores', type=["numericColumn","numberColumnFilter","customNumericFormat"], precision=2) | |
go = builder.build() | |
AgGrid(st.session_state.data_to_model_without_point.loc[:,cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS) | |
with tab2: | |
cols = ['topic_id', 'topic_count', 'topic_0'] | |
topic_counts = get_topics_counts() | |
builder = GridOptionsBuilder.from_dataframe(topic_counts[cols]) | |
builder.configure_pagination() | |
builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True) | |
go = builder.build() | |
AgGrid(topic_counts.loc[:,cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW) | |
if __name__ == "__main__": | |
# Setting up Logger and proj_dir | |
logger = getLogger(__name__) | |
proj_dir = Path(__file__).parents[2] | |
# For max width tables | |
pd.set_option('display.max_colwidth', 0) | |
# Streamlit settings | |
st.set_page_config(layout="wide") | |
md_title = "# Semantic Search π" | |
st.markdown(md_title) | |
st.sidebar.markdown(md_title) | |
initialize_state() | |
main() |