File size: 8,656 Bytes
34b12ff
 
 
 
d9f2adf
 
 
34b12ff
576be81
11c8563
34b12ff
 
 
d9f2adf
 
 
 
 
 
 
d2b0a3c
 
 
34b12ff
 
d10a664
 
5ab63e8
34b12ff
 
 
 
 
 
 
 
d2b0a3c
 
 
 
d9f2adf
d2b0a3c
 
 
 
d9f2adf
d2b0a3c
 
d9f2adf
d2b0a3c
 
418c4bc
2afa46a
 
 
 
b61e914
 
d2b0a3c
 
34b12ff
d9f2adf
 
d2b0a3c
d9f2adf
 
 
 
34b12ff
d2b0a3c
 
 
 
 
 
 
 
d10a664
 
5ab63e8
d2b0a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e53eea
e0f2df1
 
 
d2b0a3c
 
 
 
 
 
 
34b12ff
 
 
d2b0a3c
 
 
 
 
 
 
 
 
34b12ff
d2b0a3c
 
 
 
 
 
 
 
d10a664
 
34b12ff
d2b0a3c
 
34b12ff
d2b0a3c
 
 
 
 
34b12ff
d2b0a3c
34b12ff
d2b0a3c
 
 
 
 
34b12ff
d2b0a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34b12ff
 
d2b0a3c
 
 
34b12ff
d2b0a3c
d10a664
 
34b12ff
d2b0a3c
b61e914
d2b0a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ab63e8
d10a664
5ab63e8
 
37d2b76
34b12ff
5ab63e8
34b12ff
 
d2b0a3c
5ab63e8
d10a664
 
d2b0a3c
 
 
 
 
 
 
 
 
 
 
 
 
 
d10a664
 
d2b0a3c
 
34b12ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import networkx as nx
from streamlit.components.v1 import html
import streamlit as st
import helpers
import logging

# Setup Basic Configuration
st.set_page_config(layout='wide',
                   page_title='STriP: Semantic Similarity of Scientific Papers!',
                   page_icon='๐Ÿ•ธ๏ธ'
                   )


logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(levelname)s: %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S')

logger = logging.getLogger('main')


def load_data():
    """Loads the data from the uploaded file.
    """

    st.header('๐Ÿ“‚ Load Data')
    about_load_data = read_md('markdown/load_data.md')
    st.markdown(about_load_data)

    uploaded_file = st.file_uploader("Choose a CSV file",
                                     help='Upload a CSV file with the following columns: Title, Abstract')

    if uploaded_file is not None:
        df = helpers.load_data(uploaded_file)
    else:
        df = helpers.load_data('data.csv')
    data = df.copy()

    # Column Selection. By default, any column called 'title' and 'abstract' are selected
    st.subheader('Select columns to analyze')
    selected_cols = st.multiselect(label='Select one or more columns. All the selected columns are concatenated before analyzing', options=data.columns,
                                   default=[col for col in data.columns if col.lower() in ['title', 'abstract']])

    if not selected_cols:
        st.error('No columns selected! Please select some text columns to analyze')

    data = data[selected_cols]

    # Minor cleanup
    data = data.dropna()

    # Load max 200 rows only
    st.write(f'Number of rows: {len(data)}')
    n_rows = 200
    if len(data) > n_rows:
        data = data.sample(n_rows, random_state=0)
        st.write(f'Only random {n_rows} rows will be analyzed')

    data = data.reset_index(drop=True)

    # Prints
    st.write('First 5 rows of loaded data:')
    st.write(data[selected_cols].head())

    # Combine all selected columns
    if (data is not None) and selected_cols:
        data['Text'] = data[data.columns[0]]
        for column in data.columns[1:]:
            data['Text'] = data['Text'] + '[SEP]' + data[column].astype(str)

    return data, selected_cols


def topic_modeling(data):
    """Runs the topic modeling step.
    """

    st.header('๐Ÿ”ฅ Topic Modeling')
    about_topic_modeling = read_md('markdown/topic_modeling.md')
    st.markdown(about_topic_modeling)

    cols = st.columns(3)
    with cols[0]:
        min_topic_size = st.slider('Minimum topic size', key='min_topic_size', min_value=2,
                                   max_value=min(round(len(data)*0.25), 100), step=1, value=min(round(len(data)/25), 10),
                                   help='The minimum size of the topic. Increasing this value will lead to a lower number of clusters/topics.')
    with cols[1]:
        n_gram_range = st.slider('N-gram range', key='n_gram_range', min_value=1,
                                 max_value=3, step=1, value=(1, 2),
                                 help='N-gram range for the topic model')
    with cols[2]:
        st.text('')
        st.text('')
        st.button('Reset Defaults', on_click=helpers.reset_default_topic_sliders, key='reset_topic_sliders',
                  kwargs={'min_topic_size': min(round(len(data)/25), 10), 'n_gram_range': (1, 2)})

    with st.spinner('Topic Modeling'):
        #enable this to route terminal logs to the frontend
        #with helpers.st_stdout("success"), helpers.st_stderr("code"):
        topic_data, topic_model, topics = helpers.topic_modeling(
            data, min_topic_size=min_topic_size, n_gram_range=n_gram_range)

        mapping = {
            'Topic Keywords': topic_model.visualize_barchart,
            'Topic Similarities': topic_model.visualize_heatmap,
            'Topic Hierarchies': topic_model.visualize_hierarchy,
            'Intertopic Distance': topic_model.visualize_topics
        }

        cols = st.columns(3)
        with cols[0]:
            topic_model_vis_option = st.selectbox(
                'Select Topic Modeling Visualization', mapping.keys())
        try:
            fig = mapping[topic_model_vis_option](top_n_topics=10)
            fig.update_layout(title='')
            st.plotly_chart(fig, use_container_width=True)
        except:
            st.warning(
                'No visualization available. Try a lower Minimum topic size!')

    return topic_data, topics


