File size: 7,906 Bytes
01d9aad
 
fb7a57d
01d9aad
 
 
b390e60
d4be6e6
 
b390e60
e452a5c
 
01d9aad
 
 
96aa704
01d9aad
 
 
66a86a3
 
0945896
 
662dc37
 
66a86a3
 
 
 
6d47b6a
66a86a3
 
 
 
96aa704
66a86a3
96aa704
662dc37
 
96aa704
662dc37
 
 
96aa704
66a86a3
 
0945896
fb7a57d
 
e452a5c
8d02a0a
fb7a57d
 
 
e452a5c
8d02a0a
633f6ea
 
66a86a3
 
 
dae3587
5738c05
3fbee32
1f1805f
 
e452a5c
1f1805f
 
 
 
 
 
 
 
 
 
 
 
 
e452a5c
1f1805f
 
 
 
 
633f6ea
e452a5c
633f6ea
 
 
5738c05
87ccdf0
633f6ea
e452a5c
633f6ea
 
 
 
 
 
 
 
 
 
 
e452a5c
633f6ea
 
 
d4be6e6
633f6ea
d4be6e6
 
 
5db15b3
d4be6e6
 
 
 
633f6ea
 
 
d4be6e6
5db15b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4be6e6
bf19bee
d4be6e6
bf19bee
 
 
 
 
 
 
 
 
 
 
d4be6e6
633f6ea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159

from os import write
import time
import pandas as pd
import base64
from typing import Sequence
import streamlit as st
from sklearn.metrics import classification_report


# from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
import models as md
from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
import json

ex_text, ex_license, ex_labels, ex_glabels = examples_load()
ex_long_text = example_long_text_load()


# if __name__ == '__main__':
st.header("Summzarization & Multi-label Classification for Long Text")
st.write("This app summarizes and then classifies your long text with multiple labels.")
st.write("__Inputs__: User enters their own custom text and labels.")
st.write("__Outputs__: A summary of the text, likelihood percentages for each label and a downloadable csv of the results. \
    Option to evaluate results against a list of ground truth labels, if available.")

with st.form(key='my_form'):
    example_text = ex_long_text #ex_text
    display_text = "[Excerpt from Project Gutenberg: Frankenstein]\n" + example_text + "\n\n" + ex_license
    text_input = st.text_area("Input any text you want to summarize & classify here (keep in mind very long text will take a while to process):", display_text)

    if text_input == display_text:
        text_input = example_text

    labels = st.text_input('Enter possible labels (comma-separated):',ex_labels, max_chars=1000)
    labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
    
    glabels = st.text_input('If available, enter ground truth labels to evaluate results, otherwise leave blank (comma-separated):',ex_glabels, max_chars=1000)
    glabels = list(set([x.strip() for x in glabels.strip().split(',') if len(x.strip()) > 0]))

    threshold_value = st.slider(
         'Select a threshold cutoff for matching percentage (used for ground truth label evaluation)',
         0.0, 1.0, (0.5))

    submit_button = st.form_submit_button(label='Submit')


with st.spinner('Loading pretrained summarizer mnli model...'):
    start = time.time()
    summarizer = md.load_summary_model()   
    st.success(f'Time taken to load summarizer mnli model: {round(time.time() - start,4)} seconds')

with st.spinner('Loading pretrained classifier mnli model...'):
    start = time.time()
    classifier = md.load_model()    
    st.success(f'Time taken to load classifier mnli model: {round(time.time() - start,4)} seconds')


if submit_button:
    if len(labels) == 0:
        st.write('Enter some text and at least one possible topic to see predictions.')
    
    with st.spinner('Generating summaries and matching labels...'):
        my_expander = st.expander(label='Expand to see summary generation details')
        with my_expander:
            # For each body of text, create text chunks of a certain token size required for the transformer
            nested_sentences = md.create_nest_sentences(document = text_input, token_max_length = 1024)

            summary = []
            # st.markdown("### Text Chunk & Summaries")
            st.markdown("_Breaks up the original text into sections with complete sentences totaling \
                less than 1024 tokens, a requirement for the summarizer. Each block of text is than summarized separately \
                and then combined at the very end to generate the final summary._")

            # For each chunk of sentences (within the token max), generate a summary
            for n in range(0, len(nested_sentences)):
                text_chunk = " ".join(map(str, nested_sentences[n]))
                st.markdown(f"###### Original Text Chunk {n+1}/{len(nested_sentences)}" )
                st.markdown(text_chunk)

                chunk_summary = md.summarizer_gen(summarizer, sequence=text_chunk, maximum_tokens = 300, minimum_tokens = 20)
                summary.append(chunk_summary) 
                st.markdown(f"###### Partial Summary {n+1}/{len(nested_sentences)}")
                st.markdown(chunk_summary)
                # Combine all the summaries into a list and compress into one document, again
                final_summary = " \n\n".join(list(summary))

        # final_summary = md.summarizer_gen(summarizer, sequence=text_input, maximum_tokens = 30, minimum_tokens = 100)
        st.markdown("### Combined Summary")
        st.markdown(final_summary)
    

        st.markdown("### Top Label Predictions on Summary & Full Text")
        with st.spinner('Matching labels...'):
            topics, scores = md.classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
            # st.markdown("### Top Label Predictions: Combined Summary")
            # plot_result(topics[::-1][:], scores[::-1][:])
            # st.markdown("### Download Data")
            data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
            # st.dataframe(data)
            # coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
            # st.markdown(
            #     f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
            #     unsafe_allow_html = True
            #     )

            topics_ex_text, scores_ex_text = md.classifier_zero(classifier, sequence=example_text, labels=labels, multi_class=True)
            plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)

            data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
            
            data2 = pd.merge(data, data_ex_text, on = ['label'])

            if len(glabels) > 0:
                gdata = pd.DataFrame({'label': glabels})
                gdata['is_true_label'] = int(1)           
            
                data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
                data2['is_true_label'].fillna(0, inplace = True)

            st.markdown("### Data Table")
            with st.spinner('Generating a table of results and a download link...'):
                st.dataframe(data2)

                @st.cache
                def convert_df(df):
                     # IMPORTANT: Cache the conversion to prevent computation on every rerun
                     return df.to_csv().encode('utf-8')
                csv = convert_df(data2)
                st.download_button(
                     label="Download data as CSV",
                     data=csv,
                     file_name='text_labels.csv',
                     mime='text/csv',
                 )
                # coded_data = base64.b64encode(data2.to_csv(index = False). encode ()).decode()
                # st.markdown(
                #     f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Click here to download the data</a>',
                #     unsafe_allow_html = True
                #     )

            if len(glabels) > 0:
                st.markdown("### Evaluation Metrics")
                with st.spinner('Evaluating output against ground truth...'):

                    section_header_description = ['Summary Label Performance', 'Original Full Text Label Performance']
                    data_headers = ['scores_from_summary', 'scores_from_full_text']
                    for i in range(0,2):
                        st.markdown(f"##### {section_header_description[i]}")
                        report = classification_report(y_true = data2[['is_true_label']], 
                            y_pred = (data2[[data_headers[i]]] >= threshold_value) * 1.0,
                            output_dict=True)
                        df_report = pd.DataFrame(report).transpose()
                        st.markdown(f"Threshold set for: {threshold_value}")
                        st.dataframe(df_report)

            st.success('All done!')
            st.balloons()