alpertml commited on
Commit
06b4325
1 Parent(s): 1d4fc05

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +112 -5
  2. config.py +51 -0
  3. pipeline.py +132 -0
  4. requirements.txt +9 -2
app.py CHANGED
@@ -1,9 +1,116 @@
 
1
  import streamlit as st
2
  from transformers import pipeline
 
 
3
 
4
- pipe = pipeline('sentiment-analysis')
5
- text = st.text_area('enter some text:')
 
6
 
7
- if text:
8
- out = pipe(text)
9
- st.json(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # external libraries
2
  import streamlit as st
3
  from transformers import pipeline
4
+
5
+ import pandas as pd
6
 
7
+ # internal libraries
8
+ from config import config
9
+ import pipeline
10
 
11
+
12
+ def main():
13
+
14
+ st.set_page_config(
15
+ layout="centered", # Can be "centered" or "wide". In the future also "dashboard", etc.
16
+ initial_sidebar_state="auto", # Can be "auto", "expanded", "collapsed"
17
+ page_title=config.main_title, # String or None. Strings get appended with "• Streamlit".
18
+ page_icon=config.logo_path, # String, anything supported by st.image, or None.
19
+ )
20
+
21
+ if "output" not in st.session_state:
22
+ st.session_state['data'] = pd.read_csv(config.sample_texts_path)
23
+ st.session_state['sample_text'] = None
24
+ generate_text()
25
+ st.session_state["output"] = False
26
+ st.session_state["output_text"] = ""
27
+ st.session_state['inputs'] = {}
28
+
29
+ col1, col2, col3 = st.columns(3)
30
+ col1.write(' ')
31
+ col2.image(config.logo_path)
32
+ col3.write(' ')
33
+
34
+ st.markdown(f"<h1 style='text-align: center;'>{config.main_title}</h1>", unsafe_allow_html=True)
35
+ st.markdown(f"<h3 style='text-align: center;'>{config.lecture_title}</h3>", unsafe_allow_html=True)
36
+
37
+ # topic modelling radio bar
38
+ input_topic_modelling = st.radio(
39
+ config.topic_modelling_title,
40
+ config.topic_modelling_answers,
41
+ horizontal=True)
42
+ st.session_state['inputs']['input_topic_modelling'] = input_topic_modelling
43
+
44
+ # input text area
45
+ input_text = st.text_area(config.input_text, st.session_state['sample_text'], height=300)
46
+ st.session_state['inputs']['input_text'] = input_text
47
+
48
+ # generate sample text button
49
+ st.button(config.button_text, on_click=generate_text)
50
+
51
+ # choosing segmenter radio bar
52
+ input_segmenter = st.radio(
53
+ config.segmenter_title,
54
+ config.segmenter_answers,
55
+ horizontal=True)
56
+ st.session_state['inputs']['input_segmenter'] = input_segmenter
57
+
58
+ # choosing summarizer algorithm radio bar
59
+ input_summarizer = st.radio(
60
+ config.summarizer_title,
61
+ config.summarizer_answers,
62
+ horizontal=True)
63
+ st.session_state['inputs']['input_summarizer'] = input_summarizer
64
+
65
+ # generating summary button
66
+ col1, col2, col3 = st.columns(3)
67
+ col1.header(' ')
68
+ col2.button(config.generate_text, on_click=generate_summary)
69
+ col3.header(' ')
70
+
71
+ if st.session_state["output"]:
72
+
73
+ TOPICS = [key for key, value in st.session_state["output_text"].items() if key != '#']
74
+
75
+ if config.filter_threshold_summaries:
76
+ TOPICS = [key for key in TOPICS if st.session_state["output_text"][key]['summary'] != config.threshold_error]
77
+
78
+ st.write(config.output_title)
79
+ options = {}
80
+ for topic in TOPICS:
81
+ option = st.checkbox(topic)
82
+ options[topic] = option
83
+
84
+ if len(options) == 0:
85
+ st.warning(config.warning_len_input_text, icon="⚠️")
86
+
87
+ for topic, option in options.items():
88
+ if option == True:
89
+ st.text_area(topic,
90
+ st.session_state["output_text"][topic]['summary'],
91
+ disabled=True)
92
+
93
+ def generate_text():
94
+ df = st.session_state['data']
95
+ df = df[~df['data'].isnull()]
96
+ df = df[df['data'].str.len().gt(100)]
97
+ st.session_state['sample_text'] = df.sample(1)['data'].values[0]
98
+
99
+ def generate_summary():
100
+ st.session_state["output"] = True
101
+
102
+ MODELS = {
103
+ 'summarizer':st.session_state['inputs']['input_summarizer'],
104
+ 'topic_modelling':st.session_state['inputs']['input_topic_modelling'],
105
+ 'segmentizer':st.session_state['inputs']['input_segmenter']
106
+ }
107
+
108
+ with st.spinner('Generating the output of Topic Modeling for Summarization...'):
109
+ OUTPUT = pipeline.run(st.session_state['inputs']['input_text'], MODELS)
110
+
111
+ st.session_state["output_text"] = OUTPUT
112
+ st.success('Done!')
113
+
114
+
115
+ if __name__ == "__main__":
116
+ main()
config.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class config():
2
+
3
+ sample_texts_path = 'data/sample_text.csv'
4
+
5
+ logo_path = 'images/tum-logo.png'
6
+ main_title = 'Topic Modeling for Summarization'
7
+ lecture_title = 'Machine Learning for Natural Language Processing Applications (IN2106)'
8
+
9
+ input_text = 'Enter Some Text:'
10
+ button_text = 'Generate new sample text'
11
+
12
+
13
+ topic_modelling_title = 'Topic Modelling:'
14
+ topic_modelling_answers = ('BERTopic','LDA','CTM','NMF','Top2Vec')
15
+
16
+ segmenter_title = 'Segmentizer:'
17
+ segmenter_answers = ('Nltk','Spacy','Stanza')
18
+
19
+ summarizer_title = 'Summarizer:'
20
+ summarizer_answers = ('Bart','T5-base','Prophetnet','Pegasus')
21
+
22
+ generate_text = 'Generate topic based summaries'
23
+
24
+ output_title = 'Select topics you want to see summaries of:'
25
+
26
+
27
+ warning_len_input_text = 'The length of the input text is not enough to create topic awareness summaries! or change parameters!'
28
+ filter_threshold_summaries = True
29
+ threshold_error = 'X -> not possible to generate a summary due to ''threshold'
30
+
31
+ # model parameters
32
+ MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION = 2
33
+ PATH_20_NEWS_CLUSTERID_LABEL_WORDS = 'data/_20news_df_output_clusterId_label_words.csv'
34
+
35
+ PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM = 'data/_20news_df_output_doc_topic_CTM_LIST.csv'
36
+ PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA = 'data/_20news_df_output_doc_topic_LDA_LIST.csv'
37
+ PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF = 'data/_20news_df_output_doc_topic_NMF_LIST.csv'
38
+ PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC = 'data/_20news_df_output_doc_topic_Top2Vec_LIST.csv'
39
+
40
+ # model paths
41
+ nltk_path = 'models/nltkUtilsObj.pkl'
42
+ sent_trans_path = 'models/sentTransfModelUtilsObj.pkl'
43
+
44
+ pegasus_model_path = 'models/pegasus_model'
45
+ bart_model_path = 'models/bart_model'
46
+ t5_model_path = 'models/t5_model'
47
+ prophetnet_model_path = 'models/prophetnet_model'
48
+
49
+
50
+
51
+
pipeline.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # external libraries
2
+ import pickle
3
+ import numpy as np
4
+ import pandas as pd
5
+ import ast
6
+ from transformers import BartForConditionalGeneration, BartTokenizer
7
+
8
+ # internal libraries
9
+ from config import config
10
+ from src.nltk_utilities import NltkSegmentizer
11
+ from src.stanza_utilities import StanzaSegmentizer
12
+ from src.spacy_utilities import SpacySegmentizer
13
+ from src.preprocessing import remove_patterns
14
+ from src.summarization_utilities import SummarizationUtilities, BARTSummarizer, T5Summarizer, ProphetNetSummarizer
15
+
16
+
17
+ nltkUtilsObj = None
18
+
19
+ sentTransfModelUtilsObj = pickle.load(open(config.sent_trans_path, 'rb'))
20
+ sentTransfModelUtilsObj.model = sentTransfModelUtilsObj.model.to('cpu')
21
+
22
+ TopicModelling = ''
23
+
24
+ summUtilsObj = None
25
+
26
+ def text_to_sentences(data):
27
+
28
+ list_sentences = [*nltkUtilsObj.segment_into_sentences(data)]
29
+
30
+ return list_sentences
31
+
32
+ def preprocess(list_sentences, sentTransfModelUtilsObj):
33
+
34
+ list_sentences = [remove_patterns(x) for x in list_sentences]
35
+ list_sentences_per_doc_embeddings = [sentTransfModelUtilsObj.get_embeddings(x) for x in list_sentences if len(x) > 0]
36
+ return list_sentences_per_doc_embeddings, list_sentences
37
+
38
+ def get_emb_cluster_topic(sentTransfModelUtilsObj):
39
+
40
+ df_latVectorRep = pd.read_csv(TopicModelling)
41
+ df_latVectorRep["sentence_from_words"] = df_latVectorRep["list_topic_words"].map(lambda x: " ".join(ast.literal_eval(x)))
42
+ list_embeddings_cluster_sentences = list()
43
+
44
+ for index, row in df_latVectorRep.iterrows():
45
+ list_embeddings_cluster_sentences.append(sentTransfModelUtilsObj.get_embeddings(row["sentence_from_words"]))
46
+
47
+ return list_embeddings_cluster_sentences, df_latVectorRep
48
+
49
+ def compute_similarity_matrix(list_sentences_per_doc_embeddings, list_sentences, sentTransfModelUtilsObj):
50
+
51
+ list_embeddings_cluster_sentences, df_latVectorRep = get_emb_cluster_topic(sentTransfModelUtilsObj)
52
+
53
+ similarity_matrix = np.zeros((len(list_embeddings_cluster_sentences), len(list_sentences_per_doc_embeddings)))
54
+
55
+ for i, cluster_embedding in enumerate(list_embeddings_cluster_sentences):
56
+ for j, sentence_emebedding in enumerate(list_sentences_per_doc_embeddings):
57
+ similarity_matrix[i][j] = sentTransfModelUtilsObj.compute_cosine_similarity(cluster_embedding, sentence_emebedding)
58
+
59
+ list_index_topics_within_matrix = np.argmax(similarity_matrix, axis=0)
60
+
61
+ dict_topic_sentences = dict()
62
+
63
+ for index_sentence, index_id_topic in enumerate(list_index_topics_within_matrix):
64
+ label_class = df_latVectorRep.iloc[index_id_topic]["label_class"]
65
+
66
+ if label_class not in dict_topic_sentences.keys():
67
+ dict_topic_sentences[label_class] = list()
68
+ dict_topic_sentences[label_class].append(list_sentences[index_sentence])
69
+
70
+ return dict_topic_sentences
71
+
72
+ def summarize(dict_topic_sentences):
73
+
74
+ summaries_report = dict()
75
+ for class_label in dict_topic_sentences.keys():
76
+
77
+ summaries_report[class_label] = {}
78
+
79
+ if len(dict_topic_sentences[class_label]) >= config.MIN_NUM_SENTENCES_FOR_SUMMARY_CREATION:
80
+ summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
81
+ summaries_report[class_label]["summary"] = summUtilsObj.summarize(" ".join(dict_topic_sentences[class_label]))
82
+ print(dict_topic_sentences[class_label])
83
+ else:
84
+ summaries_report[class_label]["summary"] = "X -> not possible to generate a summary due to threshold"
85
+ summaries_report[class_label]["source"] = dict_topic_sentences[class_label]
86
+
87
+ return summaries_report
88
+
89
+ def define_models(MODELS):
90
+
91
+ global TopicModelling
92
+ global summUtilsObj
93
+ global nltkUtilsObj
94
+
95
+ if MODELS['summarizer'] == 'Pegasus':
96
+ summUtilsObj = SummarizationUtilities()
97
+ elif MODELS['summarizer'] == 'Bart':
98
+ summUtilsObj = BARTSummarizer()
99
+ elif MODELS['summarizer'] == 'T5-base':
100
+ summUtilsObj = T5Summarizer()
101
+ elif MODELS['summarizer'] == 'Prophetnet':
102
+ summUtilsObj = ProphetNetSummarizer()
103
+
104
+ if MODELS['topic_modelling'] == 'BERTopic':
105
+ TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS
106
+ elif MODELS['topic_modelling'] == 'LDA':
107
+ TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_CTM
108
+ elif MODELS['topic_modelling'] == 'CTM':
109
+ TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_LDA
110
+ elif MODELS['topic_modelling'] == 'NMF':
111
+ TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_NMF
112
+ elif MODELS['topic_modelling'] == 'Top2Vec':
113
+ TopicModelling = config.PATH_20_NEWS_CLUSTERID_LABEL_WORDS_TOP2VEC
114
+
115
+ if MODELS['segmentizer'] == 'Nltk':
116
+ nltkUtilsObj = NltkSegmentizer()
117
+ if MODELS['segmentizer'] == 'Spacy':
118
+ nltkUtilsObj = SpacySegmentizer()
119
+ elif MODELS['segmentizer'] == 'Stanza':
120
+ nltkUtilsObj = StanzaSegmentizer()
121
+
122
+ def run(data, MODELS):
123
+
124
+ define_models(MODELS)
125
+
126
+ data_sentences = text_to_sentences(data)
127
+ data_embed, list_sentences = preprocess(data_sentences, sentTransfModelUtilsObj)
128
+
129
+ dict_topic_sentences = compute_similarity_matrix(data_embed, list_sentences, sentTransfModelUtilsObj)
130
+ summaries_report = summarize(dict_topic_sentences)
131
+
132
+ return summaries_report
requirements.txt CHANGED
@@ -1,2 +1,9 @@
1
- torch
2
- transformers
 
 
 
 
 
 
 
 
1
+ nltk==3.8.1
2
+ numpy==1.24.3
3
+ pandas==2.0.0
4
+ sentence_transformers==2.2.2
5
+ spacy==3.5.2
6
+ stanza==1.5.0
7
+ streamlit==1.23.1
8
+ torch==2.0.1
9
+ transformers==4.29.1