ZeroCommand commited on
Commit
21e0bb3
1 Parent(s): bedf925

Fix for flattened raw config

Browse files
app.py CHANGED
@@ -10,7 +10,7 @@ from run_jobs import start_process_run_job, stop_thread
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
- get_demo_text_classification(demo)
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
 
10
  try:
11
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="green")) as demo:
12
  with gr.Tab("Text Classification"):
13
+ get_demo_text_classification()
14
  with gr.Tab("Leaderboard"):
15
  get_demo_leaderboard()
16
  with gr.Tab("Logs(Debug)"):
app_text_classification.py CHANGED
@@ -12,7 +12,7 @@ from text_classification_ui_helpers import (check_dataset_and_get_config,
12
  write_column_mapping_to_config)
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
- MAX_LABELS = 20
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
@@ -20,7 +20,7 @@ EXAMPLE_DATA_ID = "tweet_eval"
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
- def get_demo(demo):
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
@@ -55,9 +55,11 @@ def get_demo(demo):
55
  column_mappings = []
56
  with gr.Row():
57
  with gr.Column():
 
58
  for _ in range(MAX_LABELS):
59
  column_mappings.append(gr.Dropdown(visible=False))
60
  with gr.Column():
 
61
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
62
  column_mappings.append(gr.Dropdown(visible=False))
63
 
@@ -138,9 +140,6 @@ def get_demo(demo):
138
  triggers=[label.change for label in column_mappings],
139
  fn=write_column_mapping_to_config,
140
  inputs=[
141
- dataset_id_input,
142
- dataset_config_input,
143
- dataset_split_input,
144
  uid_label,
145
  *column_mappings,
146
  ],
@@ -151,9 +150,6 @@ def get_demo(demo):
151
  triggers=[label.input for label in column_mappings],
152
  fn=write_column_mapping_to_config,
153
  inputs=[
154
- dataset_id_input,
155
- dataset_config_input,
156
- dataset_split_input,
157
  uid_label,
158
  *column_mappings,
159
  ],
@@ -172,6 +168,7 @@ def get_demo(demo):
172
  dataset_id_input,
173
  dataset_config_input,
174
  dataset_split_input,
 
175
  ],
176
  outputs=[
177
  example_input,
 
12
  write_column_mapping_to_config)
13
  from wordings import CONFIRM_MAPPING_DETAILS_MD, INTRODUCTION_MD
14
 
15
+ MAX_LABELS = 40
16
  MAX_FEATURES = 20
17
 
18
  EXAMPLE_MODEL_ID = "cardiffnlp/twitter-roberta-base-sentiment-latest"
 
20
  CONFIG_PATH = "./config.yaml"
21
 
22
 
23
+ def get_demo():
24
  with gr.Row():
25
  gr.Markdown(INTRODUCTION_MD)
26
  uid_label = gr.Textbox(
 
55
  column_mappings = []
56
  with gr.Row():
57
  with gr.Column():
58
+ gr.Markdown("# Label Mapping")
59
  for _ in range(MAX_LABELS):
60
  column_mappings.append(gr.Dropdown(visible=False))
61
  with gr.Column():
62
+ gr.Markdown("# Feature Mapping")
63
  for _ in range(MAX_LABELS, MAX_LABELS + MAX_FEATURES):
64
  column_mappings.append(gr.Dropdown(visible=False))
65
 
 
140
  triggers=[label.change for label in column_mappings],
141
  fn=write_column_mapping_to_config,
142
  inputs=[
 
 
 
143
  uid_label,
144
  *column_mappings,
145
  ],
 
150
  triggers=[label.input for label in column_mappings],
151
  fn=write_column_mapping_to_config,
152
  inputs=[
 
 
 
153
  uid_label,
154
  *column_mappings,
155
  ],
 
168
  dataset_id_input,
169
  dataset_config_input,
170
  dataset_split_input,
171
+ uid_label,
172
  ],
