Paula Leonova commited on
Commit
ee24d8b
·
1 Parent(s): b1bf232

Append ground truth labels to matched table

Browse files
Files changed (1) hide show
  1. app.py +14 -31
app.py CHANGED
@@ -155,6 +155,11 @@ if submit_button or example_button:
155
  elif uploaded_csv_text_files is not None:
156
  text_df = pd.read_csv(uploaded_csv_text_files)
157
 
 
 
 
 
 
158
 
159
  with st.spinner('Breaking up text into more reasonable chunks (transformers cannot exceed a 1024 token max)...'):
160
  # For each body of text, create text chunks of a certain token size required for the transformer
@@ -187,10 +192,7 @@ if submit_button or example_button:
187
  kw_df0 = pd.DataFrame.from_dict(kw_dict).reset_index()
188
  kw_df0.rename(columns={'index': 'keyword'}, inplace=True)
189
  kw_df = pd.melt(kw_df0, id_vars=['keyword'], var_name='title', value_name='score').dropna()
190
- if len(text_input) != 0:
191
- title_element = []
192
- else:
193
- title_element = ['title']
194
  kw_column_list = ['keyword', 'score']
195
  kw_df = kw_df[kw_df['score'] > 0.25][title_element + kw_column_list].sort_values(title_element + ['score'], ascending=False).reset_index().drop(columns='index')
196
 
@@ -283,6 +285,14 @@ if submit_button or example_button:
283
  else:
284
  label_match_df = labels_full_df.copy()
285
 
 
 
 
 
 
 
 
 
286
  st.dataframe(label_match_df)
287
  st.download_button(
288
  label="Download data as CSV",
@@ -291,33 +301,6 @@ if submit_button or example_button:
291
  mime='title_label_sum_full/csv',
292
  )
293
 
294
- if len(glabels) > 0:
295
- gdata = pd.DataFrame({'label': glabels})
296
- gdata['is_true_label'] = int(1)
297
-
298
- data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
299
- data2['is_true_label'].fillna(0, inplace = True)
300
-
301
- st.markdown("### Data Table")
302
- with st.spinner('Generating a table of results and a download link...'):
303
- st.dataframe(data2)
304
-
305
- @st.cache
306
- def convert_df(df):
307
- # IMPORTANT: Cache the conversion to prevent computation on every rerun
308
- return df.to_csv().encode('utf-8')
309
- csv = convert_df(data2)
310
- st.download_button(
311
- label="Download data as CSV",
312
- data=csv,
313
- file_name='text_labels.csv',
314
- mime='text/csv',
315
- )
316
- # coded_data = base64.b64encode(data2.to_csv(index = False). encode ()).decode()
317
- # st.markdown(
318
- # f'<a href="data:file/csv;base64, {coded_data}" download = "data.csv">Click here to download the data</a>',
319
- # unsafe_allow_html = True
320
- # )
321
 
322
  if len(glabels) > 0:
323
  st.markdown("### Evaluation Metrics")
 
155
  elif uploaded_csv_text_files is not None:
156
  text_df = pd.read_csv(uploaded_csv_text_files)
157
 
158
+ # Which input was used? If text area was used, ignore the 'title'
159
+ if len(text_input) != 0:
160
+ title_element = []
161
+ else:
162
+ title_element = ['title']
163
 
164
  with st.spinner('Breaking up text into more reasonable chunks (transformers cannot exceed a 1024 token max)...'):
165
  # For each body of text, create text chunks of a certain token size required for the transformer
 
192
  kw_df0 = pd.DataFrame.from_dict(kw_dict).reset_index()
193
  kw_df0.rename(columns={'index': 'keyword'}, inplace=True)
194
  kw_df = pd.melt(kw_df0, id_vars=['keyword'], var_name='title', value_name='score').dropna()
195
+
 
 
 
196
  kw_column_list = ['keyword', 'score']
197
  kw_df = kw_df[kw_df['score'] > 0.25][title_element + kw_column_list].sort_values(title_element + ['score'], ascending=False).reset_index().drop(columns='index')
198
 
 
285
  else:
286
  label_match_df = labels_full_df.copy()
287
 
288
+ # TO DO: ADD Flexibility for csv import
289
+ if len(glabels) > 0:
290
+ gdata = pd.DataFrame({'label': glabels})
291
+ gdata['is_true_label'] = True
292
+
293
+ label_match_df = pd.merge(label_match_df, gdata, how = 'left', on = title_element + ['label'])
294
+ label_match_df['correct_match'].fillna(0, inplace = True)
295
+
296
  st.dataframe(label_match_df)
297
  st.download_button(
298
  label="Download data as CSV",
 
301
  mime='title_label_sum_full/csv',
302
  )
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  if len(glabels) > 0:
306
  st.markdown("### Evaluation Metrics")