ZeroCommand commited on
Commit
9e212de
1 Parent(s): 1aa43b4

updated version of ui

Browse files
Files changed (2) hide show
  1. app.py +74 -81
  2. text_classification.py +14 -17
app.py CHANGED
@@ -59,26 +59,28 @@ def check_dataset(dataset_id, dataset_config="default", dataset_split="test"):
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
- def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping):
63
  # Validate model
64
  m_id, ppl = check_model(model_id=model_id)
65
  if m_id is None:
66
  gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
67
  return (
68
- dataset_config, dataset_split,
69
  gr.update(interactive=False), # Submit button
 
 
 
70
  gr.update(visible=False), # Model prediction preview
71
  gr.update(visible=False), # Label mapping preview
72
- gr.update(visible=True), # Column mapping
73
  )
74
  if isinstance(ppl, Exception):
75
  gr.Warning(f'Failed to load "{model_id} model": {ppl}')
76
  return (
77
- dataset_config, dataset_split,
78
  gr.update(interactive=False), # Submit button
 
 
 
79
  gr.update(visible=False), # Model prediction preview
80
  gr.update(visible=False), # Label mapping preview
81
- gr.update(visible=True), # Column mapping
82
  )
83
 
84
  # Validate dataset
@@ -98,11 +100,13 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
98
 
99
  if not dataset_ok:
100
  return (
101
- config, split,
102
  gr.update(interactive=False), # Submit button
 
 
 
103
  gr.update(visible=False), # Model prediction preview
104
  gr.update(visible=False), # Label mapping preview
105
- gr.update(visible=True), # Column mapping
106
  )
107
 
108
  # TODO: Validate column mapping by running once
@@ -110,11 +114,12 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
110
  id2label_df = None
111
  if isinstance(ppl, TextClassificationPipeline):
112
  try:
 
113
  column_mapping = json.loads(column_mapping)
114
  except Exception:
115
  column_mapping = {}
116
 
117
- column_mapping, prediction_result, id2label_df = \
118
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
119
 
120
  column_mapping = json.dumps(column_mapping, indent=2)
@@ -124,31 +129,35 @@ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_map
124
  if prediction_result is None:
125
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
126
  return (
127
- config, split,
128
  gr.update(interactive=False), # Submit button
 
 
 
129
  gr.update(visible=False), # Model prediction preview
130
  gr.update(visible=False), # Label mapping preview
131
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
132
  )
133
  elif id2label_df is None:
134
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
135
  return (
136
- config, split,
137
  gr.update(interactive=False), # Submit button
 
 
 
138
  gr.update(value=prediction_result, visible=True), # Model prediction preview
139
  gr.update(visible=False), # Label mapping preview
140
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
141
  )
142
 
143
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
144
 
145
  return (
 
146
  gr.update(visible=False), # Loading row
147
  gr.update(visible=True), # Preview row
148
- gr.update(interactive=True), # Submit button
149
  gr.update(value=prediction_result, visible=True), # Model prediction preview
150
- gr.update(value=id2label_df, visible=True), # Label mapping preview
151
- gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
152
  )
153
 
154
 
@@ -200,10 +209,7 @@ def try_submit(m_id, d_id, config, split, column_mappings, local):
200
 
201
  with gr.Blocks(theme=theme) as iface:
202
  with gr.Tab("Text Classification"):
203
- global_ds_id = gr.State('ds')
204
-
205
  def check_dataset_and_get_config(dataset_id):
206
- global_ds_id.value = dataset_id
207
  try:
208
  configs = datasets.get_dataset_config_names(dataset_id)
209
  print(configs)
@@ -212,10 +218,9 @@ with gr.Blocks(theme=theme) as iface:
212
  # Dataset may not exist
213
  pass
214
 
215
- def check_dataset_and_get_split(choice):
216
- print('choice: ',choice, global_ds_id.value)
217
  try:
218
- splits = list(datasets.load_dataset(global_ds_id.value, choice).keys())
219
  print('splits: ',splits)
220
  return gr.Dropdown(splits, value=splits[0], visible=True)
221
  except Exception as e:
@@ -223,12 +228,20 @@ with gr.Blocks(theme=theme) as iface:
223
  print(e)
224
  pass
225
 
226
- def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split):
227
  print('model_id: ',model_id)
 
 
 
 
228
  if model_id and dataset_id and dataset_config and dataset_split:
229
- return gr.update(interactive=True)
230
  else:
231
- return gr.update(interactive=False)
 
 
 
 
232
 
233
  with gr.Row():
