inoki-giskard commited on
Commit
54410d4
·
1 Parent(s): 27381a7

Fix dataset validation and label mapping

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. text_classification.py +5 -2
app.py CHANGED
@@ -270,8 +270,8 @@ with gr.Blocks(theme=theme) as iface:
270
  placeholder="tweet_eval",
271
  )
272
  with gr.Row():
273
- dataset_config_input = gr.Dropdown(['default'], value=['default'], label='Dataset Config', visible=False)
274
- dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
275
 
276
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
277
  dataset_config_input.change(
 
270
  placeholder="tweet_eval",
271
  )
272
  with gr.Row():
273
+ dataset_config_input = gr.Dropdown(['default'], value='default', label='Dataset Config', visible=False)
274
+ dataset_split_input = gr.Dropdown(['default'], value='default', label='Dataset Split', visible=False)
275
 
276
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
277
  dataset_config_input.change(
text_classification.py CHANGED
@@ -100,6 +100,9 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
100
 
101
  # Infer labels
102
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
 
 
 
103
  if "data" in column_mapping.keys():
104
  if isinstance(column_mapping["data"], list):
105
  # Use the column mapping passed by user
@@ -116,13 +119,13 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
116
  }
117
  id2label_df = pd.DataFrame({
118
  "Dataset Labels": dataset_labels,
119
- "Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
120
  })
121
 
122
  if "data" not in column_mapping.keys():
123
  # Column mapping should contain original model labels
124
  column_mapping["label"] = {
125
- str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
126
  }
127
 
128
  return column_mapping, prediction_input, prediction_result, id2label_df
 
100
 
101
  # Infer labels
102
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
103
+ id2label_mapping_dataset_model = {
104
+ v: k for k, v in id2label_mapping.items()
105
+ }
106
  if "data" in column_mapping.keys():
107
  if isinstance(column_mapping["data"], list):
108
  # Use the column mapping passed by user
 
119
  }
120
  id2label_df = pd.DataFrame({
121
  "Dataset Labels": dataset_labels,
122
+ "Model Prediction Labels": [id2label_mapping_dataset_model[label] for label in dataset_labels],
123
  })
124
 
125
  if "data" not in column_mapping.keys():
126
  # Column mapping should contain original model labels
127
  column_mapping["label"] = {
128
+ str(i): id2label_mapping_dataset_model[label] for i, label in zip(id2label.keys(), dataset_labels)
129
  }
130
 
131
  return column_mapping, prediction_input, prediction_result, id2label_df