ZeroCommand commited on
Commit
536b2a2
1 Parent(s): 9037bf7

add pre-check for column mapping values

Browse files
Files changed (2) hide show
  1. app.py +17 -9
  2. text_classification.py +11 -2
app.py CHANGED
@@ -10,7 +10,7 @@ import json
10
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
- from text_classification import text_classification_fix_column_mapping
14
 
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
@@ -233,15 +233,23 @@ with gr.Blocks(theme=theme) as iface:
233
  column_mapping = '{}'
234
  if id2label_mapping_dataframe is not None:
235
  column_mapping = id2label_mapping_dataframe.to_json(orient="split")
236
- print(column_mapping)
237
- if model_id and dataset_id and dataset_config and dataset_split:
238
- return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
239
- else:
240
  return (gr.update(interactive=False),
241
- gr.update(visible=True),
242
- gr.update(visible=False),
243
- gr.update(visible=False),
244
- gr.update(visible=False))
 
 
 
 
 
 
 
 
 
 
245
  with gr.Row():
246
  gr.Markdown('''
247
  <h1 style="text-align: center;">
 
10
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
+ from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
14
 
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
 
233
  column_mapping = '{}'
234
  if id2label_mapping_dataframe is not None:
235
  column_mapping = id2label_mapping_dataframe.to_json(orient="split")
236
+ if check_column_mapping_keys_validity(column_mapping) is False:
237
+ gr.Warning('Label mapping table has invalid contents. Please check again.')
 
 
238
  return (gr.update(interactive=False),
239
+ gr.update(),
240
+ gr.update(),
241
+ gr.update(),
242
+ gr.update(),
243
+ gr.update())
244
+ else:
245
+ if model_id and dataset_id and dataset_config and dataset_split:
246
+ return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
247
+ else:
248
+ return (gr.update(interactive=False),
249
+ gr.update(visible=True),
250
+ gr.update(visible=False),
251
+ gr.update(visible=False),
252
+ gr.update(visible=False))
253
  with gr.Row():
254
  gr.Markdown('''
255
  <h1 style="text-align: center;">
text_classification.py CHANGED
@@ -1,7 +1,6 @@
1
  import datasets
2
-
3
  import logging
4
-
5
  import pandas as pd
6
 
7
 
@@ -35,6 +34,16 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
35
 
36
  return id2label_mapping, dataset_labels
37
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
40
  # We assume dataset is ok here
 
1
  import datasets
 
2
  import logging
3
+ import json
4
  import pandas as pd
5
 
6
 
 
34
 
35
  return id2label_mapping, dataset_labels
36
 
37
+ def check_column_mapping_keys_validity(column_mapping):
38
+ # get the element in all the list elements
39
+ column_mapping = json.loads(column_mapping)
40
+ if "data" not in column_mapping.keys():
41
+ return True
42
+ user_labels = set([pair[0] for pair in column_mapping["data"]])
43
+ model_labels = set([pair[1] for pair in column_mapping["data"]])
44
+
45
+ return user_labels == model_labels
46
+
47
 
48
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
49
  # We assume dataset is ok here