stephenleo commited on
Commit
d2b0a3c
1 Parent(s): 7475828

refactor and adding progress bar for emb gen

Browse files
Files changed (2) hide show
  1. app.py +161 -122
  2. helpers.py +59 -10
app.py CHANGED
@@ -18,167 +18,187 @@ logging.basicConfig(level=logging.INFO,
18
  logger = logging.getLogger('main')
19
 
20
 
21
- def main():
22
- st.title('STriP (S3P): Semantic Similarity of Scientific Papers!')
 
23
 
24
  st.header('📂 Load Data')
25
  uploaded_file = st.file_uploader("Choose a CSV file",
26
  help='Upload a CSV file with the following columns: Title, Abstract')
27
 
28
- ##########
29
- # Load data
30
- ##########
31
- logger.info('========== Step1: Loading data ==========')
32
  if uploaded_file is not None:
33
  df = helpers.load_data(uploaded_file)
34
  else:
35
  df = helpers.load_data('data.csv')
36
-
37
  data = df.copy()
38
- selected_cols = st.multiselect('Select columns to analyse', options=data.columns,
 
 
 
39
  default=[col for col in data.columns if col.lower() in ['title', 'abstract']])
 
 
 
 
40
  data = data[selected_cols]
 
 
41
  data = data.dropna()
42
  data = data.reset_index(drop=True)
 
 
43
  st.write(f'Number of rows: {len(data)}')
44
  if len(data) > 200:
45
  data = data.iloc[:200]
46
  st.write(f'Only first 200 rows will be analyzed')
 
 
47
  st.write('First 5 rows of loaded data:')
48
  st.write(data[selected_cols].head())
49
 
 
50
  if (data is not None) and selected_cols:
51
- # For 'allenai-specter'
52
  data['Text'] = data[data.columns[0]]
53
  for column in data.columns[1:]:
54
  data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)
55
 
56
- ##########
57
- # Topic modeling
58
- ##########
59
- logger.info('========== Step2: Topic modeling ==========')
60
- st.header('🔥 Topic Modeling')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  cols = st.columns(3)
63
  with cols[0]:
64
- min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
65
- max_value=min(round(len(data)*0.25), 100), step=1, value=min(round(len(data)/25), 10),
66
- help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
67
- with cols[1]:
68
- n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
69
- max_value=3, step=1, value=(1, 2),
70
- help='N-gram range for the topic model')
71
- with cols[2]:
72
- st.text('')
73
- st.text('')
74
- st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
75
- kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})
76
-
77
- with st.spinner('Topic Modeling'):
78
- topic_data, topic_model, topics = helpers.topic_modeling(
79
- data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)
80
 
81
- mapping = {
82
- 'Topic Keywords': topic_model.visualize_barchart,
83
- 'Topic Similarities': topic_model.visualize_heatmap,
84
- 'Topic Hierarchies': topic_model.visualize_hierarchy,
85
- 'Intertopic Distance': topic_model.visualize_topics
86
- }
87
-
88
- cols = st.columns(3)
89
- with cols[0]:
90
- topic_model_vis_option = st.selectbox(
91
- 'Select Topic Modeling Visualization', mapping.keys())
92
- try:
93
- fig = mapping[topic_model_vis_option](top_n_topics=10)
94
- fig.update_layout(title='')
95
- st.plotly_chart(fig, use_container_width=True)
96
- except:
97
- st.warning(
98
- 'No visualization available. Try a lower Minimum topic size!')
99
-
100
- ##########
101
- # STriP Network
102
- ##########
103
- logger.info('========== Step3: STriP Network ==========')
104
- st.header('🚀 STriP Network')
105
 
106
- with st.spinner('Cosine Similarity Calculation'):
107
- cosine_sim_matrix = helpers.cosine_sim(data)
108
 
109
- value, min_value = helpers.calc_optimal_threshold(
110
- cosine_sim_matrix,
111
- # 25% is a good value for the number of papers
112
- max_connections=min(
113
- helpers.calc_max_connections(len(data), 0.25), 5_000
114
- )
115
  )
 
116
 
