GSK-2362-improve-uiux-for-hfspace

#7
by ZeroCommand - opened
Files changed (2) hide show
  1. app.py +99 -93
  2. text_classification.py +29 -18
app.py CHANGED
@@ -10,7 +10,7 @@ import json
10
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
- from text_classification import text_classification_fix_column_mapping
14
 
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
@@ -59,26 +59,27 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
- def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
63
  # Validate model
64
- m_id, ppl = check_model(model_id=model_id)
65
  if m_id is None:
66
- gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
67
  return (
68
- dataset_config, dataset_split,
69
  gr.update(interactive=False), # Submit button
 
 
 
70
  gr.update(visible=False), # Model prediction preview
71
  gr.update(visible=False), # Label mapping preview
72
- gr.update(visible=True), # Column mapping
73
  )
74
  if isinstance(ppl, Exception):
75
- gr.Warning(f'Failed to load "{model_id} model": {ppl}')
76
  return (
77
- dataset_config, dataset_split,
78
  gr.update(interactive=False), # Submit button
 
 
 
79
  gr.update(visible=False), # Model prediction preview
80
  gr.update(visible=False), # Label mapping preview
81
- gr.update(visible=True), # Column mapping
82
  )
83
 
84
  # Validate dataset
@@ -98,11 +99,13 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
98
 
99
  if not dataset_ok:
100
  return (
101
- config, split,
102
  gr.update(interactive=False), # Submit button
 
 
 
103
  gr.update(visible=False), # Model prediction preview
104
  gr.update(visible=False), # Label mapping preview
105
- gr.update(visible=True), # Column mapping
106
  )
107
 
108
  # TODO: Validate column mapping by running once
@@ -110,45 +113,48 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
110
  id2label_df = None
111
  if isinstance(ppl, TextClassificationPipeline):
112
  try:
 
113
  column_mapping = json.loads(column_mapping)
114
  except Exception:
115
  column_mapping = {}
116
 
117
- column_mapping, prediction_result, id2label_df = \
118
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
119
 
120
  column_mapping = json.dumps(column_mapping, indent=2)
121
 
122
- del ppl
123
-
124
  if prediction_result is None:
125
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
126
  return (
127
- config, split,
128
  gr.update(interactive=False), # Submit button
 
 
 
129
  gr.update(visible=False), # Model prediction preview
130
  gr.update(visible=False), # Label mapping preview
131
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
132
  )
133
  elif id2label_df is None:
134
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
135
  return (
136
- config, split,
137
  gr.update(interactive=False), # Submit button
 
 
 
138
  gr.update(value=prediction_result, visible=True), # Model prediction preview
139
  gr.update(visible=False), # Label mapping preview
140
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
141
  )
142
 
143
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
144
 
145
  return (
 
146
  gr.update(visible=False), # Loading row
147
  gr.update(visible=True), # Preview row
148
- gr.update(interactive=True), # Submit button
149
  gr.update(value=prediction_result, visible=True), # Model prediction preview
150
- gr.update(value=id2label_df, visible=True), # Label mapping preview
151
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
152
  )
153
 
154
 
@@ -200,36 +206,56 @@ def try_submit(m_id, d_id, config, split, column_mappings, local):
200
 
201
  with gr.Blocks(theme=theme) as iface:
202
  with gr.Tab("Text Classification"):
203
- global_ds_id = gr.State('ds')
204
-
205
  def check_dataset_and_get_config(dataset_id):
206
- global_ds_id.value = dataset_id
207
  try:
208
  configs = datasets.get_dataset_config_names(dataset_id)
209
- print(configs)
210
  return gr.Dropdown(configs, value=configs[0], visible=True)
211
  except Exception:
212
  # Dataset may not exist
213
  pass
214
 
215
- def check_dataset_and_get_split(choice):
216
- print('choice: ',choice, global_ds_id.value)
217
  try:
218
- splits = list(datasets.load_dataset(global_ds_id.value, choice).keys())
219
- print('splits: ',splits)
220
  return gr.Dropdown(splits, value=splits[0], visible=True)
221
  except Exception as e:
222
  # Dataset may not exist
223
- print(e)
224
  pass
225
 
226
- def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split):
227
- print('model_id: ',model_id)
228
- if model_id and dataset_id and dataset_config and dataset_split:
229
- return gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
230
  else:
