marinone94 commited on
Commit
1d35cf4
1 Parent(s): e51d4c5

remove columns not in common across datasets

Browse files
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +28 -13
run_speech_recognition_ctc.py CHANGED
@@ -331,7 +331,7 @@ def create_vocabulary_from_data(
331
  batched=True,
332
  batch_size=10000,
333
  keep_in_memory=False,
334
- # remove_columns=datasets["train"].column_names,
335
  )
336
 
337
  # take union of all unique characters in each dataset
@@ -418,6 +418,11 @@ def main():
418
  # 1. First, let's load the dataset
419
  raw_datasets = DatasetDict()
420
 
 
 
 
 
 
421
  if training_args.do_train:
422
 
423
  # Multiple datasets might need to be loaded from HF
@@ -437,18 +442,21 @@ def main():
437
  split=train_split_name,
438
  use_auth_token=data_args.use_auth_token,
439
  )
 
440
  else:
 
 
 
 
 
 
441
  raw_datasets["train"] = concatenate_datasets(
442
  [
443
  raw_datasets["train"],
444
- load_dataset(
445
- dataset_name,
446
- dataset_config_name,
447
- split=train_split_name,
448
- use_auth_token=data_args.use_auth_token,
449
- )
450
  ]
451
  )
 
452
  else:
453
  logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
454
 
@@ -468,6 +476,8 @@ def main():
468
 
469
  if data_args.max_train_samples is not None:
470
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
 
 
471
 
472
  if training_args.do_eval:
473
  # Multiple datasets might need to be loaded from HF
@@ -486,23 +496,28 @@ def main():
486
  split=eval_split_name,
487
  use_auth_token=data_args.use_auth_token,
488
  )
 
489
  else:
 
 
 
 
 
 
490
  raw_datasets["eval"] = concatenate_datasets(
491
  [
492
  raw_datasets["eval"],
493
- load_dataset(
494
- dataset_name,
495
- dataset_config_name,
496
- split=eval_split_name,
497
- use_auth_token=data_args.use_auth_token,
498
- )
499
  ]
500
  )
 
501
  else:
502
  logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
503
 
504
  if data_args.max_eval_samples is not None:
505
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
 
 
506
 
507
  # 2. We remove some special characters from the datasets
508
  # that make training complicated and do not help in transcribing the speech
 
331
  batched=True,
332
  batch_size=10000,
333
  keep_in_memory=False,
334
+ remove_columns=datasets["train"].column_names,
335
  )
336
 
337
  # take union of all unique characters in each dataset
 
418
  # 1. First, let's load the dataset
419
  raw_datasets = DatasetDict()
420
 
421
+ def common_cols(dataset_a, dataset_b):
422
+ col_a = set(dataset_a.column_names)
423
+ col_b = set(dataset_b.column_names)
424
+ return [col for col in col_a if col in col_b]
425
+
426
  if training_args.do_train:
427
 
428
  # Multiple datasets might need to be loaded from HF
 
442
  split=train_split_name,
443
  use_auth_token=data_args.use_auth_token,
444
  )
445
+ min_columns_train = raw_datasets["train"].column_names
446
  else:
447
+ new_dataset = load_dataset(
448
+ dataset_name,
449
+ dataset_config_name,
450
+ split=train_split_name,
451
+ use_auth_token=data_args.use_auth_token,
452
+ )
453
  raw_datasets["train"] = concatenate_datasets(
454
  [
455
  raw_datasets["train"],
456
+ new_dataset
 
 
 
 
 
457
  ]
458
  )
459
+ min_columns_train = common_cols(min_columns, new_dataset.column_names)
460
  else:
461
  logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
462
 
 
476
 
477
  if data_args.max_train_samples is not None:
478
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
479
+ other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
480
+ raw_datasets["train"].remove_columns(other_columns_train)
481
 
482
  if training_args.do_eval:
483
  # Multiple datasets might need to be loaded from HF
 
496
  split=eval_split_name,
497
  use_auth_token=data_args.use_auth_token,
498
  )
499
+ min_columns_eval = raw_datasets["eval"].column_names
500
  else:
501
+ new_dataset = load_dataset(
502
+ dataset_name,
503
+ dataset_config_name,
504
+ split=eval_split_name,
505
+ use_auth_token=data_args.use_auth_token,
506
+ )
507
  raw_datasets["eval"] = concatenate_datasets(
508
  [
509
  raw_datasets["eval"],
510
+ new_dataset
 
 
 
 
 
511
  ]
512
  )
513
+ min_columns_eval = common_cols(min_columns_eval, new_dataset.column_names)
514
  else:
515
  logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
516
 
517
  if data_args.max_eval_samples is not None:
518
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
519
+ other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
520
+ raw_datasets["eval"].remove_columns(other_columns_eval)
521
 
522
  # 2. We remove some special characters from the datasets
523
  # that make training complicated and do not help in transcribing the speech