inoki-giskard commited on
Commit
583defc
1 Parent(s): ea670d5

Attemp to match labels in model and in dataset

Browse files
Files changed (1) hide show
  1. app.py +67 -5
app.py CHANGED
@@ -6,6 +6,10 @@ import os
6
  import time
7
  from pathlib import Path
8
 
 
 
 
 
9
 
10
  HF_REPO_ID = 'HF_REPO_ID'
11
  HF_SPACE_ID = 'SPACE_ID'
@@ -54,15 +58,41 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
54
  return dataset_id, dataset_config, dataset_split
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def try_validate(model_id, dataset_id, dataset_config, dataset_split):
58
  # Validate model
59
  m_id, ppl = check_model(model_id=model_id)
60
  if m_id is None:
61
  gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
62
- return dataset_config, dataset_split, gr.update(interactive=False)
63
  if isinstance(ppl, Exception):
64
  gr.Warning(f'Failed to load "{model_id} model": {ppl}')
65
- return dataset_config, dataset_split, gr.update(interactive=False)
66
 
67
  # Validate dataset
68
  d_id, config, split = check_dataset(dataset_id=dataset_id, dataset_config=dataset_config, dataset_split=dataset_split)
@@ -80,15 +110,42 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split):
80
  dataset_ok = True
81
 
82
  if not dataset_ok:
83
- return config, split, gr.update(interactive=False)
84
 
85
  # TODO: Validate column mapping by running once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  del ppl
88
 
89
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
90
 
91
- return config, split, gr.update(interactive=True)
92
 
93
 
94
  def try_submit(m_id, d_id, config, split, local):
@@ -133,7 +190,7 @@ def try_submit(m_id, d_id, config, split, local):
133
  with open(output_dir / "report.html", "w") as f:
134
  print(f'Writing to {output_dir / "report.html"}')
135
  f.write(rendered_report)
136
-
137
  print(f"Finished local evaluation on {eval_str}: {time.time() - start:.2f}s")
138
 
139
 
@@ -155,6 +212,7 @@ with gr.Blocks(theme=theme) as iface:
155
  value=0,
156
  )
157
  run_local = gr.Checkbox(value=True, label="Run in this Space")
 
158
 
159
  with gr.Column():
160
  dataset_id_input = gr.Textbox(
@@ -180,6 +238,8 @@ with gr.Blocks(theme=theme) as iface:
180
  value="test",
181
  )
182
 
 
 
183
  with gr.Row():
184
  validate_btn = gr.Button("Validate model and dataset", variant="primary")
185
  run_btn = gr.Button(
@@ -199,6 +259,8 @@ with gr.Blocks(theme=theme) as iface:
199
  dataset_config_input,
200
  dataset_split_input,
201
  run_btn,
 
 
202
  ],
203
  )
204
  run_btn.click(
 
6
  import time
7
  from pathlib import Path
8
 
9
+ import pandas as pd
10
+
11
+ from transformers.pipelines import TextClassificationPipeline
12
+
13
 
14
  HF_REPO_ID = 'HF_REPO_ID'
15
  HF_SPACE_ID = 'SPACE_ID'
 
58
  return dataset_id, dataset_config, dataset_split
59
 
60
 
61
+ def text_classificaiton_match_label_case_unsensative(id2label_mapping, label):
62
+ for model_label in id2label_mapping.keys():
63
+ if model_label.upper() == label.upper():
64
+ return model_label, label
65
+
66
+
67
+ def text_classification_map_model_and_dataset_labels(id2label, dataset_features):
68
+ id2label_mapping = {id2label[k]: None for k in id2label.keys()}
69
+ for feature in dataset_features.values():
70
+ if not isinstance(feature, datasets.ClassLabel):
71
+ continue
72
+ if len(feature.names) != len(id2label_mapping.keys()):
73
+ continue
74
+
75
+ # Try to match labels
76
+ for label in feature.names:
77
+ if label in id2label_mapping.keys():
78
+ model_label = label
79
+ else:
80
+ # Try to find case unsensative
81
+ model_label, label = text_classificaiton_match_label_case_unsensative(id2label_mapping, label)
82
+ id2label_mapping[model_label] = label
83
+
84
+ return id2label_mapping
85
+
86
+
87
  def try_validate(model_id, dataset_id, dataset_config, dataset_split):
88
  # Validate model
89
  m_id, ppl = check_model(model_id=model_id)
90
  if m_id is None:
91
  gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
92
+ return dataset_config, dataset_split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
93
  if isinstance(ppl, Exception):
94
  gr.Warning(f'Failed to load "{model_id} model": {ppl}')
95
+ return dataset_config, dataset_split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
96
 
97
  # Validate dataset
98
  d_id, config, split = check_dataset(dataset_id=dataset_id, dataset_config=dataset_config, dataset_split=dataset_split)
 
110
  dataset_ok = True
111
 
112
  if not dataset_ok:
113
+ return config, split, gr.update(interactive=False), gr.update(visible=False), gr.update(visible=False)
114
 
115
  # TODO: Validate column mapping by running once
116
+ prediction_result = {}
117
+ id2label_df = None
118
+ if isinstance(ppl, TextClassificationPipeline):
119
+ # Retrieve all labels
120
+ id2label_mapping = {}
121
+ try:
122
+ results = ppl({"text": "Test"}, top_k=None)
123
+ prediction_result = {
124
+ result["label"]: result["score"] for result in results
125
+ }
126
+ except Exception as e:
127
+ # Pipeline is not executable
128
+ pass
129
+
130
+ # We assume dataset is ok here
131
+ ds = datasets.load_dataset(d_id, config)[split]
132
+ try:
133
+ id2label = ppl.model.config.id2label
134
+ id2label_mapping = text_classification_map_model_and_dataset_labels(ppl.model.config.id2label, ds.features)
135
+ id2label_df = pd.DataFrame({
136
+ "ID": [i for i in id2label.keys()],
137
+ "Model labels": [id2label[label] for label in id2label.keys()],
138
+ "Dataset labels": [id2label_mapping[id2label[label]] for label in id2label.keys()],
139
+ })
140
+ except AttributeError:
141
+ # Dataset does not have features
142
+ pass
143
 
144
  del ppl
145
 
146
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
147
 
148
+ return config, split, gr.update(interactive=True), gr.update(value=prediction_result, visible=True), gr.update(value=id2label_df, visible=True)
149
 
150
 
151
  def try_submit(m_id, d_id, config, split, local):
 
190
  with open(output_dir / "report.html", "w") as f:
191
  print(f'Writing to {output_dir / "report.html"}')
192
  f.write(rendered_report)
193
+
194
  print(f"Finished local evaluation on {eval_str}: {time.time() - start:.2f}s")
195
 
196
 
 
212
  value=0,
213
  )
214
  run_local = gr.Checkbox(value=True, label="Run in this Space")
215
+ example_labels = gr.Label(label='Model pipeline test prediction result', visible=False)
216
 
217
  with gr.Column():
218
  dataset_id_input = gr.Textbox(
 
238
  value="test",
239
  )
240
 
241
+ id2label_mapping_dataframe = gr.DataFrame(visible=False)
242
+
243
  with gr.Row():
244
  validate_btn = gr.Button("Validate model and dataset", variant="primary")
245
  run_btn = gr.Button(
 
259
  dataset_config_input,
260
  dataset_split_input,
261
  run_btn,
262
+ example_labels,
263
+ id2label_mapping_dataframe,
264
  ],
265
  )
266
  run_btn.click(