117
- cols = st.columns(3)
118
- with cols[0]:
119
- threshold = st.slider('Cosine Similarity Threshold', key='threshold', min_value=min_value,
120
- max_value=1.0, step=0.01, value=value,
121
- help='The minimum cosine similarity between papers to draw a connection. Increasing this value will lead to a lesser connections.')
122
-
123
- neighbors, num_connections = helpers.calc_neighbors(
124
- cosine_sim_matrix, threshold)
125
- st.write(f'Number of connections: {num_connections}')
126
-
127
- with cols[1]:
128
- st.text('')
129
- st.text('')
130
- st.button('Reset Defaults', on_click=helpers.reset_default_threshold_slider, key='reset_threshold',
131
- kwargs={'threshold': value})
132
-
133
- with st.spinner('Network Generation'):
134
- nx_net, pyvis_net = helpers.network_plot(
135
- topic_data, topics, neighbors)
136
-
137
- # Save and read graph as HTML file (on Streamlit Sharing)
138
- try:
139
- path = '/tmp'
140
- pyvis_net.save_graph(f'{path}/pyvis_graph.html')
141
- HtmlFile = open(f'{path}/pyvis_graph.html',
142
- 'r', encoding='utf-8')
143
-
144
- # Save and read graph as HTML file (locally)
145
- except:
146
- path = '/html_files'
147
- pyvis_net.save_graph(f'{path}/pyvis_graph.html')
148
- HtmlFile = open(f'{path}/pyvis_graph.html',
149
- 'r', encoding='utf-8')
150
-
151
- # Load HTML file in HTML component for display on Streamlit page
152
- html(HtmlFile.read(), height=800)
153
-
154
- ##########
155
- # Centrality
156
- ##########
157
- logger.info('========== Step4: Network Centrality ==========')
158
- st.header('🏅 Most Important Papers')
159
 
