ZeroCommand commited on
Commit
ac0eaff
1 Parent(s): 536b2a2

polish up and add more information

Browse files
Files changed (2) hide show
  1. app.py +11 -12
  2. text_classification.py +9 -4
app.py CHANGED
@@ -59,11 +59,10 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
- def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
63
  # Validate model
64
- m_id, ppl = check_model(model_id=model_id)
65
  if m_id is None:
66
- gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
67
  return (
68
  gr.update(interactive=False), # Submit button
69
  gr.update(visible=True), # Loading row
@@ -73,7 +72,7 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
73
  gr.update(visible=False), # Label mapping preview
74
  )
75
  if isinstance(ppl, Exception):
76
- gr.Warning(f'Failed to load "{model_id} model": {ppl}')
77
  return (
78
  gr.update(interactive=False), # Submit button
79
  gr.update(visible=True), # Loading row
@@ -124,8 +123,6 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
124
 
125
  column_mapping = json.dumps(column_mapping, indent=2)
126
 
127
- del ppl
128
-
129
  if prediction_result is None:
130
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
131
  return (
@@ -212,7 +209,6 @@ with gr.Blocks(theme=theme) as iface:
212
  def check_dataset_and_get_config(dataset_id):
213
  try:
214
  configs = datasets.get_dataset_config_names(dataset_id)
215
- print(configs)
216
  return gr.Dropdown(configs, value=configs[0], visible=True)
217
  except Exception:
218
  # Dataset may not exist
@@ -221,19 +217,19 @@ with gr.Blocks(theme=theme) as iface:
221
  def check_dataset_and_get_split(dataset_config, dataset_id):
222
  try:
223
  splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
224
- print('splits: ',splits)
225
  return gr.Dropdown(splits, value=splits[0], visible=True)
226
  except Exception as e:
227
  # Dataset may not exist
228
- print(e)
229
  pass
230
 
231
  def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
232
- print('model_id: ',model_id)
233
  column_mapping = '{}'
 
 
234
  if id2label_mapping_dataframe is not None:
235
  column_mapping = id2label_mapping_dataframe.to_json(orient="split")
236
- if check_column_mapping_keys_validity(column_mapping) is False:
237
  gr.Warning('Label mapping table has invalid contents. Please check again.')
238
  return (gr.update(interactive=False),
239
  gr.update(),
@@ -243,12 +239,15 @@ with gr.Blocks(theme=theme) as iface:
243
  gr.update())
244
  else:
245
  if model_id and dataset_id and dataset_config and dataset_split:
246
- return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
247
  else:
 
 
248
  return (gr.update(interactive=False),
249
  gr.update(visible=True),
250
  gr.update(visible=False),
251
  gr.update(visible=False),
 
252
  gr.update(visible=False))
253
  with gr.Row():
254
  gr.Markdown('''
 
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
+ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
63
  # Validate model
 
64
  if m_id is None:
65
+ gr.Warning('Model is not accessible. Please set your HF_TOKEN if it is a private model.')
66
  return (
67
  gr.update(interactive=False), # Submit button
68
  gr.update(visible=True), # Loading row
 
72
  gr.update(visible=False), # Label mapping preview
73
  )
74
  if isinstance(ppl, Exception):
75
+ gr.Warning(f'Failed to load model": {ppl}')
76
  return (
77
  gr.update(interactive=False), # Submit button
78
  gr.update(visible=True), # Loading row
 
123
 
124
  column_mapping = json.dumps(column_mapping, indent=2)
125
 
 
 
126
  if prediction_result is None:
127
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
128
  return (
 
209
  def check_dataset_and_get_config(dataset_id):
210
  try:
211
  configs = datasets.get_dataset_config_names(dataset_id)
 
212
  return gr.Dropdown(configs, value=configs[0], visible=True)
213
  except Exception:
214
  # Dataset may not exist
 
217
  def check_dataset_and_get_split(dataset_config, dataset_id):
218
  try:
219
  splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
 
220
  return gr.Dropdown(splits, value=splits[0], visible=True)
221
  except Exception as e:
222
  # Dataset may not exist
223
+ gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
224
  pass
225
 
226
  def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
 
227
  column_mapping = '{}'
228
+ m_id, ppl = check_model(model_id=model_id)
229
+
230
  if id2label_mapping_dataframe is not None:
231
  column_mapping = id2label_mapping_dataframe.to_json(orient="split")
232
+ if check_column_mapping_keys_validity(column_mapping, ppl) is False:
233
  gr.Warning('Label mapping table has invalid contents. Please check again.')
234
  return (gr.update(interactive=False),
235
  gr.update(),
 
239
  gr.update())
240
  else:
241
  if model_id and dataset_id and dataset_config and dataset_split:
242
+ return try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping)
243
  else:
244
+ del ppl
245
+
246
  return (gr.update(interactive=False),
247
  gr.update(visible=True),
248
  gr.update(visible=False),
249
  gr.update(visible=False),
250
+ gr.update(visible=False),
251
  gr.update(visible=False))
252
  with gr.Row():
253
  gr.Markdown('''
text_classification.py CHANGED
@@ -34,15 +34,19 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
34
 
35
  return id2label_mapping, dataset_labels
36
 
37
- def check_column_mapping_keys_validity(column_mapping):
 
38
  # get the element in all the list elements
39
  column_mapping = json.loads(column_mapping)
40
  if "data" not in column_mapping.keys():
41
  return True
42
  user_labels = set([pair[0] for pair in column_mapping["data"]])
43
  model_labels = set([pair[1] for pair in column_mapping["data"]])
 
 
 
44
 
45
- return user_labels == model_labels
46
 
47
 
48
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
@@ -100,7 +104,6 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
100
  if isinstance(column_mapping["data"], list):
101
  # Use the column mapping passed by user
102
  for user_label, model_label in column_mapping["data"]:
103
- print(user_label, model_label)
104
  id2label_mapping[model_label] = user_label
105
  elif None in id2label_mapping.values():
106
  column_mapping["label"] = {
@@ -108,7 +111,9 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
108
  }
109
  return column_mapping, prediction_result, None
110
 
111
- print(id2label_mapping)
 
 
112
  id2label_df = pd.DataFrame({
113
  "Dataset Labels": dataset_labels,
114
  "Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
 
34
 
35
  return id2label_mapping, dataset_labels
36
 
37
+
38
+ def check_column_mapping_keys_validity(column_mapping, ppl):
39
  # get the element in all the list elements
40
  column_mapping = json.loads(column_mapping)
41
  if "data" not in column_mapping.keys():
42
  return True
43
  user_labels = set([pair[0] for pair in column_mapping["data"]])
44
  model_labels = set([pair[1] for pair in column_mapping["data"]])
45
+
46
+ id2label = ppl.model.config.id2label
47
+ original_labels = set(id2label.values())
48
 
49
+ return user_labels == model_labels == original_labels
50
 
51
 
52
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
 
104
  if isinstance(column_mapping["data"], list):
105
  # Use the column mapping passed by user
106
  for user_label, model_label in column_mapping["data"]:
 
107
  id2label_mapping[model_label] = user_label
108
  elif None in id2label_mapping.values():
109
  column_mapping["label"] = {
 
111
  }
112
  return column_mapping, prediction_result, None
113
 
114
+ prediction_result = {
115
+ f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
116
+ }
117
  id2label_df = pd.DataFrame({
118
  "Dataset Labels": dataset_labels,
119
  "Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],