stripnet / app.py
stephenleo's picture
remove progress bar routing
7e53eea
raw
history blame contribute delete
No virus
8.66 kB
import networkx as nx
from streamlit.components.v1 import html
import streamlit as st
import helpers
import logging
# Setup Basic Configuration
st.set_page_config(layout='wide',
page_title='STriP: Semantic Similarity of Scientific Papers!',
page_icon='πŸ•ΈοΈ'
)
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s: %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger('main')
def load_data():
"""Loads the data from the uploaded file.
"""
st.header('πŸ“‚ Load Data')
about_load_data = read_md('markdown/load_data.md')
st.markdown(about_load_data)
uploaded_file = st.file_uploader("Choose a CSV file",
help='Upload a CSV file with the following columns: Title, Abstract')
if uploaded_file is not None:
df = helpers.load_data(uploaded_file)
else:
df = helpers.load_data('data.csv')
data = df.copy()
# Column Selection. By default, any column called 'title' and 'abstract' are selected
st.subheader('Select columns to analyze')
selected_cols = st.multiselect(label='Select one or more columns. All the selected columns are concatenated before analyzing', options=data.columns,
default=[col for col in data.columns if col.lower() in ['title', 'abstract']])
if not selected_cols:
st.error('No columns selected! Please select some text columns to analyze')
data = data[selected_cols]
# Minor cleanup
data = data.dropna()
# Load max 200 rows only
st.write(f'Number of rows: {len(data)}')
n_rows = 200
if len(data) > n_rows:
data = data.sample(n_rows, random_state=0)
st.write(f'Only random {n_rows} rows will be analyzed')
data = data.reset_index(drop=True)
# Prints
st.write('First 5 rows of loaded data:')
st.write(data[selected_cols].head())
# Combine all selected columns
if (data is not None) and selected_cols:
data['Text'] = data[data.columns[0]]
for column in data.columns[1:]:
data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)
return data, selected_cols
def topic_modeling(data):
"""Runs the topic modeling step.
"""
st.header('πŸ”₯ Topic Modeling')
about_topic_modeling = read_md('markdown/topic_modeling.md')
st.markdown(about_topic_modeling)
cols = st.columns(3)
with cols[0]:
min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
max_value=min(round(len(data)*0.25), 100), step=1, value=min(round(len(data)/25), 10),
help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
with cols[1]:
n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
max_value=3, step=1, value=(1, 2),
help='N-gram range for the topic model')
with cols[2]:
st.text('')
st.text('')
st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})
with st.spinner('Topic Modeling'):
#enable this to route terminal logs to the frontend
#with helpers.st_stdout("success"), helpers.st_stderr("code"):
topic_data, topic_model, topics = helpers.topic_modeling(
data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)
mapping = {
'Topic Keywords': topic_model.visualize_barchart,
'Topic Similarities': topic_model.visualize_heatmap,
'Topic Hierarchies': topic_model.visualize_hierarchy,
'Intertopic Distance': topic_model.visualize_topics
}
cols = st.columns(3)
with cols[0]:
topic_model_vis_option = st.selectbox(
'Select Topic Modeling Visualization', mapping.keys())
try:
fig = mapping[topic_model_vis_option](top_n_topics=10)
fig.update_layout(title='')
st.plotly_chart(fig, use_container_width=True)
except:
st.warning(
'No visualization available. Try a lower Minimum topic size!')
return topic_data, topics
def strip_network(data, topic_data, topics):
"""Generated the STriP network.
"""
st.header('πŸš€ STriP Network')
about_stripnet = read_md('markdown/stripnet.md')
st.markdown(about_stripnet)
with st.spinner('Cosine Similarity Calculation'):
cosine_sim_matrix = helpers.cosine_sim(data)
value, min_value = helpers.calc_optimal_threshold(
cosine_sim_matrix,
# 25% is a good value for the number of papers
max_connections=min(
helpers.calc_max_connections(len(data), 0.25), 5_000
)
)
cols = st.columns(3)
with cols[0]:
threshold = st.slider('Cosine Similarity Threshold', key='threshold', min_value=min_value,
max_value=1.0, step=0.01, value=value,
help='The minimum cosine similarity between papers to draw a connection. Increasing this value will lead to a lesser connections.')
neighbors, num_connections = helpers.calc_neighbors(
cosine_sim_matrix, threshold)
st.write(f'Number of connections: {num_connections}')
with cols[1]:
st.text('')
st.text('')
st.button('Reset Defaults', on_click=helpers.reset_default_threshold_slider, key='reset_threshold',
kwargs={'threshold': value})
with st.spinner('Network Generation'):
nx_net, pyvis_net = helpers.network_plot(
topic_data, topics, neighbors)
# Save and read graph as HTML file (on Streamlit Sharing)
try:
path = '/tmp'
pyvis_net.save_graph(f'{path}/pyvis_graph.html')
HtmlFile = open(f'{path}/pyvis_graph.html',
'r', encoding='utf-8')
# Save and read graph as HTML file (locally)
except:
path = '/html_files'
pyvis_net.save_graph(f'{path}/pyvis_graph.html')
HtmlFile = open(f'{path}/pyvis_graph.html',
'r', encoding='utf-8')
# Load HTML file in HTML component for display on Streamlit page
html(HtmlFile.read(), height=800)
return nx_net
def network_centrality(nx_net, topic_data):
"""Finds most important papers using network centrality measures.
"""
st.header('πŸ… Most Important Papers')
about_centrality = read_md('markdown/centrality.md')
st.markdown(about_centrality)
centrality_mapping = {
'Betweenness Centrality': nx.betweenness_centrality,
'Closeness Centrality': nx.closeness_centrality,
'Degree Centrality': nx.degree_centrality,
'Eigenvector Centrality': nx.eigenvector_centrality,
}
cols = st.columns(3)
with cols[0]:
centrality_option = st.selectbox(
'Select Centrality Measure', centrality_mapping.keys())
# Calculate centrality
centrality = centrality_mapping[centrality_option](nx_net)
cols = st.columns([1, 10, 1])
with cols[1]:
with st.spinner('Network Centrality Calculation'):
fig = helpers.network_centrality(
topic_data, centrality, centrality_option)
st.plotly_chart(fig, use_container_width=True)
def read_md(file_path):
"""Reads a markdown file and returns the contents.
"""
with open(file_path, 'r') as f:
content = f.read()
return content
def main():
st.title('STriPNet: Semantic Similarity of Scientific Papers!')
about_stripnet = read_md('markdown/about_stripnet.md')
st.markdown(about_stripnet)
logger.info('========== Step1: Loading data ==========')
data, selected_cols = load_data()
if (data is not None) and selected_cols:
logger.info('========== Step2: Topic modeling ==========')
topic_data, topics = topic_modeling(data)
logger.info('========== Step3: STriP Network ==========')
nx_net = strip_network(data, topic_data, topics)
logger.info('========== Step4: Network Centrality ==========')
network_centrality(nx_net, topic_data)
about_me = read_md('markdown/about_me.md')
st.markdown(about_me)
if __name__ == '__main__':
main()