from logging import getLogger from pathlib import Path import joblib import pandas as pd import plotly.express as px import plotly.graph_objects as go import streamlit as st from top2vec import Top2Vec @st.cache(show_spinner=False) def initialize_state(): with st.spinner("Loading app..."): if 'model' not in st.session_state: model = Top2Vec.load('models/model.pkl') model._check_model_status() model.hierarchical_topic_reduction(num_topics=20) st.session_state.model = model st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav') logger.info("loading data...") data = pd.read_csv(proj_dir/'data'/'data.csv') data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') st.session_state.data = data topics = pd.read_csv(proj_dir/'data'/'topics.csv') topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') st.session_state.topics = topics if 'data' not in st.session_state: logger.info("loading data...") data = pd.read_csv(proj_dir/'data'/'data.csv') data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}') st.session_state.data = data st.session_state.selected_data = data st.session_state.all_topics = list(data.topic_id.unique()) if 'topics' not in st.session_state: logger.info("loading topics...") topics = pd.read_csv(proj_dir/'data'/'topics.csv') topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}') st.session_state.topics = topics def main(): st.write(""" A way to dive into each topic. Use the slider on the left to choose the topic. The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""") topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0) fig = go.Figure(go.Bar( x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1], y=st.session_state.model.topic_words_reduced[topic_num][::-1], orientation='h')) fig.update_layout( title=f'Words for Topic {topic_num}', yaxis_title='Top 20 topic words', xaxis_title='Distance to topic centroid' ) st.plotly_chart(fig, True) if __name__ == "__main__": # Setting up Logger and proj_dir logger = getLogger(__name__) proj_dir = Path(__file__).parents[2] # For max width tables pd.set_option('display.max_colwidth', 0) # Streamlit settings st.set_page_config(layout="wide") md_title = "# Topic Explorer 📚" st.markdown(md_title) st.sidebar.markdown(md_title) initialize_state() main()