ZeroCommand commited on
Commit
2569404
1 Parent(s): 8e32a09

fix feature mapping by adding multi labels case

Browse files
text_classification.py CHANGED
@@ -22,7 +22,12 @@ class HuggingFaceInferenceAPIResponse:
22
  def get_labels_and_features_from_dataset(ds):
23
  try:
24
  dataset_features = ds.features
25
- label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
 
 
 
 
 
26
  if len(label_keys) == 0: # no labels found
27
  # return everything for post processing
28
  return list(dataset_features.keys()), list(dataset_features.keys())
@@ -32,7 +37,6 @@ def get_labels_and_features_from_dataset(ds):
32
  labels = label_feat.names
33
  else:
34
  labels = dataset_features[label_keys[0]].names
35
- features = [f for f in dataset_features.keys() if not f.startswith("label")]
36
  return labels, features
37
  except Exception as e:
38
  logging.warning(
 
22
  def get_labels_and_features_from_dataset(ds):
23
  try:
24
  dataset_features = ds.features
25
+ label_keys = [i for i in dataset_features.keys() if i == 'label']
26
+ features = [f for f in dataset_features.keys() if not f.startswith("label")]
27
+ if len(label_keys) == 0: # no labels found
28
+ label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
29
+ features += label_keys
30
+
31
  if len(label_keys) == 0: # no labels found
32
  # return everything for post processing
33
  return list(dataset_features.keys()), list(dataset_features.keys())
 
37
  labels = label_feat.names
38
  else:
39
  labels = dataset_features[label_keys[0]].names
 
40
  return labels, features
41
  except Exception as e:
42
  logging.warning(
text_classification_ui_helpers.py CHANGED
@@ -138,7 +138,7 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels,
138
  ds_labels = list(shared_labels)
139
  if len(ds_labels) > MAX_LABELS:
140
  ds_labels = ds_labels[:MAX_LABELS]
141
- gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
142
 
143
  # sort labels to make sure the order is consistent
144
  # prediction gives the order based on probability
@@ -393,10 +393,12 @@ def enable_run_btn(uid, run_inference, inference_token, model_id, dataset_id, da
393
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
394
  label_mapping = {}
395
  if len(all_mappings["labels"].keys()) != len(ds_labels):
396
- logger.warn("Label mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
 
397
 
398
  if len(all_mappings["features"].keys()) != len(ds_features):
399
- logger.warn("Feature mapping corrupted: " + CONFIRM_MAPPING_DETAILS_FAIL_RAW)
 
400
 
401
  for i, label in zip(range(len(ds_labels)), ds_labels):
402
  # align the saved labels with dataset labels order
@@ -405,7 +407,10 @@ def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
405
  if "features" not in all_mappings.keys():
406
  logger.warning("features not in all_mappings")
407
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
 
 
408
  feature_mapping = all_mappings["features"]
 
409
  return label_mapping, feature_mapping
410
 
411
  def show_hf_token_info(token):
 
138
  ds_labels = list(shared_labels)
139
  if len(ds_labels) > MAX_LABELS:
140
  ds_labels = ds_labels[:MAX_LABELS]
141
+ gr.Warning(f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd.")
142
 
143
  # sort labels to make sure the order is consistent
144
  # prediction gives the order based on probability
 
393
  def construct_label_and_feature_mapping(all_mappings, ds_labels, ds_features):
394
  label_mapping = {}
395
  if len(all_mappings["labels"].keys()) != len(ds_labels):
396
+ logger.warn(f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
397
+ \nall_mappings: {all_mappings}\nds_labels: {ds_labels}""")
398
 
399
  if len(all_mappings["features"].keys()) != len(ds_features):
400
+ logger.warn(f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}.
401
+ \nall_mappings: {all_mappings}\nds_features: {ds_features}""")
402
 
403
  for i, label in zip(range(len(ds_labels)), ds_labels):
404
  # align the saved labels with dataset labels order
 
407
  if "features" not in all_mappings.keys():
408
  logger.warning("features not in all_mappings")
409
  gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
410
+
411
+ special_classlabel = [feature for feature in ds_features if feature.startswith("label")]
412
  feature_mapping = all_mappings["features"]
413
+ feature_mapping.update({"label": special_classlabel[0] if special_classlabel else "label"})
414
  return label_mapping, feature_mapping
415
 
416
  def show_hf_token_info(token):