naveed92 commited on
Commit
3153cb6
1 Parent(s): a676b41

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Visualizations for topic, text
2
+ # TODO: BERT model switch
3
+ # TODO: Default input
4
+ # TODO: Pre segmented sentences
5
+ # TODO: Progress bar
6
+
7
+ import streamlit as st
8
+ import pandas as pd
9
+ import re
10
+ import json
11
+
12
+ from utils import window, get_depths, get_local_maxima, compute_threshold, get_threshold_segments
13
+
14
+ st.write("loading ...")
15
+
16
+ import spacy
17
+ nlp = spacy.load('en_core_web_sm')
18
+
19
+ def print_list(lst):
20
+ for e in lst:
21
+ st.markdown("- " + e)
22
+
23
+ # Demo start
24
+
25
+ st.subheader("Topic Segmentation Demo")
26
+
27
+ uploaded_file = st.file_uploader("choose a text file", type=["txt"])
28
+
29
+ if uploaded_file is not None:
30
+ st.session_state["text"] = uploaded_file.getvalue().decode('utf-8')
31
+
32
+ st.write("OR")
33
+
34
+ input_text = st.text_area(
35
+ label="Enter text separated by newlines",
36
+ value="",
37
+ key="text",
38
+ height=150
39
+
40
+ )
41
+
42
+ button=st.button('Get Segments')
43
+
44
+ # Radio bar
45
+ # BERT or TOPIC
46
+ select_names = ["LDA Topic", "BERT"]
47
+ model = st.radio(label='Select model', options=select_names, index=0)
48
+
49
+ if (button==True) and input_text != "":
50
+
51
+ # Parse sample document and break it into sentences
52
+ texts = input_text.split('\n')
53
+ sents = []
54
+ for text in texts:
55
+ doc = nlp(text)
56
+ for sent in doc.sents:
57
+ sents.append(sent)
58
+
59
+ # Select tokens while ignoring punctuations and stopwords, and lowercase them
60
+ MIN_LENGTH = 3
61
+ tokenized_sents = [[token.lemma_.lower() for token in sent if
62
+ not token.is_stop and not token.is_punct and token.text.strip() and len(token) >= MIN_LENGTH]
63
+ for sent in sents]
64
+
65
+
66
+ st.write("building topic model ...")
67
+
68
+ # Build gensim dictionary and topic model
69
+ from gensim import corpora, models
70
+ import numpy as np
71
+
72
+ np.random.seed(123)
73
+
74
+ N_TOPICS = 5
75
+ N_PASSES = 5
76
+
77
+ dictionary = corpora.Dictionary(tokenized_sents)
78
+ bow = [dictionary.doc2bow(sent) for sent in tokenized_sents]
79
+ topic_model = models.LdaModel(corpus=bow, id2word=dictionary, num_topics=N_TOPICS, passes=N_PASSES)
80
+
81
+ ###st.write(topic_model.show_topics())
82
+
83
+
84
+ st.write("inferring topics ...")
85
+ # Infer topics with minimum threshold
86
+ THRESHOLD = 0.05
87
+ doc_topics = list(topic_model.get_document_topics(bow, minimum_probability=THRESHOLD))
88
+
89
+ # st.write(doc_topics)
90
+
91
+ # get top k topics for each sentence
92
+ k = 3
93
+ top_k_topics = [[t[0] for t in sorted(sent_topics, key=lambda x: x[1], reverse=True)][:k]
94
+ for sent_topics in doc_topics]
95
+ # st.write(top_k_topics)
96
+
97
+ ###st.write("apply window")
98
+
99
+ from itertools import chain
100
+
101
+ WINDOW_SIZE = 3
102
+
103
+ window_topics = window(top_k_topics, n=WINDOW_SIZE)
104
+ # assert(len(window_topics) == (len(tokenized_sents) - WINDOW_SIZE + 1))
105
+ window_topics = [list(set(chain.from_iterable(window))) for window in window_topics]
106
+
107
+ # Encode topics for similarity computation
108
+
109
+ from sklearn.preprocessing import MultiLabelBinarizer
110
+
111
+ binarizer = MultiLabelBinarizer(classes=range(N_TOPICS))
112
+
113
+ encoded_topic = binarizer.fit_transform(window_topics)
114
+
115
+ # Get similarities
116
+
117
+ st.write("generating segments ...")
118
+
119
+ from sklearn.metrics.pairwise import cosine_similarity
120
+
121
+ sims_topic = [cosine_similarity([pair[0]], [pair[1]])[0][0] for pair in zip(encoded_topic, encoded_topic[1:])]
122
+ # plot
123
+
124
+ # Compute depth scores
125
+ depths_topic = get_depths(sims_topic)
126
+ # plot
127
+
128
+ # Get local maxima
129
+ filtered_topic = get_local_maxima(depths_topic, order=1)
130
+ # plot
131
+
132
+ ###st.write("compute threshold")
133
+ # Automatic threshold computation
134
+ # threshold_topic = compute_threshold(depths_topic)
135
+ threshold_topic = compute_threshold(filtered_topic)
136
+
137
+ # topk_segments = get_topk_segments(filtered_topic, k=5)
138
+ # Select segments based on threshold
139
+ threshold_segments_topic = get_threshold_segments(filtered_topic, threshold_topic)
140
+
141
+ # st.write(threshold_topic)
142
+
143
+ ###st.write("compute segments")
144
+
145
+ segment_ids = threshold_segments_topic + WINDOW_SIZE
146
+
147
+ segment_ids = [0] + segment_ids.tolist() + [len(sents)]
148
+ slices = list(zip(segment_ids[:-1], segment_ids[1:]))
149
+
150
+ segmented = [sents[s[0]: s[1]] for s in slices]
151
+
152
+ for segment in segmented[:-1]:
153
+ print_list([s.text for s in segment])
154
+ st.markdown("""---""")
155
+ print_list([s.text for s in segmented[-1]])