234
  model_id_input = gr.Textbox(
@@ -245,22 +258,10 @@ with gr.Blocks(theme=theme) as iface:
245
  dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
246
 
247
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
248
- dataset_config_input.change(check_dataset_and_get_split, dataset_config_input, dataset_split_input)
249
-
250
- with gr.Row():
251
- validate_btn = gr.Button("Validate Model and Dataset", variant="primary", interactive=False)
252
- model_id_input.change(gate_validate_btn,
253
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
254
- outputs=[validate_btn])
255
- dataset_id_input.change(gate_validate_btn,
256
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
257
- outputs=[validate_btn])
258
- dataset_config_input.change(gate_validate_btn,
259
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
260
- outputs=[validate_btn])
261
- dataset_split_input.change(gate_validate_btn,
262
- inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
263
- outputs=[validate_btn])
264
 
265
  with gr.Row(visible=True) as loading_row:
266
  gr.Markdown('''
@@ -270,51 +271,45 @@ with gr.Blocks(theme=theme) as iface:
270
  ''')
271
 
272
  with gr.Row(visible=False) as preview_row:
273
- with gr.Column():
274
- id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping")
275
-
276
- gr.Markdown('''
277
- <span style="background-color:#5fc269; color:white">Does this look right? If not, Check and update your feature mapping -></span>
278
- ''')
279
-
280
- example_labels = gr.Label(label='Model Prediction Sample')
281
-
282
-
283
- with gr.Accordion("Advance", open=False):
284
- run_local = gr.Checkbox(value=True, label="Run in this Space")
285
- column_mapping_input = gr.Textbox(
286
- value="",
287
- lines=6,
288
- label="Column mapping",
289
- placeholder="Description of mapping of columns in model to dataset, in json format, e.g.:\n"
290
- '{\n'
291
- ' "text": "context",\n'
292
- ' "label": {0: "Positive", 1: "Negative"}\n'
293
- '}',
294
- )
295
 
 
296
  run_btn = gr.Button(
297
  "Get Evaluation Result",
298
  variant="primary",
299
  interactive=False,
 
300
  )
301
- validate_btn.click(
302
- try_validate,
303
- inputs=[
304
- model_id_input,
305
- dataset_id_input,
306
- dataset_config_input,
307
- dataset_split_input,
308
- ],
309
- outputs=[
310
- loading_row,
311
- preview_row,
312
- run_btn,
313
- example_labels,
314
- id2label_mapping_dataframe,
315
- column_mapping_input,
316
- ],
317
- )
318
 
319
  run_btn.click(
320
  try_submit,
@@ -323,8 +318,6 @@ with gr.Blocks(theme=theme) as iface:
323
  dataset_id_input,
324
  dataset_config_input,
325
  dataset_split_input,
326
- column_mapping_input,
327
- run_local,
328
  ],
329
  outputs=[
330
  run_btn,
 
59
  return dataset_id, None, None
60
  return dataset_id, dataset_config, dataset_split
61
 
62
+ def try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping='{}'):
63
  # Validate model
64
  m_id, ppl = check_model(model_id=model_id)
65
  if m_id is None:
66
  gr.Warning(f'Model "{model_id}" is not accessible. Please set your HF_TOKEN if it is a private model.')
67
  return (
 
68
  gr.update(interactive=False), # Submit button
69
+ gr.update(visible=True), # Loading row
70
+ gr.update(visible=False), # Preview row
71
+ gr.update(visible=False), # Model prediction input
72
  gr.update(visible=False), # Model prediction preview
73
  gr.update(visible=False), # Label mapping preview
 
74
  )
75
  if isinstance(ppl, Exception):
76
  gr.Warning(f'Failed to load "{model_id} model": {ppl}')
77
  return (
 
78
  gr.update(interactive=False), # Submit button
79
+ gr.update(visible=True), # Loading row
80
+ gr.update(visible=False), # Preview row
81
+ gr.update(visible=False), # Model prediction input
82
  gr.update(visible=False), # Model prediction preview
83
  gr.update(visible=False), # Label mapping preview
 
84
  )
85
 
86
  # Validate dataset
 
100
 
101
  if not dataset_ok:
102
  return (
 
103
  gr.update(interactive=False), # Submit button
104
+ gr.update(visible=True), # Loading row
105
+ gr.update(visible=False), # Preview row
106
+ gr.update(visible=False), # Model prediction input
107
  gr.update(visible=False), # Model prediction preview
108
  gr.update(visible=False), # Label mapping preview
109
+ # gr.update(visible=True), # Column mapping
110
  )
111
 
112
  # TODO: Validate column mapping by running once
 
114
  id2label_df = None
115
  if isinstance(ppl, TextClassificationPipeline):
116
  try:
117
+ print('validating phase, ', column_mapping)
118
  column_mapping = json.loads(column_mapping)
119
  except Exception:
120
  column_mapping = {}
121
 
122
+ column_mapping, prediction_input, prediction_result, id2label_df = \
123
  text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, split)