231
- return gr.update(interactive=False)
232
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  with gr.Row():
234
  model_id_input = gr.Textbox(
235
  label="Hugging Face model id",
@@ -245,22 +271,10 @@ with gr.Blocks(theme=theme) as iface:
245
  dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
246
 
247
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
248
- dataset_config_input.change(check_dataset_and_get_split, dataset_config_input, dataset_split_input)
249
-
250
- with gr.Row():
251
- validate_btn = gr.Button("Validate Model and Dataset", variant="primary", interactive=False)
252
- model_id_input.change(gate_validate_btn,
253
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
254
- outputs=[validate_btn])
255
- dataset_id_input.change(gate_validate_btn,
256
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
257
- outputs=[validate_btn])
258
- dataset_config_input.change(gate_validate_btn,
259
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
260
- outputs=[validate_btn])
261
- dataset_split_input.change(gate_validate_btn,
262
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
263
- outputs=[validate_btn])
264
 
265
  with gr.Row(visible=True) as loading_row:
266
  gr.Markdown('''
@@ -270,51 +284,45 @@ with gr.Blocks(theme=theme) as iface:
270
  ''')
271
 
272
  with gr.Row(visible=False) as preview_row:
273
- with gr.Column():
274
- id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping")
275
-
276
- gr.Markdown('''
277
- <span style="background-color:#5fc269; color:white">Does this look right? If not, Check and update your feature mapping -></span>
278
- ''')
279
-
280
- example_labels = gr.Label(label='Model Prediction Sample')
281
-
282
-
283
- with gr.Accordion("Advance", open=False):
284
- run_local = gr.Checkbox(value=True, label="Run in this Space")
285
- column_mapping_input = gr.Textbox(
286
- value="",
287
- lines=6,
288
- label="Column mapping",
289
- placeholder="Description of mapping of columns in model to dataset, in json format, e.g.:\n"
290
- '{\n'
291
- ' "text": "context",\n'
292
- ' "label": {0: "Positive", 1: "Negative"}\n'
293
- '}',
294
- )
295
 
 
 
 
 
 
 
 
296
  run_btn = gr.Button(
297
  "Get Evaluation Result",
298
  variant="primary",
299
  interactive=False,
 
300
  )
301
- validate_btn.click(
302
- try_validate,
303
- inputs=[
304
- model_id_input,
305
- dataset_id_input,
306
- dataset_config_input,
307
- dataset_split_input,
308
- ],
309
- outputs=[
310
- loading_row,
311
- preview_row,
312
- run_btn,
313
- example_labels,
314
- id2label_mapping_dataframe,
315
- column_mapping_input,
316
- ],
317
- )
318
 
319
  run_btn.click(
320
  try_submit,
@@ -323,8 +331,6 @@ with gr.Blocks(theme=theme) as iface:
323
  dataset_id_input,
324
  dataset_config_input,
325
  dataset_split_input,
326
- column_mapping_input,
327
- run_local,
328
  ],
329
  outputs=[
330
  run_btn,
 
10
 
11
  from transformers.pipelines import TextClassificationPipeline
12
 
13
+ from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
14
 
15
 
16
  HF_REPO_ID = 'HF_REPO_ID'
 
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
+ def try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
63
  # Validate model
 
64
  if m_id is None:
65
+ gr.Warning('Model is not accessible. Please set your HF_TOKEN if it is a private model.')
66
  return (
 
67
  gr.update(interactive=False), # Submit button
68
+ gr.update(visible=True), # Loading row
69
+ gr.update(visible=False), # Preview row
70
+ gr.update(visible=False), # Model prediction input
71
  gr.update(visible=False), # Model prediction preview
72
  gr.update(visible=False), # Label mapping preview
 
73
  )
74
  if isinstance(ppl, Exception):
75
+ gr.Warning(f'Failed to load model": {ppl}')
76
  return (
 
77
  gr.update(interactive=False), # Submit button
78
+ gr.update(visible=True), # Loading row
79
+ gr.update(visible=False), # Preview row
80
+ gr.update(visible=False), # Model prediction input
81
  gr.update(visible=False), # Model prediction preview
82
  gr.update(visible=False), # Label mapping preview
 
83
  )
84
 
85
  # Validate dataset
 
99
 
100
  if not dataset_ok:
101
  return (
 
102
  gr.update(interactive=False), # Submit button
103
+ gr.update(visible=True), # Loading row
104
+ gr.update(visible=False), # Preview row
105
+ gr.update(visible=False), # Model prediction input
106
  gr.update(visible=False), # Model prediction preview
107
  gr.update(visible=False), # Label mapping preview
108
+ # gr.update(visible=True), # Column mapping
109
  )
110
 
111
  # TODO: Validate column mapping by running once
 
113
  id2label_df = None
114
  if isinstance(ppl, TextClassificationPipeline):
115
  try:
116
+ print('validating phase, ', column_mapping)
117
  column_mapping = json.loads(column_mapping)
118
  except Exception:
119
  column_mapping = {}
120
 
121
+ column_mapping, prediction_input, prediction_result, id2label_df = \
122
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
123
 
124
  column_mapping = json.dumps(column_mapping, indent=2)
125
 
 
 
126
  if prediction_result is None:
127
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
128
  return (
 
129
  gr.update(interactive=False), # Submit button
130
+ gr.update(visible=True), # Loading row
131
+ gr.update(visible=False), # Preview row
132
+ gr.update(visible=False), # Model prediction input
133
  gr.update(visible=False), # Model prediction preview
134
  gr.update(visible=False), # Label mapping preview
135
+ # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
136
  )
137
  elif id2label_df is None:
138
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
139
  return (
 
140
  gr.update(interactive=False), # Submit button
141
+ gr.update(visible=False), # Loading row
142
+ gr.update(visible=True), # Preview row
143
+ gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
144
  gr.update(value=prediction_result, visible=True), # Model prediction preview
145
  gr.update(visible=False), # Label mapping preview
146
+ # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
147
  )
148
 
149
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
150
 
151
  return (
152
+ gr.update(interactive=True), # Submit button
153
  gr.update(visible=False), # Loading row
154
  gr.update(visible=True), # Preview row
155
+ gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
156
  gr.update(value=prediction_result, visible=True), # Model prediction preview
157
+ gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
 
158
  )
159
 
160
 
 
206
 
207
  with gr.Blocks(theme=theme) as iface:
208
  with gr.Tab("Text Classification"):
 
 
209
  def check_dataset_and_get_config(dataset_id):
 
210
  try:
211
  configs = datasets.get_dataset_config_names(dataset_id)
 
212
  return gr.Dropdown(configs, value=configs[0], visible=True)
213
  except Exception:
214
  # Dataset may not exist
215
  pass
216
 
217
+ def check_dataset_and_get_split(dataset_config, dataset_id):
 
218
  try:
219
+ splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
 
220
  return gr.Dropdown(splits, value=splits[0], visible=True)
221
  except Exception as e:
222
  # Dataset may not exist
223
+ gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
224
  pass
225
 
226
+ def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
227
+ column_mapping = '{}'
228
+ m_id, ppl = check_model(model_id=model_id)
229
+
230
+ if id2label_mapping_dataframe is not None:
231
+ column_mapping = id2label_mapping_dataframe.to_json(orient="split")
232
+ if check_column_mapping_keys_validity(column_mapping, ppl) is False:
233
+ gr.Warning('Label mapping table has invalid contents. Please check again.')
234
+ return (gr.update(interactive=False),
235
+ gr.update(),
236
+ gr.update(),
237
+ gr.update(),
238
+ gr.update(),
239
+ gr.update())
240
  else:
241
+ if model_id and dataset_id and dataset_config and dataset_split:
242
+ return try_validate(m_id, ppl, dataset_id, dataset_config, dataset_split, column_mapping)
243
+ else:
244
+ del ppl
245
+
246
+ return (gr.update(interactive=False),
247
+ gr.update(visible=True),
248
+ gr.update(visible=False),
249
+ gr.update(visible=False),
250
+ gr.update(visible=False),
251
+ gr.update(visible=False))
252
+ with gr.Row():
253
+ gr.Markdown('''
254
+ <h1 style="text-align: center;">
255
+ Giskard Evaluator
256
+ </h1>
257
+ Welcome to Giskard Evaluator Space! Get your report immediately by simply input your model id and dataset id below. Follow our leads and improve your model in no time.
258
+ ''')
259
  with gr.Row():
260
  model_id_input = gr.Textbox(
261
  label="Hugging Face model id",
 
271
  dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
272
 
273
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
274
+ dataset_config_input.change(
275
+ check_dataset_and_get_split,
276
+ inputs=[dataset_config_input, dataset_id_input],
277
+ outputs=[dataset_split_input])
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  with gr.Row(visible=True) as loading_row:
280
  gr.Markdown('''
 
284
  ''')
285
 
286
  with gr.Row(visible=False) as preview_row:
287
+ gr.Markdown('''
288
+ <h1 style="text-align: center;">
289
+ Confirm Label Details
290
+ </h1>
291
+ Base on your model and dataset, we inferred this label mapping. **If the mapping is incorrect, please modify it in the table below.**
292
+ ''')
293
+
294
+ with gr.Row():
295
+ id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping", interactive=True, visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
+ with gr.Row():
298
+ example_input = gr.Markdown('Sample Input: ', visible=False)
299
+
300
+ with gr.Row():
301
+ example_labels = gr.Label(label='Model Prediction Sample', visible=False)
302
+
303
+
304
  run_btn = gr.Button(
305
  "Get Evaluation Result",
306
  variant="primary",
307
  interactive=False,
308
+ size="lg",
309
  )
310
+
311
+ model_id_input.change(gate_validate_btn,
312
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
313
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
314
+ dataset_id_input.change(gate_validate_btn,
315
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
316
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
317
+ dataset_config_input.change(gate_validate_btn,
318
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
319
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
320
+ dataset_split_input.change(gate_validate_btn,
321
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
322
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
323
+ id2label_mapping_dataframe.input(gate_validate_btn,
324
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe],
325
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
 
326
 
327
  run_btn.click(
328
  try_submit,
 
331
  dataset_id_input,
332
  dataset_config_input,
333
  dataset_split_input,
 
 
334
  ],
335
  outputs=[
336
  run_btn,
text_classification.py CHANGED
@@ -1,7 +1,6 @@
1
  import datasets
2
-
3
  import logging
4
-
5
  import pandas as pd
6
 
7
 
@@ -36,6 +35,20 @@ def text_classification_map_model_and_dataset_labels(id2label, dataset_features)
36
  return id2label_mapping, dataset_labels
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
40
  # We assume dataset is ok here
41
  ds = datasets.load_dataset(d_id, config)[split]
@@ -72,10 +85,12 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
72
  id2label_mapping = {}
73
  id2label = ppl.model.config.id2label
74
  label2id = {v: k for k, v in id2label.items()}
 
75
  prediction_result = None
76
  try:
77
  # Use the first item to test prediction
78
- results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
 
79
  prediction_result = {
80
  f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
81
  }
@@ -85,33 +100,29 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
85
 
86
  # Infer labels
87
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
88
- if "label" in column_mapping.keys():
89
- if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
90
- logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
91
- return column_mapping, prediction_result, None
92
-
93
- if isinstance(column_mapping["label"], dict):
94
  # Use the column mapping passed by user
95
- for i, model_label in column_mapping["label"].items():
96
- id2label_mapping[model_label] = dataset_labels[int(i)]
97
  elif None in id2label_mapping.values():
98
  column_mapping["label"] = {
99
  i: None for i in id2label.keys()
100
  }
101
  return column_mapping, prediction_result, None
102
 
103
- id2label_mapping = {
104
- v: k for k, v in id2label_mapping.items()
105
  }
106
  id2label_df = pd.DataFrame({
107
- "ID": list(range(len(dataset_labels))),
108
- "Labels": dataset_labels,
109
- "Labels in original model": [f"{id2label_mapping[label]}({label2id[id2label_mapping[label]]})" for label in dataset_labels],
110
  })
111
- if "label" not in column_mapping.keys():
 
112
  # Column mapping should contain original model labels
113
  column_mapping["label"] = {
114
  str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
115
  }
116
 
117
- return column_mapping, prediction_result, id2label_df
 
1
  import datasets
 
2
  import logging
3
+ import json
4
  import pandas as pd
5
 
6
 
 
35
  return id2label_mapping, dataset_labels
36
 
37
 
38
+ def check_column_mapping_keys_validity(column_mapping, ppl):
39
+ # get the element in all the list elements
40
+ column_mapping = json.loads(column_mapping)
41
+ if "data" not in column_mapping.keys():
42
+ return True
43
+ user_labels = set([pair[0] for pair in column_mapping["data"]])
44
+ model_labels = set([pair[1] for pair in column_mapping["data"]])
45
+
46
+ id2label = ppl.model.config.id2label
47
+ original_labels = set(id2label.values())
48
+
49
+ return user_labels == model_labels == original_labels
50
+
51
+
52
  def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split):
53
  # We assume dataset is ok here
54
  ds = datasets.load_dataset(d_id, config)[split]
 
85
  id2label_mapping = {}
86
  id2label = ppl.model.config.id2label
87
  label2id = {v: k for k, v in id2label.items()}
88
+ prediction_input = None
89
  prediction_result = None
90
  try:
91
  # Use the first item to test prediction
92
+ prediction_input = df.head(1).at[0, column_mapping["text"]]
93
+ results = ppl({"text": prediction_input}, top_k=None)
94
  prediction_result = {
95
  f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
96
  }
 
100
 
101
  # Infer labels
102
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
103
+ if "data" in column_mapping.keys():
104
+ if isinstance(column_mapping["data"], list):
 
 
 
 
105
  # Use the column mapping passed by user
106
+ for user_label, model_label in column_mapping["data"]:
107
+ id2label_mapping[model_label] = user_label
108
  elif None in id2label_mapping.values():
109
  column_mapping["label"] = {
110
  i: None for i in id2label.keys()
111
  }
112
  return column_mapping, prediction_result, None
113
 
114
+ prediction_result = {
115
+ f'[{label2id[result["label"]]}]{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result["score"] for result in results
116
  }
117
  id2label_df = pd.DataFrame({
118
+ "Dataset Labels": dataset_labels,
119
+ "Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
 
120
  })
121
+
122
+ if "data" not in column_mapping.keys():
123
  # Column mapping should contain original model labels
124
  column_mapping["label"] = {
125
  str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
126
  }
127
 
128
+ return column_mapping, prediction_input, prediction_result, id2label_df