top2vec / app /pages /01_Topic_Explorer_πŸ“š.py
derek-thomas's picture
derek-thomas HF staff
Added same init across pages
b64c266
raw
history blame
2.84 kB
from logging import getLogger
from pathlib import Path
import joblib
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from top2vec import Top2Vec
@st.cache(show_spinner=False)
def initialize_state():
with st.spinner("Loading app..."):
if 'model' not in st.session_state:
model = Top2Vec.load('models/model.pkl')
model._check_model_status()
model.hierarchical_topic_reduction(num_topics=20)
st.session_state.model = model
st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
logger.info("loading data...")
data = pd.read_csv(proj_dir/'data'/'data.csv')
data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
st.session_state.data = data
topics = pd.read_csv(proj_dir/'data'/'topics.csv')
topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
st.session_state.topics = topics
if 'data' not in st.session_state:
logger.info("loading data...")
data = pd.read_csv(proj_dir/'data'/'data.csv')
data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
st.session_state.data = data
st.session_state.selected_data = data
st.session_state.all_topics = list(data.topic_id.unique())
if 'topics' not in st.session_state:
logger.info("loading topics...")
topics = pd.read_csv(proj_dir/'data'/'topics.csv')
topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
st.session_state.topics = topics
def main():
st.write("""
A way to dive into each topic. Use the slider on the left to choose the topic.
The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""")
topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0)
fig = go.Figure(go.Bar(
x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1],
y=st.session_state.model.topic_words_reduced[topic_num][::-1],
orientation='h'))
fig.update_layout(
title=f'Words for Topic {topic_num}',
yaxis_title='Top 20 topic words',
xaxis_title='Distance to topic centroid'
)
st.plotly_chart(fig, True)
if __name__ == "__main__":
# Setting up Logger and proj_dir
logger = getLogger(__name__)
proj_dir = Path(__file__).parents[2]
# For max width tables
pd.set_option('display.max_colwidth', 0)
# Streamlit settings
st.set_page_config(layout="wide")
md_title = "# Topic Explorer πŸ“š"
st.markdown(md_title)
st.sidebar.markdown(md_title)
initialize_state()
main()