def strip_network(data, topic_data, topics):
    """Generated the STriP network.
    """

    st.header('๐Ÿš€ STriP Network')
    about_stripnet = read_md('markdown/stripnet.md')
    st.markdown(about_stripnet)

    with st.spinner('Cosine Similarity Calculation'):
        cosine_sim_matrix = helpers.cosine_sim(data)

    value, min_value = helpers.calc_optimal_threshold(
        cosine_sim_matrix,
        # 25% is a good value for the number of papers
        max_connections=min(
            helpers.calc_max_connections(len(data), 0.25), 5_000
        )
    )

    cols = st.columns(3)
    with cols[0]:
        threshold = st.slider('Cosine Similarity Threshold', key='threshold', min_value=min_value,
                              max_value=1.0, step=0.01, value=value,
                              help='The minimum cosine similarity between papers to draw a connection. Increasing this value will lead to a lesser connections.')

        neighbors, num_connections = helpers.calc_neighbors(
            cosine_sim_matrix, threshold)
        st.write(f'Number of connections: {num_connections}')

    with cols[1]:
        st.text('')
        st.text('')
        st.button('Reset Defaults', on_click=helpers.reset_default_threshold_slider, key='reset_threshold',
                  kwargs={'threshold': value})

    with st.spinner('Network Generation'):
        nx_net, pyvis_net = helpers.network_plot(
            topic_data, topics, neighbors)

        # Save and read graph as HTML file (on Streamlit Sharing)
        try:
            path = '/tmp'
            pyvis_net.save_graph(f'{path}/pyvis_graph.html')
            HtmlFile = open(f'{path}/pyvis_graph.html',
                            'r', encoding='utf-8')

        # Save and read graph as HTML file (locally)
        except:
            path = '/html_files'
            pyvis_net.save_graph(f'{path}/pyvis_graph.html')
            HtmlFile = open(f'{path}/pyvis_graph.html',
                            'r', encoding='utf-8')

        # Load HTML file in HTML component for display on Streamlit page
        html(HtmlFile.read(), height=800)

        return nx_net


def network_centrality(nx_net, topic_data):
    """Finds most important papers using network centrality measures.
    """

    st.header('๐Ÿ… Most Important Papers')
    about_centrality = read_md('markdown/centrality.md')
    st.markdown(about_centrality)

    centrality_mapping = {
        'Betweenness Centrality': nx.betweenness_centrality,
        'Closeness Centrality': nx.closeness_centrality,
        'Degree Centrality': nx.degree_centrality,
        'Eigenvector Centrality': nx.eigenvector_centrality,
    }

    cols = st.columns(3)
    with cols[0]:
        centrality_option = st.selectbox(
            'Select Centrality Measure', centrality_mapping.keys())

    # Calculate centrality
    centrality = centrality_mapping[centrality_option](nx_net)

    cols = st.columns([1, 10, 1])
    with cols[1]:
        with st.spinner('Network Centrality Calculation'):
            fig = helpers.network_centrality(
                topic_data, centrality, centrality_option)
            st.plotly_chart(fig, use_container_width=True)


def read_md(file_path):
    """Reads a markdown file and returns the contents.
    """
    with open(file_path, 'r') as f:
        content = f.read()

    return content


def main():
    st.title('STriPNet: Semantic Similarity of Scientific Papers!')
    about_stripnet = read_md('markdown/about_stripnet.md')
    st.markdown(about_stripnet)

    logger.info('========== Step1: Loading data ==========')
    data, selected_cols = load_data()

    if (data is not None) and selected_cols:
        logger.info('========== Step2: Topic modeling ==========')
        topic_data, topics = topic_modeling(data)

        logger.info('========== Step3: STriP Network ==========')
        nx_net = strip_network(data, topic_data, topics)

        logger.info('========== Step4: Network Centrality ==========')
        network_centrality(nx_net, topic_data)

    about_me = read_md('markdown/about_me.md')
    st.markdown(about_me)


if __name__ == '__main__':
    main()