Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- app.py +112 -5
- config.py +51 -0
- pipeline.py +132 -0
- requirements.txt +9 -2
app.py
CHANGED
@@ -1,9 +1,116 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import pipeline
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# external libraries
|
2 |
import streamlit as st
|
3 |
from transformers import pipeline
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
|
7 |
+
# internal libraries
|
8 |
+
from config import config
|
9 |
+
import pipeline
|
10 |
|
11 |
+
|
12 |
+
def main():
|
13 |
+
|
14 |
+
st.set_page_config(
|
15 |
+
layout="centered", # Can be "centered" or "wide". In the future also "dashboard", etc.
|
16 |
+
initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
|
17 |
+
page_title=config.main_title, # String or None. Strings get appended with "• Streamlit".
|
18 |
+
page_icon=config.logo_path, # String, anything supported by st.image, or None.
|
19 |
+
)
|
20 |
+
|
21 |
+
if "output" not in st.session_state:
|
22 |
+
st.session_state['data'] = pd.read_csv(config.sample_texts_path)
|
23 |
+
st.session_state['sample_text'] = None
|
24 |
+
generate_text()
|
25 |
+
st.session_state["output"] = False
|
26 |
+
st.session_state["output_text"] = ""
|
27 |
+
st.session_state['inputs'] = {}
|
28 |
+
|
29 |
+
col1, col2, col3 = st.columns(3)
|
30 |
+
col1.write(' ')
|
31 |
+
col2.image(config.logo_path)
|
32 |
+
col3.write(' ')
|
33 |
+
|
34 |
+
st.markdown(f"<h1 style='text-align: center;'>{config.main_title}</h1>", unsafe_allow_html=True)
|
35 |
+
st.markdown(f"<h3 style='text-align: center;'>{config.lecture_title}</h3>", unsafe_allow_html=True)
|
36 |
+
|
37 |
+
# topic modelling radio bar
|
38 |
+
input_topic_modelling = st.radio(
|
39 |
+
config.topic_modelling_title,
|
40 |
+
config.topic_modelling_answers,
|
41 |
+
horizontal=True)
|
42 |
+
st.session_state['inputs']['input_topic_modelling'] = input_topic_modelling
|
43 |
+
|
44 |
+
# input text area
|
45 |
+
input_text = st.text_area(config.input_text, st.session_state['sample_text'], height=300)
|
46 |
+
st.session_state['inputs']['input_text'] = input_text
|
47 |
+
|
48 |
+
# generate sample text button
|
49 |
+
st.button(config.button_text, on_click=generate_text)
|
50 |
+
|
51 |
+
# choosing segmenter radio bar
|
52 |
+
input_segmenter = st.radio(
|
53 |
+
config.segmenter_title,
|
54 |
+
config.segmenter_answers,
|
55 |
+
horizontal=True)
|
56 |
+
st.session_state['inputs']['input_segmenter'] = input_segmenter
|
57 |
+
|
58 |
+
# choosing summarizer algorithm radio bar
|
59 |
+
input_summarizer = st.radio(
|
60 |
+
config.summarizer_title,
|
61 |
+
config.summarizer_answers,
|
62 |
+
horizontal=True)
|
63 |
+
st.session_state['inputs']['input_summarizer'] = input_summarizer
|
64 |
+
|
65 |
+
# generating summary button
|
66 |
+
col1, col2, col3 = st.columns(3)
|
67 |
+
col1.header(' ')
|
68 |
+
col2.button(config.generate_text, on_click=generate_summary)
|
69 |
+
col3.header(' ')
|
70 |
+
|
71 |
+
if st.session_state["output"]:
|
72 |
+
|
73 |
+
TOPICS = [key for key, value in st.session_state["output_text"].items() if key != '#']
|
74 |
+
|
75 |
+
if config.filter_threshold_summaries:
|
76 |
+
TOPICS = [key for key in TOPICS if st.session_state["output_text"][key]['summary'] != config.threshold_error]
|
77 |
+
|
78 |
+
st.write(config.output_title)
|
79 |
+
options = {}
|
80 |
+
for topic in TOPICS:
|
81 |
+
option = st.checkbox(topic)
|
82 |
+
options[topic] = option
|
83 |
+
|
84 |
+
if len(options) == 0:
|
85 |
+
st.warning(config.warning_len_input_text, icon="⚠️")
|
86 |
+
|
87 |
+
for topic, option in options.items():
|
88 |
+
if option == True:
|
89 |
+
st.text_area(topic,
|
90 |
+
st.session_state["output_text"][topic]['summary'],
|
91 |
+
disabled=True)
|
92 |
+
|
93 |
+
def generate_text():
|
94 |
+
df = st.session_state['data']
|
95 |
+
df = df[~df['data'].isnull()]
|
96 |
+
df = df[df['data'].str.len().gt(100)]
|
97 |
+
st.session_state['sample_text'] = df.sample(1)['data'].values[0]
|
98 |
+
|
99 |
+
def generate_summary():
|
100 |
+
st.session_state["output"] = True
|
101 |
+
|
102 |
+
MODELS = {
|
103 |
+
'summarizer':st.session_state['inputs']['input_summarizer'],
|
104 |
+
'topic_modelling':st.session_state['inputs']['input_topic_modelling'],
|
105 |
+
'segmentizer':st.session_state['inputs']['input_segmenter']
|
106 |
+
}
|
107 |
+
|
108 |
+
with st.spinner('Generating the output of Topic Modeling for Summarization...'):
|
109 |
+
OUTPUT = pipeline.run(st.session_state['inputs']['input_text'], MODELS)
|
110 |
+
|
111 |
+
st.session_state["output_text"] = OUTPUT
|
112 |
+
st.success('Done!')
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
config.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class config():
|
2 |
+
|
3 |
+
sample_texts_path = 'data/sample_text.csv'
|
4 |
+
|
5 |
+
logo_path = 'images/tum-logo.png'
|
6 |
+
main_title = 'Topic Modeling for Summarization'
|
7 |
+
lecture_title = 'Machine Learning for Natural Language Processing Applications (IN2106)'
|
8 |
+
|
9 |
+
input_text = 'Enter Some Text:'
|
10 |
+
button_text = 'Generate new sample text'
|
11 |
+
|
12 |
+
|
13 |
+
topic_modelling_title = 'Topic Modelling:'
|
14 |
+
topic_modelling_answers = ('BERTopic','LDA','CTM','NMF','Top2Vec')
|
15 |
+
|
16 |
+
segmenter_title = 'Segmentizer:'
|
17 |
+
segmenter_answers = ('Nltk','Spacy','Stanza')
|
18 |
+
|
19 |
+
summarizer_title = 'Summarizer:'
|
20 |
+
summarizer_answers = ('Bart','T5-base','Prophetnet','Pegasus')
|
21 |
+
|
22 |
+
generate_text = 'Generate topic based summaries'
|
23 |
+
|
24 |
+
output_title = 'Select topics you want to see summaries of:'
|
25 |
+
|
26 |
+
|
27 |
+
warning_len_input_text = 'The length of the input text is not enough to create topic awareness summaries! or change parameters!'
|
28 |
+
filter_threshold_summaries = True
|
29 |
+
threshold_error = 'X -> not possible to generate a summary due to ''threshold'
|
30 |
+
|
31 |
+
# model parameters
|
32 |
+
MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION = 2
|
33 |
+
PATH_20_NEWS_CLUSTERID_LABEL_WORDS = 'data/_20news_df_output_clusterId_label_words.csv'
|
34 |
+
|
35 |
+
PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM = 'data/_20news_df_output_doc_topic_CTM_LIST.csv'
|
36 |
+
PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA = 'data/_20news_df_output_doc_topic_LDA_LIST.csv'
|
37 |
+
PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF = 'data/_20news_df_output_doc_topic_NMF_LIST.csv'
|
38 |
+
PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC = 'data/_20news_df_output_doc_topic_Top2Vec_LIST.csv'
|
39 |
+
|
40 |
+
# model paths
|
41 |
+
nltk_path = 'models/nltkUtilsObj.pkl'
|
42 |
+
sent_trans_path = 'models/sentTransfModelUtilsObj.pkl'
|
43 |
+
|
44 |
+
pegasus_model_path = 'models/pegasus_model'
|
45 |
+
bart_model_path = 'models/bart_model'
|
46 |
+
t5_model_path = 'models/t5_model'
|
47 |
+
prophetnet_model_path = 'models/prophetnet_model'
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
pipeline.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# external libraries
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import ast
|
6 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
7 |
+
|
8 |
+
# internal libraries
|
9 |
+
from config import config
|
10 |
+
from src.nltk_utilities import NltkSegmentizer
|
11 |
+
from src.stanza_utilities import StanzaSegmentizer
|
12 |
+
from src.spacy_utilities import SpacySegmentizer
|
13 |
+
from src.preprocessing import remove_patterns
|
14 |
+
from src.summarization_utilities import SummarizationUtilities, BARTSummarizer, T5Summarizer, ProphetNetSummarizer
|
15 |
+
|
16 |
+
|
17 |
+
nltkUtilsObj = None
|
18 |
+
|
19 |
+
sentTransfModelUtilsObj = pickle.load(open(config.sent_trans_path, 'rb'))
|
20 |
+
sentTransfModelUtilsObj.model = sentTransfModelUtilsObj.model.to('cpu')
|
21 |
+
|
22 |
+
TopicModelling = ''
|
23 |
+
|
24 |
+
summUtilsObj = None
|
25 |
+
|
26 |
+
def text_to_sentences(data):
|
27 |
+
|
28 |
+
list_sentences = [*nltkUtilsObj.segment_into_sentences(data)]
|
29 |
+
|
30 |
+
return list_sentences
|
31 |
+
|
32 |
+
def preprocess(list_sentences, sentTransfModelUtilsObj):
|
33 |
+
|
34 |
+
list_sentences = [remove_patterns(x) for x in list_sentences]
|
35 |
+
list_sentences_per_doc_embeddings = [sentTransfModelUtilsObj.get_embeddings(x) for x in list_sentences if len(x) > 0]
|
36 |
+
return list_sentences_per_doc_embeddings, list_sentences
|
37 |
+
|
38 |
+
def get_emb_cluster_topic(sentTransfModelUtilsObj):
|
39 |
+
|
40 |
+
df_latVectorRep = pd.read_csv(TopicModelling)
|
41 |
+
df_latVectorRep["sentence_from_words"] = df_latVectorRep["list_topic_words"].map(lambda x: " ".join(ast.literal_eval(x)))
|
42 |
+
list_embeddings_cluster_sentences = list()
|
43 |
+
|
44 |
+
for index, row in df_latVectorRep.iterrows():
|
45 |
+
list_embeddings_cluster_sentences.append(sentTransfModelUtilsObj.get_embeddings(row["sentence_from_words"]))
|
46 |
+
|
47 |
+
return list_embeddings_cluster_sentences, df_latVectorRep
|
48 |
+
|
49 |
+
def compute_similarity_matrix(list_sentences_per_doc_embeddings, list_sentences, sentTransfModelUtilsObj):
|
50 |
+
|
51 |
+
list_embeddings_cluster_sentences, df_latVectorRep = get_emb_cluster_topic(sentTransfModelUtilsObj)
|
52 |
+
|
53 |
+
similarity_matrix = np.zeros((len(list_embeddings_cluster_sentences), len(list_sentences_per_doc_embeddings)))
|
54 |
+
|
55 |
+
for i, cluster_embedding in enumerate(list_embeddings_cluster_sentences):
|
56 |
+
for j, sentence_emebedding in enumerate(list_sentences_per_doc_embeddings):
|
57 |
+
similarity_matrix[i][j] = sentTransfModelUtilsObj.compute_cosine_similarity(cluster_embedding, sentence_emebedding)
|
58 |
+
|
59 |
+
list_index_topics_within_matrix = np.argmax(similarity_matrix, axis=0)
|
60 |
+
|
61 |
+
dict_topic_sentences = dict()
|
62 |
+
|
63 |
+
for index_sentence, index_id_topic in enumerate(list_index_topics_within_matrix):
|
64 |
+
label_class = df_latVectorRep.iloc[index_id_topic]["label_class"]
|
65 |
+
|
66 |
+
if label_class not in dict_topic_sentences.keys():
|
67 |
+
dict_topic_sentences[label_class] = list()
|
68 |
+
dict_topic_sentences[label_class].append(list_sentences[index_sentence])
|
69 |
+
|
70 |
+
return dict_topic_sentences
|
71 |
+
|
72 |
+
def summarize(dict_topic_sentences):
|
73 |
+
|
74 |
+
summaries_report = dict()
|
75 |
+
for class_label in dict_topic_sentences.keys():
|
76 |
+
|
77 |
+
summaries_report[class_label] = {}
|
78 |
+
|
79 |
+
if len(dict_topic_sentences[class_label]) >= config.MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION:
|
80 |
+
summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
|
81 |
+
summaries_report[class_label]["summary"] = summUtilsObj.summarize(" ".join(dict_topic_sentences[class_label]))
|
82 |
+
print(dict_topic_sentences[class_label])
|
83 |
+
else:
|
84 |
+
summaries_report[class_label]["summary"] = "X -> not possible to generate a summary due to threshold"
|
85 |
+
summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
|
86 |
+
|
87 |
+
return summaries_report
|
88 |
+
|
89 |
+
def define_models(MODELS):
|
90 |
+
|
91 |
+
global TopicModelling
|
92 |
+
global summUtilsObj
|
93 |
+
global nltkUtilsObj
|
94 |
+
|
95 |
+
if MODELS['summarizer'] == 'Pegasus':
|
96 |
+
summUtilsObj = SummarizationUtilities()
|
97 |
+
elif MODELS['summarizer'] == 'Bart':
|
98 |
+
summUtilsObj = BARTSummarizer()
|
99 |
+
elif MODELS['summarizer'] == 'T5-base':
|
100 |
+
summUtilsObj = T5Summarizer()
|
101 |
+
elif MODELS['summarizer'] == 'Prophetnet':
|
102 |
+
summUtilsObj = ProphetNetSummarizer()
|
103 |
+
|
104 |
+
if MODELS['topic_modelling'] == 'BERTopic':
|
105 |
+
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS
|
106 |
+
elif MODELS['topic_modelling'] == 'LDA':
|
107 |
+
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM
|
108 |
+
elif MODELS['topic_modelling'] == 'CTM':
|
109 |
+
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA
|
110 |
+
elif MODELS['topic_modelling'] == 'NMF':
|
111 |
+
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF
|
112 |
+
elif MODELS['topic_modelling'] == 'Top2Vec':
|
113 |
+
TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC
|
114 |
+
|
115 |
+
if MODELS['segmentizer'] == 'Nltk':
|
116 |
+
nltkUtilsObj = NltkSegmentizer()
|
117 |
+
if MODELS['segmentizer'] == 'Spacy':
|
118 |
+
nltkUtilsObj = SpacySegmentizer()
|
119 |
+
elif MODELS['segmentizer'] == 'Stanza':
|
120 |
+
nltkUtilsObj = StanzaSegmentizer()
|
121 |
+
|
122 |
+
def run(data, MODELS):
|
123 |
+
|
124 |
+
define_models(MODELS)
|
125 |
+
|
126 |
+
data_sentences = text_to_sentences(data)
|
127 |
+
data_embed, list_sentences = preprocess(data_sentences, sentTransfModelUtilsObj)
|
128 |
+
|
129 |
+
dict_topic_sentences = compute_similarity_matrix(data_embed, list_sentences, sentTransfModelUtilsObj)
|
130 |
+
summaries_report = summarize(dict_topic_sentences)
|
131 |
+
|
132 |
+
return summaries_report
|
requirements.txt
CHANGED
@@ -1,2 +1,9 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
nltk==3.8.1
|
2 |
+
numpy==1.24.3
|
3 |
+
pandas==2.0.0
|
4 |
+
sentence_transformers==2.2.2
|
5 |
+
spacy==3.5.2
|
6 |
+
stanza==1.5.0
|
7 |
+
streamlit==1.23.1
|
8 |
+
torch==2.0.1
|
9 |
+
transformers==4.29.1
|