Paula Leonova commited on
Commit
009207e
1 Parent(s): 3d4d8f3

Update label section to include multiple text inputs for summary and full text

Browse files
Files changed (1) hide show
  1. app.py +37 -19
app.py CHANGED
@@ -217,6 +217,7 @@ if submit_button or example_button:
217
 
218
  sum_df = pd.DataFrame.from_dict(sum_dict).T.reset_index()
219
  sum_df.columns = ['title', 'summary_text']
 
220
 
221
  st.dataframe(sum_df)
222
  st.download_button(
@@ -226,30 +227,47 @@ if submit_button or example_button:
226
  mime='title_summary/csv',
227
  )
228
 
229
- if ((len(text_input) == 0 and uploaded_text_files is None or uploaded_csv_text_files is None)
230
  or (len(labels) == 0 and uploaded_labels_file is None)):
231
  st.error('Enter some text and at least one possible topic to see label predictions.')
232
  else:
233
  st.markdown("### Top Label Predictions on Summary vs Full Text")
 
 
 
 
 
 
 
 
234
  with st.spinner('Matching labels...'):
235
- topics, scores = md.classifier_zero(classifier, sequence=final_summary, labels=labels, multi_class=True)
236
- # st.markdown("### Top Label Predictions: Combined Summary")
237
- # plot_result(topics[::-1][:], scores[::-1][:])
238
- # st.markdown("### Download Data")
239
- data = pd.DataFrame({'label': topics, 'scores_from_summary': scores})
240
- # st.dataframe(data)
241
- # coded_data = base64.b64encode(data.to_csv(index = False). encode ()).decode()
242
- # st.markdown(
243
- # f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Download Data</a>',
244
- # unsafe_allow_html = True
245
- # )
246
-
247
- topics_ex_text, scores_ex_text = md.classifier_zero(classifier, sequence=text_input, labels=labels, multi_class=True)
248
- plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
249
-
250
- data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
251
-
252
- data2 = pd.merge(data, data_ex_text, on = ['label'])
 
 
 
 
 
 
 
 
 
253
 
254
  if len(glabels) > 0:
255
  gdata = pd.DataFrame({'label': glabels})
 
217
 
218
  sum_df = pd.DataFrame.from_dict(sum_dict).T.reset_index()
219
  sum_df.columns = ['title', 'summary_text']
220
+ # TO DO: Make sure summary_text does not exceed the token length
221
 
222
  st.dataframe(sum_df)
223
  st.download_button(
 
227
  mime='title_summary/csv',
228
  )
229
 
230
+ if ((len(text_input) == 0 and uploaded_text_files is None and uploaded_csv_text_files is None)
231
  or (len(labels) == 0 and uploaded_labels_file is None)):
232
  st.error('Enter some text and at least one possible topic to see label predictions.')
233
  else:
234
  st.markdown("### Top Label Predictions on Summary vs Full Text")
235
+
236
+ if uploaded_labels_file is not None:
237
+ labels_df = pd.read_csv(uploaded_labels_file)
238
+ label_list = labels_df.iloc[:, 0]
239
+ else:
240
+ label_list = labels
241
+ st.write(label_list)
242
+
243
  with st.spinner('Matching labels...'):
244
+
245
+ labels_sum_col_list = ['title', 'label', 'scores_from_summary']
246
+ labels_sum_df = pd.DataFrame(columns=labels_sum_col_list)
247
+
248
+ labels_full_col_list = ['title', 'label', 'scores_from_full_text']
249
+ labels_full_df = pd.DataFrame(columns=labels_full_col_list)
250
+
251
+ for i in range(0, len(text_df)):
252
+
253
+ s_topics, s_scores = md.classifier_zero(classifier, sequence=sum_df['summary_text'][i], labels=label_list, multi_class=True)
254
+ ls_df = pd.DataFrame({'label': s_topics, 'scores_from_summary': s_scores})
255
+ ls_df['title'] = text_df['title'][i]
256
+ labels_sum_df = pd.concat([labels_sum_df, ls_df[labels_sum_col_list]])
257
+
258
+ f_topics, f_scores = md.classifier_zero(classifier, sequence=text_df['text'][i], labels=label_list, multi_class=True)
259
+ lf_df = pd.DataFrame({'label': f_topics, 'scores_from_full_text': f_scores})
260
+ lf_df['title'] = text_df['title'][i]
261
+ labels_full_df = pd.concat([labels_full_df, lf_df[labels_full_col_list]])
262
+
263
+ label_match_df = pd.merge(labels_sum_df, labels_full_df, on=['title','label'])
264
+ st.dataframe(label_match_df)
265
+ st.download_button(
266
+ label="Download data as CSV",
267
+ data=label_match_df.to_csv().encode('utf-8'),
268
+ file_name='title_label_sum_full.csv',
269
+ mime='title_label_sum_full/csv',
270
+ )
271
 
272
  if len(glabels) > 0:
273
  gdata = pd.DataFrame({'label': glabels})