Paula Leonova
commited on
Commit
·
43481f8
1
Parent(s):
ee24d8b
Comment out the confusion matrix - temporary
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ ex_long_text = example_long_text_load()
|
|
21 |
st.markdown("### Long Text Summarization & Multi-Label Classification")
|
22 |
st.write("This app summarizes and then classifies your long text(s) with multiple labels using [BART Large MNLI](https://huggingface.co/facebook/bart-large-mnli). The keywords are generated using [KeyBERT](https://github.com/MaartenGr/KeyBERT).")
|
23 |
st.write("__Inputs__: User enters their own custom text(s) and labels.")
|
24 |
-
st.write("__Outputs__: A summary of the text, likelihood
|
25 |
Includes additional options to generate a list of keywords and/or evaluate results against a list of ground truth labels, if available.")
|
26 |
|
27 |
example_button = st.button(label='See Example')
|
@@ -75,33 +75,27 @@ with st.form(key='my_form'):
|
|
75 |
uploaded_labels_file = st.file_uploader("Choose a CSV file with one column and no header, where each cell is a separate label",
|
76 |
key='labels_uploader')
|
77 |
|
78 |
-
# summary_option = st.multiselect(
|
79 |
-
# "Match labels to text using?",
|
80 |
-
# ['Summary', 'Full Text'],
|
81 |
-
# ['Summary', 'Full Text']
|
82 |
-
# )
|
83 |
-
|
84 |
st.text("\n\n\n")
|
85 |
st.markdown("##### Step 3: Provide Ground Truth Labels (_Optional_)")
|
86 |
glabels = st.text_input('If available, enter ground truth topic labels to evaluate results, otherwise leave blank (comma-separated):',input_glabels, max_chars=2000)
|
87 |
glabels = list(set([x.strip() for x in glabels.strip().split(',') if len(x.strip()) > 0]))
|
88 |
|
89 |
|
90 |
-
glabels_csv_expander = st.expander(label=f'Have a file with labels for the text? Click here to upload your CSV file.', expanded=False)
|
91 |
-
with glabels_csv_expander:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
99 |
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
0.0, 1.0, (0.5))
|
105 |
|
106 |
submit_button = st.form_submit_button(label='Submit')
|
107 |
|
@@ -205,8 +199,9 @@ if submit_button or example_button:
|
|
205 |
)
|
206 |
|
207 |
|
208 |
-
|
209 |
if gen_summary == 'Yes':
|
|
|
210 |
with st.spinner(f'Generating summaries for {len(text_df)} texts consisting of a total of {text_chunk_counter} chunks (this may take a minute)...'):
|
211 |
sum_dict = dict()
|
212 |
for i, key in enumerate(text_chunks_lib):
|
@@ -274,24 +269,24 @@ if submit_button or example_button:
|
|
274 |
labels_full_df = pd.concat([labels_full_df, lf_df[labels_full_col_list]])
|
275 |
|
276 |
with st.expander(f'({i+1}/{len(text_df)}) See intermediate label matching results'):
|
277 |
-
st.write(f"Results for {text_df['title'][i]}")
|
278 |
if gen_summary == 'Yes':
|
279 |
st.dataframe(pd.merge(labels_sum_df, labels_full_df, on=['title','label']))
|
280 |
else:
|
281 |
st.dataframe(labels_full_df)
|
282 |
|
283 |
if gen_summary == 'Yes':
|
284 |
-
label_match_df = pd.merge(labels_sum_df, labels_full_df, on=['
|
285 |
else:
|
286 |
label_match_df = labels_full_df.copy()
|
287 |
|
288 |
-
# TO DO: ADD Flexibility for csv import
|
289 |
if len(glabels) > 0:
|
290 |
gdata = pd.DataFrame({'label': glabels})
|
291 |
-
gdata['
|
292 |
|
293 |
-
label_match_df = pd.merge(label_match_df, gdata, how = 'left', on =
|
294 |
-
label_match_df['correct_match'].fillna(
|
295 |
|
296 |
st.dataframe(label_match_df)
|
297 |
st.download_button(
|
@@ -302,20 +297,20 @@ if submit_button or example_button:
|
|
302 |
)
|
303 |
|
304 |
|
305 |
-
if len(glabels) > 0:
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
|
320 |
st.success('All done!')
|
321 |
-
st.balloons()
|
|
|
21 |
st.markdown("### Long Text Summarization & Multi-Label Classification")
|
22 |
st.write("This app summarizes and then classifies your long text(s) with multiple labels using [BART Large MNLI](https://huggingface.co/facebook/bart-large-mnli). The keywords are generated using [KeyBERT](https://github.com/MaartenGr/KeyBERT).")
|
23 |
st.write("__Inputs__: User enters their own custom text(s) and labels.")
|
24 |
+
st.write("__Outputs__: A summary of the text, likelihood match score for each label and a downloadable csv of the results. \
|
25 |
Includes additional options to generate a list of keywords and/or evaluate results against a list of ground truth labels, if available.")
|
26 |
|
27 |
example_button = st.button(label='See Example')
|
|
|
75 |
uploaded_labels_file = st.file_uploader("Choose a CSV file with one column and no header, where each cell is a separate label",
|
76 |
key='labels_uploader')
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
st.text("\n\n\n")
|
79 |
st.markdown("##### Step 3: Provide Ground Truth Labels (_Optional_)")
|
80 |
glabels = st.text_input('If available, enter ground truth topic labels to evaluate results, otherwise leave blank (comma-separated):',input_glabels, max_chars=2000)
|
81 |
glabels = list(set([x.strip() for x in glabels.strip().split(',') if len(x.strip()) > 0]))
|
82 |
|
83 |
|
84 |
+
# glabels_csv_expander = st.expander(label=f'Have a file with labels for the text? Click here to upload your CSV file.', expanded=False)
|
85 |
+
# with glabels_csv_expander:
|
86 |
+
# st.markdown('##### Choose one of the options below:')
|
87 |
+
# st.write("__Option A:__")
|
88 |
+
# uploaded_onetext_glabels_file = st.file_uploader("Single Text: Choose a CSV file with one column and no header, where each cell is a separate label",
|
89 |
+
# key = 'onetext_glabels_uploader')
|
90 |
+
# st.write("__Option B:__")
|
91 |
+
# uploaded_multitext_glabels_file = st.file_uploader('Multiple Text: Choose a CSV file with two columns "title" and "label", with the cells in the title column matching the name of the files uploaded in step #1.',
|
92 |
+
# key = 'multitext_glabels_uploader')
|
93 |
+
#
|
94 |
|
95 |
|
96 |
+
# threshold_value = st.slider(
|
97 |
+
# 'Select a threshold cutoff for matching percentage (used for ground truth label evaluation)',
|
98 |
+
# 0.0, 1.0, (0.5))
|
|
|
99 |
|
100 |
submit_button = st.form_submit_button(label='Submit')
|
101 |
|
|
|
199 |
)
|
200 |
|
201 |
|
202 |
+
|
203 |
if gen_summary == 'Yes':
|
204 |
+
st.markdown("### Summary")
|
205 |
with st.spinner(f'Generating summaries for {len(text_df)} texts consisting of a total of {text_chunk_counter} chunks (this may take a minute)...'):
|
206 |
sum_dict = dict()
|
207 |
for i, key in enumerate(text_chunks_lib):
|
|
|
269 |
labels_full_df = pd.concat([labels_full_df, lf_df[labels_full_col_list]])
|
270 |
|
271 |
with st.expander(f'({i+1}/{len(text_df)}) See intermediate label matching results'):
|
272 |
+
st.write(f"Results for: {text_df['title'][i]}")
|
273 |
if gen_summary == 'Yes':
|
274 |
st.dataframe(pd.merge(labels_sum_df, labels_full_df, on=['title','label']))
|
275 |
else:
|
276 |
st.dataframe(labels_full_df)
|
277 |
|
278 |
if gen_summary == 'Yes':
|
279 |
+
label_match_df = pd.merge(labels_sum_df, labels_full_df, on=title_element + ['label'])
|
280 |
else:
|
281 |
label_match_df = labels_full_df.copy()
|
282 |
|
283 |
+
# TO DO: ADD Flexibility for csv import and multiple texts
|
284 |
if len(glabels) > 0:
|
285 |
gdata = pd.DataFrame({'label': glabels})
|
286 |
+
gdata['correct_match'] = True
|
287 |
|
288 |
+
label_match_df = pd.merge(label_match_df, gdata, how = 'left', on = ['label'])
|
289 |
+
label_match_df['correct_match'].fillna(False, inplace=True)
|
290 |
|
291 |
st.dataframe(label_match_df)
|
292 |
st.download_button(
|
|
|
297 |
)
|
298 |
|
299 |
|
300 |
+
# if len(glabels) > 0:
|
301 |
+
# st.markdown("### Evaluation Metrics")
|
302 |
+
# with st.spinner('Evaluating output against ground truth...'):
|
303 |
+
#
|
304 |
+
# section_header_description = ['Summary Label Performance', 'Original Full Text Label Performance']
|
305 |
+
# data_headers = ['scores_from_summary', 'scores_from_full_text']
|
306 |
+
# for i in range(0,2):
|
307 |
+
# st.markdown(f"###### {section_header_description[i]}")
|
308 |
+
# report = classification_report(y_true = data2[['is_true_label']],
|
309 |
+
# y_pred = (data2[[data_headers[i]]] >= threshold_value) * 1.0,
|
310 |
+
# output_dict=True)
|
311 |
+
# df_report = pd.DataFrame(report).transpose()
|
312 |
+
# st.markdown(f"Threshold set for: {threshold_value}")
|
313 |
+
# st.dataframe(df_report)
|
314 |
|
315 |
st.success('All done!')
|
316 |
+
# st.balloons()
|