173
  outputs=[
174
  example_input,
io_utils.py CHANGED
@@ -76,7 +76,6 @@ def read_column_mapping(uid):
76
  config = yaml.load(f, Loader=yaml.FullLoader)
77
  if config:
78
  column_mapping = config.get("column_mapping", dict())
79
- f.close()
80
  return column_mapping
81
 
82
 
@@ -84,7 +83,6 @@ def read_column_mapping(uid):
84
  def write_column_mapping(mapping, uid):
85
  with open(get_yaml_path(uid), "r") as f:
86
  config = yaml.load(f, Loader=yaml.FullLoader)
87
- f.close()
88
 
89
  if config is None:
90
  return
@@ -92,10 +90,9 @@ def write_column_mapping(mapping, uid):
92
  del config["column_mapping"]
93
  else:
94
  config["column_mapping"] = mapping
95
-
96
  with open(get_yaml_path(uid), "w") as f:
97
- yaml.dump(config, f, Dumper=Dumper)
98
- f.close()
99
 
100
 
101
  # convert column mapping dataframe to json
 
76
  config = yaml.load(f, Loader=yaml.FullLoader)
77
  if config:
78
  column_mapping = config.get("column_mapping", dict())
 
79
  return column_mapping
80
 
81
 
 
83
  def write_column_mapping(mapping, uid):
84
  with open(get_yaml_path(uid), "r") as f:
85
  config = yaml.load(f, Loader=yaml.FullLoader)
 
86
 
87
  if config is None:
88
  return
 
90
  del config["column_mapping"]
91
  else:
92
  config["column_mapping"] = mapping
 
93
  with open(get_yaml_path(uid), "w") as f:
94
+ # yaml Dumper will by default sort the keys
95
+ yaml.dump(config, f, Dumper=Dumper, sort_keys=False)
96
 
97
 
98
  # convert column mapping dataframe to json
text_classification.py CHANGED
@@ -16,15 +16,16 @@ def get_labels_and_features_from_dataset(dataset_id, dataset_config, split):
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
- if len(label_keys) == 0:
20
- raise ValueError("Dataset does not have label column")
 
21
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
22
  if hasattr(dataset_features[label_keys[0]], 'feature'):
23
  label_feat = dataset_features[label_keys[0]].feature
24
  labels = label_feat.names
25
  else:
26
  labels = [dataset_features[label_keys[0]].names]
27
- features = [f for f in dataset_features.keys() if f != "label"]
28
  return labels, features
29
  except Exception as e:
30
  logging.warning(
 
16
  ds = datasets.load_dataset(dataset_id, dataset_config)[split]
17
  dataset_features = ds.features
18
  label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
19
+ if len(label_keys) == 0: # no labels found
20
+ # return everything for post processing
21
+ return list(dataset_features.keys()), list(dataset_features.keys())
22
  if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
23
  if hasattr(dataset_features[label_keys[0]], 'feature'):
24
  label_feat = dataset_features[label_keys[0]].feature
25
  labels = label_feat.names
26
  else:
27
  labels = [dataset_features[label_keys[0]].names]
28
+ features = [f for f in dataset_features.keys() if not f.startswith("label")]
29
  return labels, features
30
  except Exception as e:
31
  logging.warning(
text_classification_ui_helpers.py CHANGED
@@ -18,7 +18,7 @@ from wordings import (CHECK_CONFIG_OR_SPLIT_RAW,
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
- MAX_LABELS = 20
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
@@ -68,46 +68,62 @@ def deselect_run_inference(run_local):
68
 
69
 
70
  def write_column_mapping_to_config(
71
- dataset_id, dataset_config, dataset_split, uid, *labels
72
  ):
73
  # TODO: Substitute 'text' with more features for zero-shot
74
  # we are not using ds features because we only support "text" for now
75
- ds_labels, _ = get_labels_and_features_from_dataset(
76
- dataset_id, dataset_config, dataset_split
77
- )
78
  if labels is None:
79
  return
 
 
80
 
81
- all_mappings = dict()
82
-
83
- if "labels" not in all_mappings.keys():
84
- all_mappings["labels"] = dict()
85
- for i, label in enumerate(labels[:MAX_LABELS]):
86
- if label:
87
- all_mappings["labels"][label] = ds_labels[i % len(ds_labels)]
88
- if "features" not in all_mappings.keys():
89
- all_mappings["features"] = dict()
90
- for _, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]):
91
- if feat:
92
- # TODO: Substitute 'text' with more features for zero-shot
93
- all_mappings["features"]["text"] = feat
94
  write_column_mapping(all_mappings, uid)
95
 
96
-
97
- def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  model_labels = list(model_id2label.values())
99
- len_model_labels = len(model_labels)
 
 
 
 
 
 
 
 
 
 
 
 
100
  lables = [
101
  gr.Dropdown(
102
  label=f"{label}",
103
  choices=model_labels,
104
- value=model_id2label[i % len_model_labels],
105
  interactive=True,
106
  visible=True,
107
  )
108
- for i, label in enumerate(ds_labels[:MAX_LABELS])
109
  ]
110
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
 
 
111
  # TODO: Substitute 'text' with more features for zero-shot
112
  features = [
113
  gr.Dropdown(
@@ -122,11 +138,14 @@ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label
122
  features += [
123
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
124
  ]
 
 
 
125
  return lables + features
126
 
127
 
128
  def check_model_and_show_prediction(
129
- model_id, dataset_id, dataset_config, dataset_split
130
  ):
131
  ppl = check_model(model_id)
132
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
@@ -168,6 +187,7 @@ def check_model_and_show_prediction(
168
  ds_labels,
169
  ds_features,
170
  model_id2label,
 
171
  )
172
 
173
  # when labels or features are not aligned
 
18
  CONFIRM_MAPPING_DETAILS_FAIL_RAW,
19
  MAPPING_STYLED_ERROR_WARNING)
20
 
21
+ MAX_LABELS = 40
22
  MAX_FEATURES = 20
23
 
24
  HF_REPO_ID = "HF_REPO_ID"
 
68
 
69
 
70
  def write_column_mapping_to_config(
71
+ uid, *labels
72
  ):
73
  # TODO: Substitute 'text' with more features for zero-shot
74
  # we are not using ds features because we only support "text" for now
75
+ all_mappings = read_column_mapping(uid)
76
+
 
77
  if labels is None:
78
  return
79
+ all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
80
+ all_mappings = export_mappings(all_mappings, "features", ["text"], labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)])
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  write_column_mapping(all_mappings, uid)
83
 
84
+ def export_mappings(all_mappings, key, subkeys, values):
85
+ if key not in all_mappings.keys():
86
+ all_mappings[key] = dict()
87
+ if subkeys is None:
88
+ subkeys = list(all_mappings[key].keys())
89
+
90
+ if not subkeys:
91
+ logging.debug(f"subkeys is empty for {key}")
92
+ return all_mappings
93
+
94
+ for i, subkey in enumerate(subkeys):
95
+ if subkey:
96
+ all_mappings[key][subkey] = values[i % len(values)]
97
+ return all_mappings
98
+
99
+ def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label, uid):
100
  model_labels = list(model_id2label.values())
