marinone94 commited on
Commit
79a4bc0
1 Parent(s): 2b994e2

allow for multiple datasets from hf in run

Browse files
Files changed (1) hide show
  1. run_speech_recognition_ctc.py +75 -17
run_speech_recognition_ctc.py CHANGED
@@ -30,7 +30,7 @@ import datasets
30
  import numpy as np
31
  import torch
32
  import wandb
33
- from datasets import DatasetDict, load_dataset, load_metric
34
 
35
  import transformers
36
  from transformers import (
@@ -140,21 +140,33 @@ class DataTrainingArguments:
140
  """
141
 
142
  dataset_name: str = field(
143
- metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
 
 
 
 
144
  )
145
  dataset_config_name: str = field(
146
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
 
 
 
 
147
  )
148
  train_split_name: str = field(
149
  default="train+validation",
150
  metadata={
151
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
 
 
152
  },
153
  )
154
  eval_split_name: str = field(
155
  default="test",
156
  metadata={
157
- "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
 
 
158
  },
159
  )
160
  audio_column_name: str = field(
@@ -407,12 +419,36 @@ def main():
407
  raw_datasets = DatasetDict()
408
 
409
  if training_args.do_train:
410
- raw_datasets["train"] = load_dataset(
411
- data_args.dataset_name,
412
- data_args.dataset_config_name,
413
- split=data_args.train_split_name,
414
- use_auth_token=data_args.use_auth_token,
415
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
418
  raise ValueError(
@@ -432,12 +468,34 @@ def main():
432
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
433
 
434
  if training_args.do_eval:
435
- raw_datasets["eval"] = load_dataset(
436
- data_args.dataset_name,
437
- data_args.dataset_config_name,
438
- split=data_args.eval_split_name,
439
- use_auth_token=data_args.use_auth_token,
440
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
  if data_args.max_eval_samples is not None:
443
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
 
30
  import numpy as np
31
  import torch
32
  import wandb
33
+ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_metric
34
 
35
  import transformers
36
  from transformers import (
 
140
  """
141
 
142
  dataset_name: str = field(
143
+ metadata={
144
+ "help": "The name of the dataset to use (via the datasets library)."
145
+ " To use multiple datasets, specify them separated by a comma."
146
+ " e.g.: 'mozilla-foundation/common_voice_7_0,marinone94/nst_sv'"
147
+ }
148
  )
149
  dataset_config_name: str = field(
150
+ default=None, metadata={
151
+ "help": "The configuration name of the dataset to use (via the datasets library)."
152
+ " To use multiple datasets, specify them separated by a comma."
153
+ " e.g.: 'sv-SE,sv'"
154
+ }
155
  )
156
  train_split_name: str = field(
157
  default="train+validation",
158
  metadata={
159
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train+validation'"
160
+ " To use multiple datasets, specify them separated by a comma."
161
+ " e.g.: 'train+validation,all'"
162
  },
163
  )
164
  eval_split_name: str = field(
165
  default="test",
166
  metadata={
167
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'test'"
168
+ " To use multiple datasets, specify them separated by a comma."
169
+ " e.g.: 'test,None'"
170
  },
171
  )
172
  audio_column_name: str = field(
 
419
  raw_datasets = DatasetDict()
420
 
421
  if training_args.do_train:
422
+
423
+ # Multiple datasets might need to be loaded from HF
424
+ # It assumes they all follow the common voice format
425
+ for (dataset_name, dataset_config_name, train_split_name) in zip(
426
+ data_args.dataset_name.split(","),
427
+ data_args.dataset_config_name.split(","),
428
+ data_args.train_split_name.split(","),
429
+ ):
430
+
431
+
432
+ if train_split_name != "None":
433
+ if "train" not in raw_datasets:
434
+ raw_datasets["train"] = load_dataset(
435
+ dataset_name,
436
+ dataset_config_name,
437
+ split=train_split_name,
438
+ use_auth_token=data_args.use_auth_token,
439
+ )
440
+ else:
441
+ raw_datasets["train"] = concatenate_datasets(
442
+ [
443
+ raw_datasets["train"],
444
+ load_dataset(
445
+ dataset_name,
446
+ dataset_config_name,
447
+ split=train_split_name,
448
+ use_auth_token=data_args.use_auth_token,
449
+ )
450
+ ]
451
+ )
452
 
453
  if data_args.audio_column_name not in raw_datasets["train"].column_names:
454
  raise ValueError(
 
468
  raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
469
 
470
  if training_args.do_eval:
471
+ # Multiple datasets might need to be loaded from HF
472
+ # It assumes they all follow the common voice format
473
+ for (dataset_name, dataset_config_name, eval_split_name) in zip(
474
+ data_args.dataset_name.split(","),
475
+ data_args.dataset_config_name.split(","),
476
+ data_args.eval_split_name.split(","),
477
+ ):
478
+
479
+ if train_split_name != "None":
480
+ if "eval" not in raw_datasets:
481
+ raw_datasets["eval"] = load_dataset(
482
+ dataset_name,
483
+ dataset_config_name,
484
+ split=eval_split_name,
485
+ use_auth_token=data_args.use_auth_token,
486
+ )
487
+ else:
488
+ raw_datasets["eval"] = concatenate_datasets(
489
+ [
490
+ raw_datasets["eval"],
491
+ load_dataset(
492
+ dataset_name,
493
+ dataset_config_name,
494
+ split=train_split_name,
495
+ use_auth_token=data_args.use_auth_token,
496
+ )
497
+ ]
498
+ )
499
 
500
  if data_args.max_eval_samples is not None:
501
  raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))