derek-thomas HF staff commited on
Commit
d5f15cb
β€’
1 Parent(s): ea72d75

Abstracted `session_state`. Should work for multiple users now

Browse files
app/Top2Vec.py CHANGED
@@ -1,6 +1,10 @@
1
  import streamlit as st
2
 
 
 
3
  st.set_page_config(page_title="Top2Vec", layout="wide")
 
 
4
 
5
  st.markdown(
6
  """
 
1
  import streamlit as st
2
 
3
+ from utilities import initialization
4
+
5
  st.set_page_config(page_title="Top2Vec", layout="wide")
6
+ initialization()
7
+
8
 
9
  st.markdown(
10
  """
app/pages/01_Topic_Explorer_πŸ“š.py CHANGED
@@ -1,47 +1,40 @@
1
  from logging import getLogger
2
  from pathlib import Path
3
 
4
- import joblib
5
  import pandas as pd
6
  import plotly.graph_objects as go
7
  import streamlit as st
8
- from top2vec import Top2Vec
9
 
10
-
11
- @st.cache(show_spinner=False)
12
- def initialize_state():
13
- with st.spinner("Loading app..."):
14
- if 'model' not in st.session_state:
15
- model = Top2Vec.load('models/model.pkl')
16
- model._check_model_status()
17
- model.hierarchical_topic_reduction(num_topics=20)
18
-
19
- st.session_state.model = model
20
- st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
21
- logger.info("loading data...")
22
-
23
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
24
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
25
- st.session_state.data = data
26
-
27
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
28
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
29
-
30
- st.session_state.topics = topics
31
-
32
- if 'data' not in st.session_state:
33
- logger.info("loading data...")
34
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
35
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
36
- st.session_state.data = data
37
- st.session_state.selected_data = data
38
- st.session_state.all_topics = list(data.topic_id.unique())
39
-
40
- if 'topics' not in st.session_state:
41
- logger.info("loading topics...")
42
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
43
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
44
- st.session_state.topics = topics
45
 
46
 
47
  def main():
@@ -73,10 +66,10 @@ if __name__ == "__main__":
73
  pd.set_option('display.max_colwidth', 0)
74
 
75
  # Streamlit settings
76
- st.set_page_config(layout="wide")
77
  md_title = "# Topic Explorer πŸ“š"
78
  st.markdown(md_title)
79
  st.sidebar.markdown(md_title)
80
 
81
- initialize_state()
82
  main()
 
1
  from logging import getLogger
2
  from pathlib import Path
3
 
 
4
  import pandas as pd
5
  import plotly.graph_objects as go
6
  import streamlit as st
 
7
 
8
+ from utilities import initialization
9
+
10
+ initialization()
11
+
12
+
13
+ # @st.cache(show_spinner=False)
14
+ # def initialize_state():
15
+ # with st.spinner("Loading app..."):
16
+ # if 'model' not in st.session_state:
17
+ # model = Top2Vec.load('models/model.pkl')
18
+ # model._check_model_status()
19
+ # model.hierarchical_topic_reduction(num_topics=20)
20
+ #
21
+ # st.session_state.model = model
22
+ # st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
23
+ # logger.info("loading data...")
24
+ #
25
+ # if 'data' not in st.session_state:
26
+ # logger.info("loading data...")
27
+ # data = pd.read_csv(proj_dir / 'data' / 'data.csv')
28
+ # data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
29
+ # st.session_state.data = data
30
+ # st.session_state.selected_data = data
31
+ # st.session_state.all_topics = list(data.topic_id.unique())
32
+ #
33
+ # if 'topics' not in st.session_state:
34
+ # logger.info("loading topics...")
35
+ # topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
36
+ # topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
37
+ # st.session_state.topics = topics
 
 
 
 
 
38
 
39
 
40
  def main():
 
66
  pd.set_option('display.max_colwidth', 0)
67
 
68
  # Streamlit settings
69
+ # st.set_page_config(layout="wide")
70
  md_title = "# Topic Explorer πŸ“š"
71
  st.markdown(md_title)
72
  st.sidebar.markdown(md_title)
73
 
74
+ # initialize_state()
75
  main()
app/pages/02_Document_Explorer_πŸ“–.py CHANGED
@@ -1,49 +1,42 @@
1
  from logging import getLogger
2
  from pathlib import Path
3
 
4
- import joblib
5
  import pandas as pd
6
  import plotly.express as px
7
  import streamlit as st
8
  from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
9
  from streamlit_plotly_events import plotly_events
10
- from top2vec import Top2Vec
11
 
12
-
13
- @st.cache(show_spinner=False)
14
- def initialize_state():
15
- with st.spinner("Loading app..."):
16
- if 'model' not in st.session_state:
17
- model = Top2Vec.load('models/model.pkl')
18
- model._check_model_status()
19
- model.hierarchical_topic_reduction(num_topics=20)
20
-
21
- st.session_state.model = model
22
- st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
23
- logger.info("loading data...")
24
-
25
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
26
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
27
- st.session_state.data = data
28
-
29
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
30
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
31
-
32
- st.session_state.topics = topics
33
-
34
- if 'data' not in st.session_state:
35
- logger.info("loading data...")
36
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
37
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
38
- st.session_state.data = data
39
- st.session_state.selected_data = data
40
- st.session_state.all_topics = list(data.topic_id.unique())
41
-
42
- if 'topics' not in st.session_state:
43
- logger.info("loading topics...")
44
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
45
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
46
- st.session_state.topics = topics
47
 
48
 
49
  def reset():
@@ -131,10 +124,10 @@ if __name__ == "__main__":
131
  pd.set_option('display.max_colwidth', 0)
132
 
133
  # Streamlit settings
134
- st.set_page_config(layout="wide")
135
  md_title = "# Document Explorer πŸ“–"
136
  st.markdown(md_title)
137
  st.sidebar.markdown(md_title)
138
 
139
- initialize_state()
140
  main()
 
1
  from logging import getLogger
2
  from pathlib import Path
3
 
 
4
  import pandas as pd
5
  import plotly.express as px
6
  import streamlit as st
7
  from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
8
  from streamlit_plotly_events import plotly_events
 
9
 
10
+ from utilities import initialization
11
+
12
+ initialization()
13
+
14
+
15
+ # @st.cache(show_spinner=False)
16
+ # def initialize_state():
17
+ # with st.spinner("Loading app..."):
18
+ # if 'model' not in st.session_state:
19
+ # model = Top2Vec.load('models/model.pkl')
20
+ # model._check_model_status()
21
+ # model.hierarchical_topic_reduction(num_topics=20)
22
+ #
23
+ # st.session_state.model = model
24
+ # st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
25
+ # logger.info("loading data...")
26
+ #
27
+ # if 'data' not in st.session_state:
28
+ # logger.info("loading data...")
29
+ # data = pd.read_csv(proj_dir / 'data' / 'data.csv')
30
+ # data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
31
+ # st.session_state.data = data
32
+ # st.session_state.selected_data = data
33
+ # st.session_state.all_topics = list(data.topic_id.unique())
34
+ #
35
+ # if 'topics' not in st.session_state:
36
+ # logger.info("loading topics...")
37
+ # topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
38
+ # topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
39
+ # st.session_state.topics = topics
 
 
 
 
 
40
 
41
 
42
  def reset():
 
124
  pd.set_option('display.max_colwidth', 0)
125
 
126
  # Streamlit settings
127
+ # st.set_page_config(layout="wide")
128
  md_title = "# Document Explorer πŸ“–"
129
  st.markdown(md_title)
130
  st.sidebar.markdown(md_title)
131
 
132
+ # initialize_state()
133
  main()
app/pages/03_Semantic_Search_πŸ”.py CHANGED
@@ -1,50 +1,43 @@
1
  from logging import getLogger
2
  from pathlib import Path
3
 
4
- import joblib
5
  import pandas as pd
6
  import plotly.express as px
7
  import streamlit as st
8
  from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
9
- from top2vec import Top2Vec
10
 
11
-
12
- @st.cache(show_spinner=False)
13
- def initialize_state():
14
- with st.spinner("Loading app..."):
15
- if 'model' not in st.session_state:
16
- model = Top2Vec.load('models/model.pkl')
17
- model._check_model_status()
18
- model.hierarchical_topic_reduction(num_topics=20)
19
-
20
- st.session_state.model = model
21
- st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
22
- logger.info("loading data...")
23
-
24
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
25
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
26
- st.session_state.data = data
27
-
28
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
29
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
30
-
31
- st.session_state.topics = topics
32
-
33
- if 'data' not in st.session_state:
34
- logger.info("loading data...")
35
- data = pd.read_csv(proj_dir / 'data' / 'data.csv')
36
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
37
- st.session_state.data = data
38
- st.session_state.selected_data = data
39
- st.session_state.all_topics = list(data.topic_id.unique())
40
-
41
- if 'topics' not in st.session_state:
42
- logger.info("loading topics...")
43
- topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
44
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
45
- st.session_state.topics = topics
46
-
47
- st.session_state.selected_points = []
48
 
49
 
50
  def main():
@@ -124,10 +117,10 @@ if __name__ == "__main__":
124
  pd.set_option('display.max_colwidth', 0)
125
 
126
  # Streamlit settings
127
- st.set_page_config(layout="wide")
128
  md_title = "# Semantic Search πŸ”"
129
  st.markdown(md_title)
130
  st.sidebar.markdown(md_title)
131
 
132
- initialize_state()
133
  main()
 
1
  from logging import getLogger
2
  from pathlib import Path
3
 
 
4
  import pandas as pd
5
  import plotly.express as px
6
  import streamlit as st
7
  from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
 
8
 
9
+ from utilities import initialization
10
+
11
+ initialization()
12
+
13
+
14
+ # @st.cache(show_spinner=False)
15
+ # def initialize_state():
16
+ # with st.spinner("Loading app..."):
17
+ # if 'model' not in st.session_state:
18
+ # model = Top2Vec.load('models/model.pkl')
19
+ # model._check_model_status()
20
+ # model.hierarchical_topic_reduction(num_topics=20)
21
+ #
22
+ # st.session_state.model = model
23
+ # st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
24
+ # logger.info("loading data...")
25
+ #
26
+ # if 'data' not in st.session_state:
27
+ # logger.info("loading data...")
28
+ # data = pd.read_csv(proj_dir / 'data' / 'data.csv')
29
+ # data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
30
+ # st.session_state.data = data
31
+ # st.session_state.selected_data = data
32
+ # st.session_state.all_topics = list(data.topic_id.unique())
33
+ #
34
+ # if 'topics' not in st.session_state:
35
+ # logger.info("loading topics...")
36
+ # topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
37
+ # topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
38
+ # st.session_state.topics = topics
39
+ #
40
+ # st.session_state.selected_points = []
 
 
 
 
 
41
 
42
 
43
  def main():
 
117
  pd.set_option('display.max_colwidth', 0)
118
 
119
  # Streamlit settings
120
+ # st.set_page_config(layout="wide")
121
  md_title = "# Semantic Search πŸ”"
122
  st.markdown(md_title)
123
  st.sidebar.markdown(md_title)
124
 
125
+ # initialize_state()
126
  main()
app/utilities.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from pathlib import Path
3
+
4
+ import joblib
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from top2vec import Top2Vec
8
+
9
+ logger = getLogger(__name__)
10
+
11
+ proj_dir = Path(__file__).parents[1]
12
+
13
+
14
+ def initialization():
15
+ with st.spinner("Loading app..."):
16
+ if 'model' not in st.session_state:
17
+ model = Top2Vec.load('models/model.pkl')
18
+ model._check_model_status()
19
+ model.hierarchical_topic_reduction(num_topics=20)
20
+
21
+ st.session_state.model = model
22
+ st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
23
+ logger.info("loading data...")
24
+
25
+ if 'data' not in st.session_state:
26
+ logger.info("loading data...")
27
+ data = pd.read_csv(proj_dir / 'data' / 'data.csv')
28
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
29
+ st.session_state.data = data
30
+ st.session_state.selected_data = data
31
+ st.session_state.all_topics = list(data.topic_id.unique())
32
+
33
+ if 'topics' not in st.session_state:
34
+ logger.info("loading topics...")
35
+ topics = pd.read_csv(proj_dir / 'data' / 'topics.csv')
36
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
37
+ st.session_state.topics = topics
38
+
39
+ if 'selected_points' not in st.session_state:
40
+ st.session_state.selected_points = []