File size: 5,667 Bytes
74ce942
 
 
 
 
 
 
 
 
d5f15cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74ce942
ea72d75
74ce942
 
 
 
 
 
 
 
 
 
 
 
 
ea72d75
74ce942
 
 
 
ea72d75
74ce942
 
 
 
 
 
 
 
 
ea72d75
74ce942
 
ea72d75
 
356174d
ea72d75
 
74ce942
 
 
 
 
 
 
 
 
 
356174d
 
 
74ce942
 
176bc83
ea72d75
74ce942
 
 
 
 
 
ea72d75
74ce942
 
ea72d75
74ce942
 
 
 
 
 
 
 
 
 
ea72d75
 
74ce942
 
 
ea72d75
74ce942
 
 
 
 
 
 
 
 
d5f15cb
74ce942
 
 
 
d5f15cb
ea72d75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from logging import getLogger
from pathlib import Path

import pandas as pd
import plotly.express as px
import streamlit as st
from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
from streamlit_plotly_events import plotly_events

from utilities import initialization

initialization()


# @st.cache(show_spinner=False)
# def initialize_state():
#     with st.spinner("Loading app..."):
#         if 'model' not in st.session_state:
#             model = Top2Vec.load('models/model.pkl')
#             model._check_model_status()
#             model.hierarchical_topic_reduction(num_topics=20)
#
#             st.session_state.model = model
#             st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
#             logger.info("loading data...")
#
#         if 'data' not in st.session_state:
#             logger.info("loading data...")
#             data = pd.read_csv(proj_dir / 'data' / 'data.csv')
#             data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
#             st.session_state.data = data
#             st.session_state.selected_data = data
#             st.session_state.all_topics = list(data.topic_id.unique())
#
#         if 'topics' not in st.session_state:
#             logger.info("loading topics...")
#             topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
#             topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
#             st.session_state.topics = topics


def reset():
    logger.info("Resetting...")
    st.session_state.selected_data = st.session_state.data
    st.session_state.selected_points = []


def filter_df():
    if st.session_state.selected_points:
        points_df = pd.DataFrame(st.session_state.selected_points).loc[:, ['x', 'y']]
        st.session_state.selected_data = st.session_state.data.merge(points_df, on=['x', 'y'])
        logger.info(f"Updates selected_data: {len(st.session_state.selected_data)}")
    else:
        logger.info(f"Lame")


def reset():
    st.session_state.selected_data = st.session_state.data
    st.session_state.selected_points = []


def main():
    st.write(""" 
    # Topic Modeling
    This shows a 2d representation of documents embeded in a semantic space. Each dot is a document
    and the dots close represent documents that are close in meaning.
    
    Zoom in and explore a topic of your choice. You can see the documents you select with the `lasso` or `box`
    tool below in the corresponding tabs."""
             )

    st.button("Reset", help="Will Reset the selected points and the selected topics", on_click=reset)
    data_to_model = st.session_state.data.sort_values(by='topic_id',
                                                      ascending=True)  # to make legend sorted https://bioinformatics.stackexchange.com/a/18847
    data_to_model['topic_id'].replace(st.session_state.topic_str_to_word, inplace=True)
    fig = px.scatter(data_to_model, x='x', y='y', color='topic_id', template='plotly_dark',
                     hover_data=['id', 'topic_id', 'x', 'y'])
    st.session_state.selected_points = plotly_events(fig, select_event=True, click_event=False)
    filter_df()

    tab1, tab2 = st.tabs(["Docs", "Topics"])

    with tab1:
        if st.session_state.selected_points:
            filter_df()
            cols = ['id', 'topic_id', 'documents']
            data = st.session_state.selected_data[cols]
            data['topic_word'] = data.topic_id.replace(st.session_state.topic_str_to_word)
            ordered_cols = ['id', 'topic_id', 'topic_word', 'documents']
            builder = GridOptionsBuilder.from_dataframe(data[ordered_cols])
            builder.configure_pagination()
            go = builder.build()
            AgGrid(data[ordered_cols], theme='streamlit', gridOptions=go,
                   columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS)
        else:
            st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')

    def get_topics_counts() -> pd.DataFrame:
        topic_counts = st.session_state.selected_data["topic_id"].value_counts().to_frame()
        merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id')
        cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x': 'topic_count'}, axis=1)
        cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id']
        return cleaned[cols]

    with tab2:
        if st.session_state.selected_points:
            filter_df()
            cols = ['topic_id', 'topic_count', 'topic_0']
            topic_counts = get_topics_counts()
            # st.write(topic_counts.columns)
            builder = GridOptionsBuilder.from_dataframe(topic_counts[cols])
            builder.configure_pagination()
            builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True)
            go = builder.build()
            AgGrid(topic_counts.loc[:, cols], theme='streamlit', gridOptions=go,
                   columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW)
        else:
            st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')


if __name__ == "__main__":
    # Setting up Logger and proj_dir
    logger = getLogger(__name__)
    proj_dir = Path(__file__).parents[2]

    # For max width tables
    pd.set_option('display.max_colwidth', 0)

    # Streamlit settings
    # st.set_page_config(layout="wide")
    md_title = "# Document Explorer πŸ“–"
    st.markdown(md_title)
    st.sidebar.markdown(md_title)

    # initialize_state()
    main()