160
- centrality_mapping = {
161
- 'Closeness Centrality': nx.closeness_centrality,
162
- 'Degree Centrality': nx.degree_centrality,
163
- 'Eigenvector Centrality': nx.eigenvector_centrality,
164
- 'Betweenness Centrality': nx.betweenness_centrality,
165
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- cols = st.columns(3)
168
- with cols[0]:
169
- centrality_option = st.selectbox(
170
- 'Select Centrality Measure', centrality_mapping.keys())
171
 
172
- # Calculate centrality
173
- centrality = centrality_mapping[centrality_option](nx_net)
 
174
 
175
- cols = st.columns([1, 10, 1])
176
- with cols[1]:
177
- with st.spinner('Network Centrality Calculation'):
178
- fig = helpers.network_centrality(
179
- topic_data, centrality, centrality_option)
180
- st.plotly_chart(fig, use_container_width=True)
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  st.markdown(
183
  """
184
  💡🔥🚀 STriP v1.0 🚀🔥💡
@@ -194,5 +214,24 @@ def main():
194
  )
195
 
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  if __name__ == '__main__':
198
  main()
 
18
  logger = logging.getLogger('main')
19
 
20
 
21
+ def load_data():
22
+ """Loads the data from the uploaded file.
23
+ """
24
 
25
  st.header('📂 Load Data')
26
  uploaded_file = st.file_uploader("Choose a CSV file",
27
  help='Upload a CSV file with the following columns: Title, Abstract')
28
 
 
 
 
 
29
  if uploaded_file is not None:
30
  df = helpers.load_data(uploaded_file)
31
  else:
32
  df = helpers.load_data('data.csv')
 
33
  data = df.copy()
34
+
35
+ # Column Selection. By default, any column called 'title' and 'abstract' are selected
36
+ st.subheader('Select columns to analyze')
37
+ selected_cols = st.multiselect(label='Select one or more columns. All the selected columns are concatenated before analyzing', options=data.columns,
38
  default=[col for col in data.columns if col.lower() in ['title', 'abstract']])
39
+
40
+ if not selected_cols:
41
+ st.error('No columns selected! Please select some text columns to analyze')
42
+
43
  data = data[selected_cols]
44
+
45
+ # Minor cleanup
46
  data = data.dropna()
47
  data = data.reset_index(drop=True)
48
+
49
+ # Load max 200 rows only
50
  st.write(f'Number of rows: {len(data)}')
51
  if len(data) > 200:
52
  data = data.iloc[:200]
53
  st.write(f'Only first 200 rows will be analyzed')
54
+
55
+ # Prints
56
  st.write('First 5 rows of loaded data:')
57
  st.write(data[selected_cols].head())
58
 
59
+ # Combine all selected columns
60
  if (data is not None) and selected_cols:
 
61
  data['Text'] = data[data.columns[0]]
62
  for column in data.columns[1:]:
63
  data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)
64
 
65
+ return data, selected_cols
66
+
67
+
68
+ def topic_modeling(data):
69
+ """Runs the topic modeling step.
70
+ """
71
+
72
+ st.header('🔥 Topic Modeling')
73
+ cols = st.columns(3)
74
+ with cols[0]:
75
+ min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
76
+ max_value=min(round(len(data)*0.25), 100), step=1, value=min(round(len(data)/25), 10),
77
+ help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
78
+ with cols[1]:
79
+ n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
80
+ max_value=3, step=1, value=(1, 2),
81
+ help='N-gram range for the topic model')
82
+ with cols[2]:
83
+ st.text('')
84
+ st.text('')
85
+ st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
86
+ kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})
87
+
88
+ with st.spinner('Topic Modeling'):
89
+ with helpers.st_stdout("success"), helpers.st_stderr("code"):
90
+ topic_data, topic_model, topics = helpers.topic_modeling(
91
+ data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)
92
+
93
+ mapping = {
94
+ 'Topic Keywords': topic_model.visualize_barchart,
95
+ 'Topic Similarities': topic_model.visualize_heatmap,
96
+ 'Topic Hierarchies': topic_model.visualize_hierarchy,
97
+ 'Intertopic Distance': topic_model.visualize_topics
98
+ }
99
 
100
  cols = st.columns(3)
101
  with cols[0]:
102
+ topic_model_vis_option = st.selectbox(
103
+ 'Select Topic Modeling Visualization', mapping.keys())
104
+ try:
105
+ fig = mapping[topic_model_vis_option](top_n_topics=10)
106
+ fig.update_layout(title='')
107
+ st.plotly_chart(fig, use_container_width=True)
108
+ except:
109
+ st.warning(
110
+ 'No visualization available. Try a lower Minimum topic size!')
 
 
 
 
 
 
 
111
 
112
+ return topic_data, topics
113
+
114
+
115
+ def strip_network(data, topic_data, topics):
116
+ """Generated the STriP network.
117
+ """
118
+
119
+ st.header('🚀 STriP Network')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ with st.spinner('Cosine Similarity Calculation'):
122
+ cosine_sim_matrix = helpers.cosine_sim(data)
123
 
124
+ value, min_value = helpers.calc_optimal_threshold(
125
+ cosine_sim_matrix,
126
+ # 25% is a good value for the number of papers
127
+ max_connections=min(
128
+ helpers.calc_max_connections(len(data), 0.25), 5_000
 
129
  )
130
+ )
131
 
132
+ cols = st.columns(3)
133
+ with cols[0]:
134
+ threshold = st.slider('Cosine Similarity Threshold', key='threshold', min_value=min_value,
135
+ max_value=1.0, step=0.01, value=value,
136
+ help='The minimum cosine similarity between papers to draw a connection. Increasing this value will lead to a lesser connections.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ neighbors, num_connections = helpers.calc_neighbors(
139
+ cosine_sim_matrix, threshold)
140
+ st.write(f'Number of connections: {num_connections}')
141
+
142
+ with cols[1]:
143
+ st.text('')
144
+ st.text('')
145
+ st.button('Reset Defaults', on_click=helpers.reset_default_threshold_slider, key='reset_threshold',
146
+ kwargs={'threshold': value})
147
+
148
+ with st.spinner('Network Generation'):
149
+ nx_net, pyvis_net = helpers.network_plot(
150
+ topic_data, topics, neighbors)
151
+
152
+ # Save and read graph as HTML file (on Streamlit Sharing)
153
+ try:
154
+ path = '/tmp'
155
+ pyvis_net.save_graph(f'{path}/pyvis_graph.html')
156
+ HtmlFile = open(f'{path}/pyvis_graph.html',
157
+ 'r', encoding='utf-8')
158
+
159
+ # Save and read graph as HTML file (locally)
160
+ except:
161
+ path = '/html_files'
162
+ pyvis_net.save_graph(f'{path}/pyvis_graph.html')
163
+ HtmlFile = open(f'{path}/pyvis_graph.html',
164
+ 'r', encoding='utf-8')
165
+
166
+ # Load HTML file in HTML component for display on Streamlit page
167
+ html(HtmlFile.read(), height=800)
168
+
169
+ return nx_net
170
 
 
 
 
 
171
 
172
+ def network_centrality(nx_net, topic_data):
173
+ """Finds most important papers using network centrality measures.
174
+ """
175
 
176
+ st.header('🏅 Most Important Papers')
 
 
 
 
 
177
 
178
+ centrality_mapping = {
179
+ 'Closeness Centrality': nx.closeness_centrality,
180
+ 'Degree Centrality': nx.degree_centrality,
181
+ 'Eigenvector Centrality': nx.eigenvector_centrality,
182
+ 'Betweenness Centrality': nx.betweenness_centrality,
183
+ }
184
+
185
+ cols = st.columns(3)
186
+ with cols[0]:
187
+ centrality_option = st.selectbox(
188
+ 'Select Centrality Measure', centrality_mapping.keys())
189
+
190
+ # Calculate centrality
191
+ centrality = centrality_mapping[centrality_option](nx_net)
192
+
193
+ cols = st.columns([1, 10, 1])
194
+ with cols[1]:
195
+ with st.spinner('Network Centrality Calculation'):
196
+ fig = helpers.network_centrality(
197
+ topic_data, centrality, centrality_option)
198
+ st.plotly_chart(fig, use_container_width=True)
199
+
200
+
201
+ def about_me():
202
  st.markdown(
203
  """
204
  💡🔥🚀 STriP v1.0 🚀🔥💡
 
214
  )
215
 
216
 
217
+ def main():
218
+ st.title('STriP (S3P): Semantic Similarity of Scientific Papers!')
219
+
220
+ logger.info('========== Step1: Loading data ==========')
221
+ data, selected_cols = load_data()
222
+
223
+ if (data is not None) and selected_cols:
224
+ logger.info('========== Step2: Topic modeling ==========')
225
+ topic_data, topics = topic_modeling(data)
226
+
227
+ logger.info('========== Step3: STriP Network ==========')
228
+ nx_net = strip_network(data, topic_data, topics)
229
+
230
+ logger.info('========== Step4: Network Centrality ==========')
231
+ network_centrality(nx_net, topic_data)
232
+
233
+ about_me()
234
+
235
+
236
  if __name__ == '__main__':
237
  main()
helpers.py CHANGED
@@ -11,6 +11,13 @@ import networkx as nx
11
  import textwrap
12
  import logging
13
 
 
 
 
 
 
 
 
14
  logger = logging.getLogger('main')
15
 
16
 
@@ -70,6 +77,8 @@ def topic_modeling(data, min_topic_size, n_gram_range):
70
  # Optimization: Only take top 10 largest topics
71
  topics = topic_df.head(10).set_index('Topic').to_dict(orient='index')
72
 
 
 
73
  return topic_data, topic_model, topics
74
 
75
 
@@ -91,6 +100,13 @@ def calc_max_connections(num_papers, ratio):
91
  return n*(n-1)/2
92
 
93
 
 
 
 
 
 
 
 
94
  @st.cache()
95
  def calc_optimal_threshold(cosine_sim_matrix, max_connections):
96
  """Calculates the optimal threshold for the cosine similarity matrix.
@@ -99,21 +115,13 @@ def calc_optimal_threshold(cosine_sim_matrix, max_connections):
99
  logger.info('Calculating optimal threshold')
100
  thresh_sweep = np.arange(0.05, 1.05, 0.05)[::-1]
101
  for idx, threshold in enumerate(thresh_sweep):
102
- neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
103
- if len(neighbors) > max_connections:
104
  break
105
 
106
  return round(thresh_sweep[idx-1], 2).item(), round(thresh_sweep[idx], 2).item()
107
 
108
 
109
- @st.cache()
110
- def calc_neighbors(cosine_sim_matrix, threshold):
111
- logger.info('Calculating neighbors')
112
- neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
113
-
114
- return neighbors, len(neighbors)
115
-
116
-
117
  def nx_hash_func(nx_net):
118
  """Hash function for NetworkX graphs.
119
  """
@@ -219,3 +227,44 @@ def network_centrality(topic_data, centrality, centrality_option):
219
  fig.update_layout(yaxis={'categoryorder': 'total ascending', 'visible': False, 'showticklabels': False},
220
  font={'size': 15}, height=800)
221
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import textwrap
12
  import logging
13
 
14
+ from streamlit.report_thread import REPORT_CONTEXT_ATTR_NAME
15
+ from threading import current_thread
16
+ from contextlib import contextmanager
17
+ from io import StringIO
18
+ import sys
19
+ import time
20
+
21
  logger = logging.getLogger('main')
22
 
23
 
 
77
  # Optimization: Only take top 10 largest topics
78
  topics = topic_df.head(10).set_index('Topic').to_dict(orient='index')
79
 
80
+ logger.info('Topic Modeling Complete')
81
+
82
  return topic_data, topic_model, topics
83
 
84
 
 
100
  return n*(n-1)/2
101
 
102
 
103
+ @st.cache()
104
+ def calc_neighbors(cosine_sim_matrix, threshold):
105
+ neighbors = np.argwhere(cosine_sim_matrix >= threshold).tolist()
106
+
107
+ return neighbors, len(neighbors)
108
+
109
+
110
  @st.cache()
111
  def calc_optimal_threshold(cosine_sim_matrix, max_connections):
112
  """Calculates the optimal threshold for the cosine similarity matrix.
 
115
  logger.info('Calculating optimal threshold')
116
  thresh_sweep = np.arange(0.05, 1.05, 0.05)[::-1]
117
  for idx, threshold in enumerate(thresh_sweep):
118
+ _, num_neighbors = calc_neighbors(cosine_sim_matrix, threshold)
119
+ if num_neighbors > max_connections:
120
  break
121
 
122
  return round(thresh_sweep[idx-1], 2).item(), round(thresh_sweep[idx], 2).item()
123
 
124
 
 
 
 
 
 
 
 
 
125
  def nx_hash_func(nx_net):
126
  """Hash function for NetworkX graphs.
127
  """
 
227
  fig.update_layout(yaxis={'categoryorder': 'total ascending', 'visible': False, 'showticklabels': False},
228
  font={'size': 15}, height=800)
229
  return fig
230
+
231
+
232
+ # Progress bar printer
233
+ # https://github.com/BugzTheBunny/streamlit_logging_output_example/blob/main/app.py
234
+ # https://discuss.streamlit.io/t/cannot-print-the-terminal-output-in-streamlit/6602/34
235
+ @contextmanager
236
+ def st_redirect(src, dst):
237
+ placeholder = st.empty()
238
+ output_func = getattr(placeholder, dst)
239
+
240
+ with StringIO() as buffer:
241
+ old_write = src.write
242
+
243
+ def new_write(b):
244
+ if getattr(current_thread(), REPORT_CONTEXT_ATTR_NAME, None):
245
+ buffer.write(b)
246
+ time.sleep(1)
247
+ buffer.seek(0) # returns pointer to 0 position
248
+ output_func(b)
249
+ else:
250
+ old_write(b)
251
+
252
+ try:
253
+ src.write = new_write
254
+ yield
255
+ finally:
256
+ src.write = old_write
257
+
258
+
259
+ @contextmanager
260
+ def st_stdout(dst):
261
+ "this will show the prints"
262
+ with st_redirect(sys.stdout, dst):
263
+ yield
264
+
265
+
266
+ @contextmanager
267
+ def st_stderr(dst):
268
+ "This will show the logging"
269
+ with st_redirect(sys.stderr, dst):
270
+ yield