124
 
125
  column_mapping = json.dumps(column_mapping, indent=2)
 
129
  if prediction_result is None:
130
  gr.Warning('The model failed to predict with the first row in the dataset. Please provide column mappings in "Advance" settings.')
131
  return (
 
132
  gr.update(interactive=False), # Submit button
133
+ gr.update(visible=True), # Loading row
134
+ gr.update(visible=False), # Preview row
135
+ gr.update(visible=False), # Model prediction input
136
  gr.update(visible=False), # Model prediction preview
137
  gr.update(visible=False), # Label mapping preview
138
+ # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
139
  )
140
  elif id2label_df is None:
141
  gr.Warning('The prediction result does not conform the labels in the dataset. Please provide label mappings in "Advance" settings.')
142
  return (
 
143
  gr.update(interactive=False), # Submit button
144
+ gr.update(visible=False), # Loading row
145
+ gr.update(visible=True), # Preview row
146
+ gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
147
  gr.update(value=prediction_result, visible=True), # Model prediction preview
148
  gr.update(visible=False), # Label mapping preview
149
+ # gr.update(value=column_mapping, visible=True, interactive=True), # Column mapping
150
  )
151
 
152
  gr.Info("Model and dataset validations passed. Your can submit the evaluation task.")
153
 
154
  return (
155
+ gr.update(interactive=True), # Submit button
156
  gr.update(visible=False), # Loading row
157
  gr.update(visible=True), # Preview row
158
+ gr.update(value=f'**Sample Input**: {prediction_input}', visible=True), # Model prediction input
159
  gr.update(value=prediction_result, visible=True), # Model prediction preview
160
+ gr.update(value=id2label_df, visible=True, interactive=True), # Label mapping preview
 
161
  )
162
 
163
 
 
209
 
210
  with gr.Blocks(theme=theme) as iface:
211
  with gr.Tab("Text Classification"):
 
 
212
  def check_dataset_and_get_config(dataset_id):
 
213
  try:
214
  configs = datasets.get_dataset_config_names(dataset_id)
215
  print(configs)
 
218
  # Dataset may not exist
219
  pass
220
 
221
+ def check_dataset_and_get_split(dataset_config, dataset_id):
 
222
  try:
223
+ splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
224
  print('splits: ',splits)
225
  return gr.Dropdown(splits, value=splits[0], visible=True)
226
  except Exception as e:
 
228
  print(e)
229
  pass
230
 
231
+ def gate_validate_btn(model_id, dataset_id, dataset_config, dataset_split, id2label_mapping_dataframe=None):
232
  print('model_id: ',model_id)
233
+ column_mapping = '{}'
234
+ if id2label_mapping_dataframe is not None:
235
+ column_mapping = id2label_mapping_dataframe.to_json(orient="split")
236
+ print(column_mapping)
237
  if model_id and dataset_id and dataset_config and dataset_split:
238
+ return try_validate(model_id, dataset_id, dataset_config, dataset_split, column_mapping)
239
  else:
240
+ return (gr.update(interactive=False),
241
+ gr.update(visible=True),
242
+ gr.update(visible=False),
243
+ gr.update(visible=False),
244
+ gr.update(visible=False))
245
 
246
  with gr.Row():
