datavistics commited on
Commit
74ce942
β€’
1 Parent(s): d23c925

Init commit

Browse files
.gitattributes CHANGED
@@ -3,6 +3,7 @@
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
@@ -23,6 +24,7 @@
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
 
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
  *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.csv filter=lfs diff=lfs merge=lfs -text
7
  *.ftz filter=lfs diff=lfs merge=lfs -text
8
  *.gz filter=lfs diff=lfs merge=lfs -text
9
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
26
  *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ *.sav filter=lfs diff=lfs merge=lfs -text
28
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
  *.tar.* filter=lfs diff=lfs merge=lfs -text
30
  *.tflite filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pyc
2
+ vscode
3
+ notebooks/.ipynb_checkpoints
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
+ app_file: app/Top2Vec.py
9
  pinned: false
10
  license: mit
11
  ---
app/Top2Vec.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(page_title="Top2Vec", layout="wide")
4
+
5
+ st.markdown(
6
+ """
7
+ # Introduction
8
+ This is [space](https://huggingface.co/spaces) dedicated to using [top2vec](https://github.com/ddangelov/Top2Vec) and showing what features are available for semantic searching and topic modeling.
9
+ Please check out this [readme](https://github.com/ddangelov/Top2Vec#how-does-it-work) to better understand how it works.
10
+
11
+ > Top2Vec is an algorithm for **topic modeling** and **semantic search**. It automatically detects topics present in text and generates jointly embedded topic, document and word vectors.
12
+
13
+
14
+ # Setup
15
+ I used the [20 NewsGroups](https://huggingface.co/datasets/SetFit/20_newsgroups) dataset with `top2vec`.
16
+ I fit on the dataset and reduced the topics to 20.
17
+ The topics are created from top2vec, not the labels.
18
+ No analysis on the top 20 topics vs labels is provided.
19
+
20
+ # Usage
21
+ Check out
22
+ - The [Topic Explorer](/Topic_Explorer) page to understand what topic were detected
23
+ - The [Document Explorer](/Document_Explorer) page to visually explore documents
24
+ - The [Semantic Search](/Semantic_Search) page to search by meaning
25
+ """
26
+ )
app/pages/01_Topic_Explorer_πŸ“š.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ import plotly.express as px
6
+ import plotly.graph_objects as go
7
+ import streamlit as st
8
+
9
+ from top2vec import Top2Vec
10
+
11
+
12
+ def initialize_state():
13
+ with st.spinner('Loading App...'):
14
+ if 'model' not in st.session_state:
15
+ model = Top2Vec.load('models/model.pkl')
16
+ model._check_model_status()
17
+ model.hierarchical_topic_reduction(num_topics=20)
18
+ assert len(model.topic_words_reduced) == 20
19
+ st.session_state.model = model
20
+
21
+ def main():
22
+ st.write("""
23
+ A way to dive into each topic. Use the slider on the left to choose the topic.
24
+
25
+ The `y` axis shows which words are closest to a topic centroid. The `x` axis shows how correlated they are.""")
26
+
27
+ topic_num = st.sidebar.slider("Topic Number", 0, 19, value=0)
28
+ fig = go.Figure(go.Bar(
29
+ x=st.session_state.model.topic_word_scores_reduced[topic_num][::-1],
30
+ y=st.session_state.model.topic_words_reduced[topic_num][::-1],
31
+ orientation='h'))
32
+ fig.update_layout(
33
+ title=f'Words for Topic {topic_num}',
34
+ yaxis_title='Top 20 topic words',
35
+ xaxis_title='Distance to topic centroid'
36
+ )
37
+
38
+ st.plotly_chart(fig, True)
39
+
40
+ if __name__ == "__main__":
41
+ # Setting up Logger and proj_dir
42
+ logger = getLogger(__name__)
43
+ proj_dir = Path(__file__).parents[2]
44
+
45
+ # For max width tables
46
+ pd.set_option('display.max_colwidth', 0)
47
+
48
+ # Streamlit settings
49
+ st.set_page_config(layout="wide")
50
+ md_title = "# Topic Explorer πŸ“š"
51
+ st.markdown(md_title)
52
+ st.sidebar.markdown(md_title)
53
+
54
+ initialize_state()
55
+ main()
app/pages/02_Document_Explorer_πŸ“–.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.fancy_getopt import wrap_text
2
+ from logging import getLogger
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ import plotly.express as px
7
+ import plotly.graph_objects as go
8
+ import streamlit as st
9
+ from st_aggrid import AgGrid, ColumnsAutoSizeMode, GridOptionsBuilder
10
+ from streamlit_plotly_events import plotly_events
11
+
12
+ from top2vec import Top2Vec
13
+
14
+
15
+ def initialize_state():
16
+ if 'data' not in st.session_state:
17
+ logger.info("loading data...")
18
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
19
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
20
+ st.session_state.data = data
21
+ st.session_state.selected_data = data
22
+ st.session_state.all_topics = list(data.topic_id.unique())
23
+
24
+ if 'topics' not in st.session_state:
25
+ logger.info("loading topics...")
26
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
27
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
28
+ st.session_state.topics = topics
29
+
30
+ st.session_state.selected_points = []
31
+
32
+ def reset():
33
+ logger.info("Resetting...")
34
+ st.session_state.selected_data = st.session_state.data
35
+ st.session_state.selected_points = []
36
+
37
+
38
+ def filter_df():
39
+ if st.session_state.selected_points:
40
+ points_df = pd.DataFrame(st.session_state.selected_points).loc[:, ['x', 'y']]
41
+ st.session_state.selected_data = st.session_state.data.merge(points_df, on=['x', 'y'])
42
+ logger.info(f"Updates selected_data: {len(st.session_state.selected_data)}")
43
+ else:
44
+ logger.info(f"Lame")
45
+
46
+
47
+ def reset():
48
+ st.session_state.selected_data = st.session_state.data
49
+ st.session_state.selected_points = []
50
+
51
+
52
+ def main():
53
+ st.write("""
54
+ # Topic Modeling
55
+ This shows a 2d representation of documents embeded in a semantic space. Each dot is a document
56
+ and the dots close represent documents that are close in meaning.
57
+
58
+ Zoom in and explore a topic of your choice. You can see the documents you select with the `lasso` or `box`
59
+ tool below in the corresponding tabs."""
60
+ )
61
+
62
+ st.button("Reset", help="Will Reset the selected points and the selected topics", on_click=reset)
63
+ data_to_model = st.session_state.data.sort_values(by='topic_id', ascending=True) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847
64
+ fig = px.scatter(data_to_model, x='x', y='y', color='topic_id', template='plotly_dark', hover_data=['id', 'topic_id', 'x', 'y'])
65
+ st.session_state.selected_points = plotly_events(fig, select_event=True, click_event=False)
66
+ filter_df()
67
+
68
+ tab1, tab2 = st.tabs(["Docs", "Topics"])
69
+
70
+ with tab1:
71
+ if st.session_state.selected_points:
72
+ filter_df()
73
+ cols = ['id', 'topic_id', 'documents']
74
+ data = st.session_state.selected_data[cols]
75
+ builder = GridOptionsBuilder.from_dataframe(data)
76
+ builder.configure_pagination()
77
+ go = builder.build()
78
+ AgGrid(st.session_state.selected_data[cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS)
79
+ else:
80
+ st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')
81
+
82
+ def get_topics_counts() -> pd.DataFrame:
83
+ topic_counts = st.session_state.selected_data["topic_id"].value_counts().to_frame()
84
+ merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id')
85
+ cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x':'topic_count'}, axis=1)
86
+ cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id']
87
+ return cleaned[cols]
88
+
89
+
90
+ with tab2:
91
+ if st.session_state.selected_points:
92
+ filter_df()
93
+ cols = ['topic_id', 'topic_count', 'topic_0']
94
+ topic_counts = get_topics_counts()
95
+ # st.write(topic_counts.columns)
96
+ builder = GridOptionsBuilder.from_dataframe(topic_counts[cols])
97
+ builder.configure_pagination()
98
+ builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True)
99
+ go = builder.build()
100
+ AgGrid(topic_counts.loc[:,cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW)
101
+ else:
102
+ st.markdown('Select points in the graph with the `lasso` or `box` select tools to populate this table.')
103
+
104
+ if __name__ == "__main__":
105
+ # Setting up Logger and proj_dir
106
+ logger = getLogger(__name__)
107
+ proj_dir = Path(__file__).parents[2]
108
+
109
+ # For max width tables
110
+ pd.set_option('display.max_colwidth', 0)
111
+
112
+ # Streamlit settings
113
+ st.set_page_config(layout="wide")
114
+ md_title = "# Document Explorer πŸ“–"
115
+ st.markdown(md_title)
116
+ st.sidebar.markdown(md_title)
117
+
118
+ initialize_state()
119
+ main()
app/pages/03_Semantic_Search_πŸ”.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.fancy_getopt import wrap_text
2
+ from top2vec import Top2Vec
3
+ import joblib
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import plotly.express as px
8
+ import plotly.graph_objects as go
9
+ from streamlit_plotly_events import plotly_events
10
+ from st_aggrid import AgGrid, GridOptionsBuilder, ColumnsAutoSizeMode
11
+ from logging import getLogger
12
+
13
+
14
+ @st.cache(show_spinner=False)
15
+ def initialize_state():
16
+ with st.spinner("Loading app..."):
17
+ if 'model' not in st.session_state:
18
+ model = Top2Vec.load('models/model.pkl')
19
+ model._check_model_status()
20
+ st.session_state.model = model
21
+ st.session_state.umap_model = joblib.load(proj_dir / 'models' / 'umap.sav')
22
+ logger.info("loading data...")
23
+
24
+ data = pd.read_csv(proj_dir/'data'/'data.csv')
25
+ data['topic_id'] = data['topic_id'].apply(lambda x: f'{x:02d}')
26
+ st.session_state.data = data
27
+
28
+ topics = pd.read_csv(proj_dir/'data'/'topics.csv')
29
+ topics['topic_id'] = topics['topic_id'].apply(lambda x: f'{x:02d}')
30
+
31
+ st.session_state.topics = topics
32
+
33
+ def main():
34
+
35
+ max_docs = st.sidebar.slider("# docs", 10, 100, value=50)
36
+ to_search = st.text_input("Write your query here", "") or ""
37
+ with st.spinner('Embedding Query...'):
38
+ vector = st.session_state.model.embed([to_search])
39
+ with st.spinner('Dimension Reduction...'):
40
+ point = st.session_state.umap_model.transform(vector.reshape(1, -1))
41
+
42
+ documents, document_scores, document_ids = st.session_state.model.search_documents_by_vector(vector.flatten(), num_docs=max_docs)
43
+ st.session_state.search_raw_df = pd.DataFrame({'document_ids':document_ids, 'document_scores':document_scores})
44
+
45
+ st.session_state.data_to_model = st.session_state.data.merge(st.session_state.search_raw_df, left_on='id', right_on='document_ids').drop(['document_ids'], axis=1)
46
+ st.session_state.data_to_model = st.session_state.data_to_model.sort_values(by='document_scores', ascending=False) # to make legend sorted https://bioinformatics.stackexchange.com/a/18847
47
+ st.session_state.data_to_model.loc[len(st.session_state.data_to_model.index)] = ['Point', *point[0].tolist(), to_search, 'Query', 0]
48
+ st.session_state.data_to_model_with_point = st.session_state.data_to_model
49
+ st.session_state.data_to_model_without_point = st.session_state.data_to_model.iloc[:-1]
50
+
51
+ def get_topics_counts() -> pd.DataFrame:
52
+ topic_counts = st.session_state.data_to_model_without_point["topic_id"].value_counts().to_frame()
53
+ merged = topic_counts.merge(st.session_state.topics, left_index=True, right_on='topic_id')
54
+ cleaned = merged.drop(['topic_id_y'], axis=1).rename({'topic_id_x':'topic_count'}, axis=1)
55
+ cols = ['topic_id'] + [col for col in cleaned.columns if col != 'topic_id']
56
+ return cleaned[cols]
57
+
58
+ st.write("""
59
+ # Semantic Search
60
+ This shows a 2d representation of documents embeded in a semantic space. Each dot is a document
61
+ and the dots close represent documents that are close in meaning.
62
+
63
+ Note that the distance metrics were computed at a higher dimension so take the representation with
64
+ a grain of salt.
65
+
66
+ The Query is shown with the documents in yellow.
67
+ """
68
+ )
69
+
70
+
71
+ df = st.session_state.data_to_model_with_point.sort_values(by='topic_id', ascending=True)
72
+ fig = px.scatter(df.iloc[:-1], x='x', y='y', color='topic_id', template='plotly_dark', hover_data=['id', 'topic_id', 'x', 'y'])
73
+ fig.add_traces(px.scatter(df.tail(1), x="x", y="y").update_traces(marker_size=10, marker_color="yellow").data)
74
+ st.plotly_chart(fig, use_container_width=True)
75
+ tab1, tab2 = st.tabs(["Docs", "Topics"])
76
+
77
+
78
+ with tab1:
79
+ cols = ['id', 'document_scores', 'topic_id', 'documents']
80
+ builder = GridOptionsBuilder.from_dataframe(st.session_state.data_to_model_without_point.loc[:, cols])
81
+ builder.configure_pagination()
82
+ builder.configure_column('document_scores', type=["numericColumn","numberColumnFilter","customNumericFormat"], precision=2)
83
+ go = builder.build()
84
+ AgGrid(st.session_state.data_to_model_without_point.loc[:,cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_CONTENTS)
85
+
86
+
87
+ with tab2:
88
+ cols = ['topic_id', 'topic_count', 'topic_0']
89
+ topic_counts = get_topics_counts()
90
+ builder = GridOptionsBuilder.from_dataframe(topic_counts[cols])
91
+ builder.configure_pagination()
92
+ builder.configure_column('topic_0', header_name='Topic Word', wrap_text=True)
93
+ go = builder.build()
94
+ AgGrid(topic_counts.loc[:,cols], theme='streamlit', gridOptions=go, columns_auto_size_mode=ColumnsAutoSizeMode.FIT_ALL_COLUMNS_TO_VIEW)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ # Setting up Logger and proj_dir
99
+ logger = getLogger(__name__)
100
+ proj_dir = Path(__file__).parents[2]
101
+
102
+ # For max width tables
103
+ pd.set_option('display.max_colwidth', 0)
104
+
105
+ # Streamlit settings
106
+ st.set_page_config(layout="wide")
107
+ md_title = "# Semantic Search πŸ”"
108
+ st.markdown(md_title)
109
+ st.sidebar.markdown(md_title)
110
+
111
+ initialize_state()
112
+ main()
bootstrap.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import streamlit.web.bootstrap
3
+ from streamlit import config as _config
4
+
5
+ proj_dir = Path(__file__).parent
6
+ filename = proj_dir / "app" / "Top2Vec.py"
7
+
8
+ _config.set_option("server.headless", True)
9
+ args = []
10
+
11
+ # streamlit.cli.main_run(filename, args)
12
+ streamlit.web.bootstrap.run(str(filename), "", args, "")
notebooks/explore.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ top2vec[sentence_transformers]==1.0.27
2
+ scikit-learn==1.1.1
3
+ jupyter==1.0.0
4
+ streamlit==1.16.0
5
+ streamlit-aggrid==0.3.3
6
+ streamlit-plotly-events==0.0.6
7
+ plotly==5.9.0
8
+ datasets==2.8.0
9
+ keybert==0.7.0