lewtun HF staff commited on
Commit
aacdddf
1 Parent(s): 118ffe4

Handle multiple configs

Browse files
Files changed (2) hide show
  1. app.py +46 -32
  2. utils.py +11 -0
app.py CHANGED
@@ -16,6 +16,7 @@ from utils import (
16
  create_autotrain_project_name,
17
  format_col_mapping,
18
  get_compatible_models,
 
19
  get_dataset_card_url,
20
  get_key,
21
  get_metadata,
@@ -123,16 +124,6 @@ SUPPORTED_METRICS = [
123
  ]
124
 
125
 
126
- def get_config_metadata(config, metadata=None):
127
- if metadata is None:
128
- return None
129
- config_metadata = [m for m in metadata if m["config"] == config]
130
- if len(config_metadata) == 1:
131
- return config_metadata[0]
132
- else:
133
- return None
134
-
135
-
136
  #######
137
  # APP #
138
  #######
@@ -190,10 +181,6 @@ if metadata is None:
190
 
191
  with st.expander("Advanced configuration"):
192
  # Select task
193
- # Hack to filter for unsupported tasks
194
- # TODO(lewtun): remove this once we have SQuAD metrics support
195
- if metadata is not None and metadata[0]["task_id"] in UNSUPPORTED_TASKS:
196
- metadata = None
197
  selected_task = st.selectbox(
198
  "Select a task",
199
  SUPPORTED_TASKS,
@@ -211,6 +198,9 @@ with st.expander("Advanced configuration"):
211
  See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
212
  """,
213
  )
 
 
 
214
 
215
  # Select splits
216
  splits_resp = http_get(
@@ -225,8 +215,8 @@ with st.expander("Advanced configuration"):
225
  if split["config"] == selected_config:
226
  split_names.append(split["split"])
227
 
228
- if metadata is not None:
229
- eval_split = metadata[0]["splits"].get("eval_split", None)
230
  else:
231
  eval_split = None
232
  selected_split = st.selectbox(
@@ -270,12 +260,16 @@ with st.expander("Advanced configuration"):
270
  text_col = st.selectbox(
271
  "This column should contain the text to be classified",
272
  col_names,
273
- index=col_names.index(get_key(metadata[0]["col_mapping"], "text")) if metadata is not None else 0,
 
 
274
  )
275
  target_col = st.selectbox(
276
  "This column should contain the labels associated with the text",
277
  col_names,
278
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
279
  )
280
  col_mapping[text_col] = "text"
281
  col_mapping[target_col] = "target"
@@ -289,11 +283,13 @@ with st.expander("Advanced configuration"):
289
  st.text("")
290
  st.text("")
291
  st.text("")
 
292
  st.markdown("`text2` column")
293
  st.text("")
294
  st.text("")
295
  st.text("")
296
  st.text("")
 
297
  st.markdown("`target` column")
298
  with col2:
299
  text1_col = st.selectbox(
@@ -333,12 +329,16 @@ with st.expander("Advanced configuration"):
333
  tokens_col = st.selectbox(
334
  "This column should contain the array of tokens to be classified",
335
  col_names,
336
- index=col_names.index(get_key(metadata[0]["col_mapping"], "tokens")) if metadata is not None else 0,
 
 
337
  )
338
  tags_col = st.selectbox(
339
  "This column should contain the labels associated with each part of the text",
340
  col_names,
341
- index=col_names.index(get_key(metadata[0]["col_mapping"], "tags")) if metadata is not None else 0,
 
 
342
  )
343
  col_mapping[tokens_col] = "tokens"
344
  col_mapping[tags_col] = "tags"
@@ -355,12 +355,16 @@ with st.expander("Advanced configuration"):
355
  text_col = st.selectbox(
356
  "This column should contain the text to be translated",
357
  col_names,
358
- index=col_names.index(get_key(metadata[0]["col_mapping"], "source")) if metadata is not None else 0,
 
 
359
  )
360
  target_col = st.selectbox(
361
  "This column should contain the target translation",
362
  col_names,
363
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
364
  )
365
  col_mapping[text_col] = "source"
366
  col_mapping[target_col] = "target"
@@ -377,19 +381,23 @@ with st.expander("Advanced configuration"):
377
  text_col = st.selectbox(
378
  "This column should contain the text to be summarized",
379
  col_names,
380
- index=col_names.index(get_key(metadata[0]["col_mapping"], "text")) if metadata is not None else 0,
 
 
381
  )
382
  target_col = st.selectbox(
383
  "This column should contain the target summary",
384
  col_names,
385
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
386
  )
387
  col_mapping[text_col] = "text"
388
  col_mapping[target_col] = "target"
389
 
390
  elif selected_task == "extractive_question_answering":
391
- if metadata is not None:
392
- col_mapping = metadata[0]["col_mapping"]
393
  # Hub YAML parser converts periods to hyphens, so we remap them here
394
  col_mapping = format_col_mapping(col_mapping)
395
  with col1:
@@ -413,22 +421,24 @@ with st.expander("Advanced configuration"):
413
  context_col = st.selectbox(
414
  "This column should contain the question's context",
415
  col_names,
416
- index=col_names.index(get_key(col_mapping, "context")) if metadata is not None else 0,
417
  )
418
  question_col = st.selectbox(
419
  "This column should contain the question to be answered, given the context",
420
  col_names,
421
- index=col_names.index(get_key(col_mapping, "question")) if metadata is not None else 0,
422
  )
423
  answers_text_col = st.selectbox(
424
  "This column should contain example answers to the question, extracted from the context",
425
  col_names,
426
- index=col_names.index(get_key(col_mapping, "answers.text")) if metadata is not None else 0,
427
  )
428
  answers_start_col = st.selectbox(
429
  "This column should contain the indices in the context of the first character of each `answers.text`",
430
  col_names,
431
- index=col_names.index(get_key(col_mapping, "answers.answer_start")) if metadata is not None else 0,
 
 
432
  )
433
  col_mapping[context_col] = "context"
434
  col_mapping[question_col] = "question"
@@ -446,12 +456,16 @@ with st.expander("Advanced configuration"):
446
  image_col = st.selectbox(
447
  "This column should contain the images to be classified",
448
  col_names,
449
- index=col_names.index(get_key(metadata[0]["col_mapping"], "image")) if metadata is not None else 0,
 
 
450
  )
451
  target_col = st.selectbox(
452
  "This column should contain the labels associated with the images",
453
  col_names,
454
- index=col_names.index(get_key(metadata[0]["col_mapping"], "target")) if metadata is not None else 0,
 
 
455
  )
456
  col_mapping[image_col] = "image"
457
  col_mapping[target_col] = "target"
16
  create_autotrain_project_name,
17
  format_col_mapping,
18
  get_compatible_models,
19
+ get_config_metadata,
20
  get_dataset_card_url,
21
  get_key,
22
  get_metadata,
124
  ]
125
 
126
 
 
 
 
 
 
 
 
 
 
 
127
  #######
128
  # APP #
129
  #######
181
 
182
  with st.expander("Advanced configuration"):
183
  # Select task
 
 
 
 
184
  selected_task = st.selectbox(
185
  "Select a task",
186
  SUPPORTED_TASKS,
198
  See the [docs](https://huggingface.co/docs/datasets/master/en/load_hub#configurations) for more details.
199
  """,
200
  )
201
+ # Get metadata for config
202
+ config_metadata = get_config_metadata(selected_config, metadata)
203
+ print(f"INFO -- Config metadata: {config_metadata}")
204
 
205
  # Select splits
206
  splits_resp = http_get(
215
  if split["config"] == selected_config:
216
  split_names.append(split["split"])
217
 
218
+ if config_metadata is not None:
219
+ eval_split = config_metadata["splits"].get("eval_split", None)
220
  else:
221
  eval_split = None
222
  selected_split = st.selectbox(
260
  text_col = st.selectbox(
261
  "This column should contain the text to be classified",
262
  col_names,
263
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
264
+ if config_metadata is not None
265
+ else 0,
266
  )
267
  target_col = st.selectbox(
268
  "This column should contain the labels associated with the text",
269
  col_names,
270
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
271
+ if config_metadata is not None
272
+ else 0,
273
  )
274
  col_mapping[text_col] = "text"
275
  col_mapping[target_col] = "target"
283
  st.text("")
284
  st.text("")
285
  st.text("")
286
+ st.text("")
287
  st.markdown("`text2` column")
288
  st.text("")
289
  st.text("")
290
  st.text("")
291
  st.text("")
292
+ st.text("")
293
  st.markdown("`target` column")
294
  with col2:
295
  text1_col = st.selectbox(
329
  tokens_col = st.selectbox(
330
  "This column should contain the array of tokens to be classified",
331
  col_names,
332
+ index=col_names.index(get_key(config_metadata["col_mapping"], "tokens"))
333
+ if config_metadata is not None
334
+ else 0,
335
  )
336
  tags_col = st.selectbox(
337
  "This column should contain the labels associated with each part of the text",
338
  col_names,
339
+ index=col_names.index(get_key(config_metadata["col_mapping"], "tags"))
340
+ if config_metadata is not None
341
+ else 0,
342
  )
343
  col_mapping[tokens_col] = "tokens"
344
  col_mapping[tags_col] = "tags"
355
  text_col = st.selectbox(
356
  "This column should contain the text to be translated",
357
  col_names,
358
+ index=col_names.index(get_key(config_metadata["col_mapping"], "source"))
359
+ if config_metadata is not None
360
+ else 0,
361
  )
362
  target_col = st.selectbox(
363
  "This column should contain the target translation",
364
  col_names,
365
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
366
+ if config_metadata is not None
367
+ else 0,
368
  )
369
  col_mapping[text_col] = "source"
370
  col_mapping[target_col] = "target"
381
  text_col = st.selectbox(
382
  "This column should contain the text to be summarized",
383
  col_names,
384
+ index=col_names.index(get_key(config_metadata["col_mapping"], "text"))
385
+ if config_metadata is not None
386
+ else 0,
387
  )
388
  target_col = st.selectbox(
389
  "This column should contain the target summary",
390
  col_names,
391
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
392
+ if config_metadata is not None
393
+ else 0,
394
  )
395
  col_mapping[text_col] = "text"
396
  col_mapping[target_col] = "target"
397
 
398
  elif selected_task == "extractive_question_answering":
399
+ if config_metadata is not None:
400
+ col_mapping = config_metadata["col_mapping"]
401
  # Hub YAML parser converts periods to hyphens, so we remap them here
402
  col_mapping = format_col_mapping(col_mapping)
403
  with col1:
421
  context_col = st.selectbox(
422
  "This column should contain the question's context",
423
  col_names,
424
+ index=col_names.index(get_key(col_mapping, "context")) if config_metadata is not None else 0,
425
  )
426
  question_col = st.selectbox(
427
  "This column should contain the question to be answered, given the context",
428
  col_names,
429
+ index=col_names.index(get_key(col_mapping, "question")) if config_metadata is not None else 0,
430
  )
431
  answers_text_col = st.selectbox(
432
  "This column should contain example answers to the question, extracted from the context",
433
  col_names,
434
+ index=col_names.index(get_key(col_mapping, "answers.text")) if config_metadata is not None else 0,
435
  )
436
  answers_start_col = st.selectbox(
437
  "This column should contain the indices in the context of the first character of each `answers.text`",
438
  col_names,
439
+ index=col_names.index(get_key(col_mapping, "answers.answer_start"))
440
+ if config_metadata is not None
441
+ else 0,
442
  )
443
  col_mapping[context_col] = "context"
444
  col_mapping[question_col] = "question"
456
  image_col = st.selectbox(
457
  "This column should contain the images to be classified",
458
  col_names,
459
+ index=col_names.index(get_key(config_metadata["col_mapping"], "image"))
460
+ if config_metadata is not None
461
+ else 0,
462
  )
463
  target_col = st.selectbox(
464
  "This column should contain the labels associated with the images",
465
  col_names,
466
+ index=col_names.index(get_key(config_metadata["col_mapping"], "target"))
467
+ if config_metadata is not None
468
+ else 0,
469
  )
470
  col_mapping[image_col] = "image"
471
  col_mapping[target_col] = "target"
utils.py CHANGED
@@ -198,3 +198,14 @@ def create_autotrain_project_name(dataset_id: str) -> str:
198
  # Project names need to be unique, so we append a random string to guarantee this
199
  project_id = str(uuid.uuid4())[:8]
200
  return f"eval-project-{dataset_id_formatted}-{project_id}"
 
 
 
 
 
 
 
 
 
 
 
198
  # Project names need to be unique, so we append a random string to guarantee this
199
  project_id = str(uuid.uuid4())[:8]
200
  return f"eval-project-{dataset_id_formatted}-{project_id}"
201
+
202
+
203
+ def get_config_metadata(config: str, metadata: List[Dict] = None) -> Union[Dict, None]:
204
+ """Gets the dataset card metadata for the given config."""
205
+ if metadata is None:
206
+ return None
207
+ config_metadata = [m for m in metadata if m["config"] == config]
208
+ if len(config_metadata) == 1:
209
+ return config_metadata[0]
210
+ else:
211
+ return None