derek-thomas HF staff commited on
Commit
b64c266
β€’
1 Parent(s): c70a53f

Added same init across pages

Browse files
app/pages/01_Topic_Explorer_πŸ“š.py CHANGED
@@ -1,5 +1,6 @@
1
  from logging import getLogger
2
  from pathlib import Path
 
3
 
4
  import pandas as pd
5
  import plotly.express as px
@@ -9,14 +10,40 @@ import streamlit as st
9
  from top2vec import Top2Vec
10
 
11
 
 
12
  def initialize_state():
13
- with st.spinner('Loading App...'):
14
  if 'model' not in st.session_state:
15
  model = Top2Vec.load('models/model.pkl')
16
  model._check_model_status()
17
  model.hierarchical_topic_reduction(num_topics=20)
18
- assert len(model.topic_words_reduced) == 20
19
  st.session_state.model = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def main():
22
  st.write("""
 
1
  from logging import getLogger
2
  from pathlib import Path
3
+ import joblib
4
 
5
  import pandas as pd
6
  import plotly.express as px
 
10
  from top2vec import Top2Vec
11
 
12
 
13
+ @st.cache(show_spinner=False)
14
  def initialize_state():
15
+ with st.spinner("Loading app..."):
16
  if 'model' not in st.session_state:
17
  model = Top2Vec.load('models/model.pkl')
18
  model._check_model_status()
19
  model.hierarchical_topic_reduction(num_topics=20)
20
+
21
  st.session_state.model = model
22
+ st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
23
+ logger.info("loading data...")
24
+
25
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
26
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
27
+ st.session_state.data = data
28
+
29
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
30
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
31
+
32
+ st.session_state.topics = topics
33
+
34
+ if 'data' not in st.session_state:
35
+ logger.info("loading data...")
36
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
37
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
38
+ st.session_state.data = data
39
+ st.session_state.selected_data = data
40
+ st.session_state.all_topics = list(data.topic_id.unique())
41
+
42
+ if 'topics' not in st.session_state:
43
+ logger.info("loading topics...")
44
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
45
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
46
+ st.session_state.topics = topics
47
 
48
  def main():
49
  st.write("""
app/pages/02_Document_Explorer_πŸ“–.py CHANGED
@@ -1,6 +1,7 @@
1
  from distutils.fancy_getopt import wrap_text
2
  from logging import getLogger
3
  from pathlib import Path
 
4
 
5
  import pandas as pd
6
  import plotly.express as px
@@ -12,22 +13,40 @@ from streamlit_plotly_events import plotly_events
12
  from top2vec import Top2Vec
13
 
14
 
 
15
  def initialize_state():
16
- if 'data' not in st.session_state:
17
- logger.info("loading data...")
18
- data = pd.read_csv(proj_dir/'data'/'data.csv')
19
- data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
20
- st.session_state.data = data
21
- st.session_state.selected_data = data
22
- st.session_state.all_topics = list(data.topic_id.unique())
23
-
24
- if 'topics' not in st.session_state:
25
- logger.info("loading topics...")
26
- topics = pd.read_csv(proj_dir/'data'/'topics.csv')
27
- topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
28
- st.session_state.topics = topics
29
-
30
- st.session_state.selected_points = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def reset():
33
  logger.info("Resetting...")
 
1
  from distutils.fancy_getopt import wrap_text
2
  from logging import getLogger
3
  from pathlib import Path
4
+ import joblib
5
 
6
  import pandas as pd
7
  import plotly.express as px
 
13
  from top2vec import Top2Vec
14
 
15
 
16
+ @st.cache(show_spinner=False)
17
  def initialize_state():
18
+ with st.spinner("Loading app..."):
19
+ if 'model' not in st.session_state:
20
+ model = Top2Vec.load('models/model.pkl')
21
+ model._check_model_status()
22
+ model.hierarchical_topic_reduction(num_topics=20)
23
+
24
+ st.session_state.model = model
25
+ st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
26
+ logger.info("loading data...")
27
+
28
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
29
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
30
+ st.session_state.data = data
31
+
32
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
33
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
34
+
35
+ st.session_state.topics = topics
36
+
37
+ if 'data' not in st.session_state:
38
+ logger.info("loading data...")
39
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
40
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
41
+ st.session_state.data = data
42
+ st.session_state.selected_data = data
43
+ st.session_state.all_topics = list(data.topic_id.unique())
44
+
45
+ if 'topics' not in st.session_state:
46
+ logger.info("loading topics...")
47
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
48
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
49
+ st.session_state.topics = topics
50
 
51
  def reset():
52
  logger.info("Resetting...")
app/pages/03_Semantic_Search_πŸ”.py CHANGED
@@ -17,6 +17,8 @@ def initialize_state():
17
  if 'model' not in st.session_state:
18
  model = Top2Vec.load('models/model.pkl')
19
  model._check_model_status()
 
 
20
  st.session_state.model = model
21
  st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
22
  logger.info("loading data...")
@@ -30,6 +32,22 @@ def initialize_state():
30
 
31
  st.session_state.topics = topics
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def main():
34
 
35
  max_docs = st.sidebar.slider("# docs", 10, 100, value=50)
 
17
  if 'model' not in st.session_state:
18
  model = Top2Vec.load('models/model.pkl')
19
  model._check_model_status()
20
+ model.hierarchical_topic_reduction(num_topics=20)
21
+
22
  st.session_state.model = model
23
  st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
24
  logger.info("loading data...")
 
32
 
33
  st.session_state.topics = topics
34
 
35
+ if 'data' not in st.session_state:
36
+ logger.info("loading data...")
37
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
38
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
39
+ st.session_state.data = data
40
+ st.session_state.selected_data = data
41
+ st.session_state.all_topics = list(data.topic_id.unique())
42
+
43
+ if 'topics' not in st.session_state:
44
+ logger.info("loading topics...")
45
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
46
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
47
+ st.session_state.topics = topics
48
+
49
+ st.session_state.selected_points = []
50
+
51
  def main():
52
 
53
  max_docs = st.sidebar.slider("# docs", 10, 100, value=50)