247
  model_id_input = gr.Textbox(
 
258
  dataset_split_input = gr.Dropdown(['default'], value=['default'], label='Dataset Split', visible=False)
259
 
260
  dataset_id_input.change(check_dataset_and_get_config, dataset_id_input, dataset_config_input)
261
+ dataset_config_input.change(
262
+ check_dataset_and_get_split,
263
+ inputs=[dataset_config_input, dataset_id_input],
264
+ outputs=[dataset_split_input])
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  with gr.Row(visible=True) as loading_row:
267
  gr.Markdown('''
 
271
  ''')
272
 
273
  with gr.Row(visible=False) as preview_row:
274
+ gr.Markdown('''
275
+ <h1 style="text-align: center;">
276
+ Confirm Label Details
277
+ </h1>
278
+ Base on your model and dataset, we inferred this label mapping. **If the mapping is incorrect, please modify it in the table below.**
279
+ ''')
280
+
281
+ with gr.Row():
282
+ id2label_mapping_dataframe = gr.DataFrame(label="Preview of label mapping", interactive=True, visible=False)
283
+
284
+ with gr.Row():
285
+ example_input = gr.Markdown('Sample Input: ', visible=False)
286
+
287
+ with gr.Row():
288
+ example_labels = gr.Label(label='Model Prediction Sample', visible=False)
 
 
 
 
 
 
 
289
 
290
+
291
  run_btn = gr.Button(
292
  "Get Evaluation Result",
293
  variant="primary",
294
  interactive=False,
295
+ size="lg",
296
  )
297
+
298
+ model_id_input.change(gate_validate_btn,
299
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
300
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
301
+ dataset_id_input.change(gate_validate_btn,
302
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
303
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
304
+ dataset_config_input.change(gate_validate_btn,
305
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
306
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
307
+ dataset_split_input.change(gate_validate_btn,
308
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input],
309
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
310
+ id2label_mapping_dataframe.input(gate_validate_btn,
311
+ inputs=[model_id_input, dataset_id_input, dataset_config_input, dataset_split_input, id2label_mapping_dataframe],
312
+ outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe])
 
313
 
314
  run_btn.click(
315
  try_submit,
 
318
  dataset_id_input,
319
  dataset_config_input,
320
  dataset_split_input,
 
 
321
  ],
322
  outputs=[
323
  run_btn,
text_classification.py CHANGED
@@ -72,10 +72,12 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
72
  id2label_mapping = {}
73
  id2label = ppl.model.config.id2label
74
  label2id = {v: k for k, v in id2label.items()}
 
75
  prediction_result = None
76
  try:
77
  # Use the first item to test prediction
78
- results = ppl({"text": df.head(1).at[0, column_mapping["text"]]}, top_k=None)
 
79
  prediction_result = {
80
  f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
81
  }
@@ -85,33 +87,28 @@ def text_classification_fix_column_mapping(column_mapping, ppl, d_id, config, sp
85
 
86
  # Infer labels
87
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
88
- if "label" in column_mapping.keys():
89
- if not isinstance(column_mapping["label"], dict) or set(column_mapping["label"].values()) != set(dataset_labels):
90
- logging.warning(f'Provided {column_mapping["label"]} does not match labels in Dataset')
91
- return column_mapping, prediction_result, None
92
-
93
- if isinstance(column_mapping["label"], dict):
94
  # Use the column mapping passed by user
95
- for i, model_label in column_mapping["label"].items():
96
- id2label_mapping[model_label] = dataset_labels[int(i)]
 
97
  elif None in id2label_mapping.values():
98
  column_mapping["label"] = {
99
  i: None for i in id2label.keys()
100
  }
101
  return column_mapping, prediction_result, None
102
 
103
- id2label_mapping = {
104
- v: k for k, v in id2label_mapping.items()
105
- }
106
  id2label_df = pd.DataFrame({
107
- "ID": list(range(len(dataset_labels))),
108
- "Labels": dataset_labels,
109
- "Labels in original model": [f"{id2label_mapping[label]}({label2id[id2label_mapping[label]]})" for label in dataset_labels],
110
  })
111
- if "label" not in column_mapping.keys():
 
112
  # Column mapping should contain original model labels
113
  column_mapping["label"] = {
114
  str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
115
  }
116
 
117
- return column_mapping, prediction_result, id2label_df
 
72
  id2label_mapping = {}
73
  id2label = ppl.model.config.id2label
74
  label2id = {v: k for k, v in id2label.items()}
75
+ prediction_input = None
76
  prediction_result = None
77
  try:
78
  # Use the first item to test prediction
79
+ prediction_input = df.head(1).at[0, column_mapping["text"]]
80
+ results = ppl({"text": prediction_input}, top_k=None)
81
  prediction_result = {
82
  f'{result["label"]}({label2id[result["label"]]})': result["score"] for result in results
83
  }
 
87
 
88
  # Infer labels
89
  id2label_mapping, dataset_labels = text_classification_map_model_and_dataset_labels(id2label, dataset_features)
90
+ if "data" in column_mapping.keys():
91
+ if isinstance(column_mapping["data"], list):
 
 
 
 
92
  # Use the column mapping passed by user
93
+ for user_label, model_label in column_mapping["data"]:
94
+ print(user_label, model_label)
95
+ id2label_mapping[model_label] = user_label
96
  elif None in id2label_mapping.values():
97
  column_mapping["label"] = {
98
  i: None for i in id2label.keys()
99
  }
100
  return column_mapping, prediction_result, None
101
 
102
+ print(id2label_mapping)
 
 
103
  id2label_df = pd.DataFrame({
104
+ "Dataset Labels": dataset_labels,
105
+ "Model Prediction Labels": [id2label_mapping[label] for label in dataset_labels],
 
106
  })
107
+
108
+ if "data" not in column_mapping.keys():
109
  # Column mapping should contain original model labels
110
  column_mapping["label"] = {
111
  str(i): id2label_mapping[label] for i, label in zip(id2label.keys(), dataset_labels)
112
  }
113
 
114
+ return column_mapping, prediction_input, prediction_result, id2label_df