101
+ all_mappings = read_column_mapping(uid)
102
+ # For flattened raw datasets with no labels
103
+ # check if there are shared labels between model and dataset
104
+ shared_labels = set(model_labels).intersection(set(ds_labels))
105
+ if shared_labels:
106
+ ds_labels = list(shared_labels)
107
+ if len(ds_labels) > MAX_LABELS:
108
+ ds_labels = ds_labels[:MAX_LABELS]
109
+ gr.Warning(f"The number of labels is truncated to length {MAX_LABELS}")
110
+
111
+ ds_labels.sort()
112
+ model_labels.sort()
113
+
114
  lables = [
115
  gr.Dropdown(
116
  label=f"{label}",
117
  choices=model_labels,
118
+ value=model_id2label[i % len(model_labels)],
119
  interactive=True,
120
  visible=True,
121
  )
122
+ for i, label in enumerate(ds_labels)
123
  ]
124
  lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
125
+ all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)
126
+
127
  # TODO: Substitute 'text' with more features for zero-shot
128
  features = [
129
  gr.Dropdown(
 
138
  features += [
139
  gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
140
  ]
141
+ all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
142
+ write_column_mapping(all_mappings, uid)
143
+
144
  return lables + features
145
 
146
 
147
  def check_model_and_show_prediction(
148
+ model_id, dataset_id, dataset_config, dataset_split, uid
149
  ):
150
  ppl = check_model(model_id)
151
  if ppl is None or not isinstance(ppl, TextClassificationPipeline):
 
187
  ds_labels,
188
  ds_features,
189
  model_id2label,
190
+ uid,
191
  )
192
 
193
  # when labels or features are not aligned