Automatic Speech Recognition
Transformers
4 languages
whisper
whisper-event
Generated from Trainer
Inference Endpoints
marinone94 commited on
Commit
6ea4d4a
1 Parent(s): ee5b1b2

allow multiple languages and datasets

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -49,6 +49,7 @@ from transformers import (
49
  set_seed,
50
  )
51
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
 
52
  from transformers.trainer_pt_utils import IterableDatasetShard
53
  from transformers.trainer_utils import get_last_checkpoint, is_main_process
54
  from transformers.utils import check_min_version, send_example_telemetry
@@ -61,6 +62,9 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
61
 
62
  logger = logging.getLogger(__name__)
63
 
 
 
 
64
  wandb_token = os.environ.get("WANDB_TOKEN", "None")
65
  hf_token = os.environ.get("HF_TOKEN", None)
66
  if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
@@ -160,10 +164,16 @@ class DataTrainingArguments:
160
  Arguments pertaining to what data we are going to input our model for training and eval.
161
  """
162
 
163
- dataset_name: str = field(
 
 
 
 
 
 
164
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
165
  )
166
- dataset_config_name: Optional[str] = field(
167
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
168
  )
169
  text_column: Optional[str] = field(
@@ -232,7 +242,16 @@ class DataTrainingArguments:
232
  default=True,
233
  metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
234
  )
235
- language: str = field(
 
 
 
 
 
 
 
 
 
236
  default=None,
237
  metadata={
238
  "help": (
@@ -273,6 +292,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
273
 
274
  processor: Any
275
  decoder_start_token_id: int
 
276
 
277
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
278
  # split inputs and labels since they have to be of different lengths and need
@@ -280,6 +300,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
280
  model_input_name = self.processor.model_input_names[0]
281
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
282
  label_features = [{"input_ids": feature["labels"]} for feature in features]
 
283
 
284
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
285
 
@@ -292,6 +313,15 @@ class DataCollatorSpeechSeq2SeqWithPadding:
292
  # cut bos token here as it's append later anyways
293
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
294
  labels = labels[:, 1:]
 
 
 
 
 
 
 
 
 
295
 
296
  batch["labels"] = labels
297
 
@@ -316,7 +346,7 @@ def notify_me(recipient, message=None):
316
  from email.mime.text import MIMEText
317
 
318
  msg = MIMEText(message)
319
- msg["Subject"] = "Training is finished!"
320
  msg["From"] = "marinone.auto@gmail.com"
321
  msg["To"] = recipient
322
 
@@ -334,16 +364,26 @@ def load_maybe_streaming_dataset(dataset_names, dataset_config_names, split="tra
334
  each split is loaded individually and then splits combined by taking alternating examples from
335
  each (interleaving).
336
  """
 
 
 
 
337
  if "," in dataset_names or "+" in split:
338
  # load multiple splits separated by the `+` symbol with streaming mode
339
  dataset_splits = []
340
- for dataset_name, dataset_config_name, split_names in zip(
341
- dataset_names.split(","), dataset_config_names.split(","), split.split(",")
342
  ):
343
  for split_name in split_names.split("+"):
344
- dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
 
 
 
 
 
 
345
  dataset_splits.append(dataset)
346
-
347
  # interleave multiple splits to form one dataset
348
  interleaved_dataset = interleave_datasets(dataset_splits)
349
  return interleaved_dataset
@@ -426,20 +466,23 @@ def main():
426
 
427
  if training_args.do_train:
428
  raw_datasets["train"] = load_maybe_streaming_dataset(
429
- data_args.dataset_name,
430
- data_args.dataset_config_name,
431
  split=data_args.train_split_name,
432
  use_auth_token=hf_token if model_args.use_auth_token else None,
433
  streaming=data_args.streaming,
 
 
434
  )
435
 
436
  if training_args.do_eval:
437
  raw_datasets["eval"] = load_maybe_streaming_dataset(
438
- data_args.dataset_name,
439
- data_args.dataset_config_name,
440
  split=data_args.eval_split_name,
441
  use_auth_token=hf_token if model_args.use_auth_token else None,
442
  streaming=data_args.streaming,
 
443
  )
444
 
445
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
@@ -451,6 +494,7 @@ def main():
451
  f"{', '.join(raw_datasets_features)}."
452
  )
453
 
 
454
  if data_args.text_column_name not in raw_datasets_features:
455
  raise ValueError(
456
  f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
@@ -504,9 +548,13 @@ def main():
504
  if model_args.freeze_encoder:
505
  model.freeze_encoder()
506
 
507
- if data_args.language is not None:
508
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
 
509
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
 
 
 
510
 
511
  # 6. Resample speech dataset if necessary
512
  logger.info("*** Resample dataset ***")
@@ -558,6 +606,7 @@ def main():
558
  return batch
559
 
560
  with training_args.main_process_first(desc="dataset map pre-processing"):
 
561
  vectorized_datasets = raw_datasets.map(
562
  prepare_dataset,
563
  remove_columns=raw_datasets_features,
@@ -617,9 +666,14 @@ def main():
617
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
618
 
619
  # 10. Define data collator
 
 
 
 
620
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
621
  processor=processor,
622
  decoder_start_token_id=model.config.decoder_start_token_id,
 
623
  )
624
 
625
  # 11. Configure Trainer
@@ -716,20 +770,24 @@ def main():
716
  if model_args.model_index_name is not None:
717
  kwargs["model_name"] = model_args.model_index_name
718
 
 
 
 
 
 
 
719
  if training_args.push_to_hub:
720
  logger.info("*** Pushing to hub ***")
721
  trainer.push_to_hub(**kwargs)
722
  logger.info("*** Pushed to hub ***")
 
 
723
  else:
724
  logger.info("*** Creating model card ***")
725
  trainer.create_model_card(**kwargs)
726
  logger.info("*** Model card created ***")
727
-
728
- # Training complete notification
729
- logger.info("*** Sending notification ***")
730
- notify_me(recipient="marinone94@gmail.com", message=json.dumps(kwargs, indent=4))
731
-
732
- logger.info("*** Training complete!!! ***")
733
 
734
  return results
735
 
 
49
  set_seed,
50
  )
51
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
52
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
53
  from transformers.trainer_pt_utils import IterableDatasetShard
54
  from transformers.trainer_utils import get_last_checkpoint, is_main_process
55
  from transformers.utils import check_min_version, send_example_telemetry
 
62
 
63
  logger = logging.getLogger(__name__)
64
 
65
+ SENDING_NOTIFICATION = "*** Sending notification to email ***"
66
+ RECIPIENT_ADDRESS = "marinone94@gmail.com"
67
+
68
  wandb_token = os.environ.get("WANDB_TOKEN", "None")
69
  hf_token = os.environ.get("HF_TOKEN", None)
70
  if (hf_token is None or wandb_token == "None") and os.path.exists("./creds.txt"):
 
164
  Arguments pertaining to what data we are going to input our model for training and eval.
165
  """
166
 
167
+ dataset_train_name: str = field(
168
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
169
+ )
170
+ dataset_train_config_name: Optional[str] = field(
171
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
172
+ )
173
+ dataset_eval_name: str = field(
174
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
175
  )
176
+ dataset_eval_config_name: Optional[str] = field(
177
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
178
  )
179
  text_column: Optional[str] = field(
 
242
  default=True,
243
  metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
244
  )
245
+ language_train: str = field(
246
+ default=None,
247
+ metadata={
248
+ "help": (
249
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
250
+ "only. For English speech recognition, it should be set to `None`."
251
+ )
252
+ },
253
+ )
254
+ language_eval: str = field(
255
  default=None,
256
  metadata={
257
  "help": (
 
292
 
293
  processor: Any
294
  decoder_start_token_id: int
295
+ task_id: int
296
 
297
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
298
  # split inputs and labels since they have to be of different lengths and need
 
300
  model_input_name = self.processor.model_input_names[0]
301
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
302
  label_features = [{"input_ids": feature["labels"]} for feature in features]
303
+ lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
304
 
305
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
306
 
 
313
  # cut bos token here as it's append later anyways
314
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
315
  labels = labels[:, 1:]
316
+ lang_token_ids = self.processor.tokenizer(lang_features).input_ids
317
+ # Replace language and task if they are in the beginning, otherwise add them
318
+ if (labels[:, 1] == self.task_id).all().cpu().item():
319
+ labels[:, 0] = lang_token_ids
320
+ labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
321
+ else:
322
+ # convert task id to tensor of labels dim to concatenate
323
+ task_id = torch.full_like(labels[:, 0], self.task_id)
324
+ labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
325
 
326
  batch["labels"] = labels
327
 
 
346
  from email.mime.text import MIMEText
347
 
348
  msg = MIMEText(message)
349
+ msg["Subject"] = "Training updates..."
350
  msg["From"] = "marinone.auto@gmail.com"
351
  msg["To"] = recipient
352
 
 
364
  each split is loaded individually and then splits combined by taking alternating examples from
365
  each (interleaving).
366
  """
367
+ column_names = None
368
+ if "column_names" in kwargs:
369
+ column_names = kwargs.pop("column_names").split(",")
370
+
371
  if "," in dataset_names or "+" in split:
372
  # load multiple splits separated by the `+` symbol with streaming mode
373
  dataset_splits = []
374
+ for dataset_name, dataset_config_name, split_names, lang in zip(
375
+ dataset_names.split(","), dataset_config_names.split(","), split.split(","), kwargs.pop("language").split(",")
376
  ):
377
  for split_name in split_names.split("+"):
378
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
379
+ raw_datasets_features = list(next(iter(dataset.values())).features.keys())
380
+ if column_names[0] not in raw_datasets_features:
381
+ if len(column_names) == 1 or column_names[1] not in raw_datasets_features:
382
+ raise ValueError("Column name not found in dataset.")
383
+ dataset = dataset.rename_columns(column_names[1], column_names[0])
384
+ dataset["language"] = lang
385
  dataset_splits.append(dataset)
386
+
387
  # interleave multiple splits to form one dataset
388
  interleaved_dataset = interleave_datasets(dataset_splits)
389
  return interleaved_dataset
 
466
 
467
  if training_args.do_train:
468
  raw_datasets["train"] = load_maybe_streaming_dataset(
469
+ data_args.dataset_train_name,
470
+ data_args.dataset_train_config_name,
471
  split=data_args.train_split_name,
472
  use_auth_token=hf_token if model_args.use_auth_token else None,
473
  streaming=data_args.streaming,
474
+ column_names=data_args.text_column_name,
475
+ language=data_args.language_train
476
  )
477
 
478
  if training_args.do_eval:
479
  raw_datasets["eval"] = load_maybe_streaming_dataset(
480
+ data_args.dataset_eval_name,
481
+ data_args.dataset_eval_config_name,
482
  split=data_args.eval_split_name,
483
  use_auth_token=hf_token if model_args.use_auth_token else None,
484
  streaming=data_args.streaming,
485
+ language=data_args.language_eval
486
  )
487
 
488
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
 
494
  f"{', '.join(raw_datasets_features)}."
495
  )
496
 
497
+ data_args.text_column_name = data_args.text_column_name.split(",")[0]
498
  if data_args.text_column_name not in raw_datasets_features:
499
  raise ValueError(
500
  f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
 
548
  if model_args.freeze_encoder:
549
  model.freeze_encoder()
550
 
551
+ if data_args.language is not None and len(data_args.language.split(",")) == 1:
552
  # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
553
+ # If more than a langugae is specified, it will be specified in the data collator
554
  tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
555
+ elif data_args.language is not None and len(data_args.language.split(",")) > 1:
556
+ # make sure language and task are not stored in the model config
557
+ model.config.forced_decoder_ids = None
558
 
559
  # 6. Resample speech dataset if necessary
560
  logger.info("*** Resample dataset ***")
 
606
  return batch
607
 
608
  with training_args.main_process_first(desc="dataset map pre-processing"):
609
+ raw_datasets_features.remove("language")
610
  vectorized_datasets = raw_datasets.map(
611
  prepare_dataset,
612
  remove_columns=raw_datasets_features,
 
666
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
667
 
668
  # 10. Define data collator
669
+ task_token = data_args.task
670
+ if not task_token.startswith('<|'):
671
+ task_token = f'<{task_token}>'
672
+ task_id = tokenizer(task_token).input_ids[0]
673
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
674
  processor=processor,
675
  decoder_start_token_id=model.config.decoder_start_token_id,
676
+ task_id=task_id
677
  )
678
 
679
  # 11. Configure Trainer
 
770
  if model_args.model_index_name is not None:
771
  kwargs["model_name"] = model_args.model_index_name
772
 
773
+ # Training complete notification
774
+ logger.info(SENDING_NOTIFICATION)
775
+ notify_me(recipient=RECIPIENT_ADDRESS, message=json.dumps(kwargs, indent=4))
776
+ logger.info("*** Training complete!!! ***")
777
+
778
+
779
  if training_args.push_to_hub:
780
  logger.info("*** Pushing to hub ***")
781
  trainer.push_to_hub(**kwargs)
782
  logger.info("*** Pushed to hub ***")
783
+ logger.info(SENDING_NOTIFICATION)
784
+ notify_me(recipient=RECIPIENT_ADDRESS, message="Model pushed to hub")
785
  else:
786
  logger.info("*** Creating model card ***")
787
  trainer.create_model_card(**kwargs)
788
  logger.info("*** Model card created ***")
789
+ logger.info(SENDING_NOTIFICATION)
790
+ notify_me(recipient=RECIPIENT_ADDRESS, message="Model card created")
 
 
 
 
791
 
792
  return results
793
 
test_run_nordic.sh CHANGED
@@ -1,9 +1,12 @@
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
- --dataset_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC" \
4
- --dataset_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk" \
5
- --language="swedish" \
6
- --train_split_name="train+validation,train+validation,train+validation,train,train+test, train,train+validation" \
 
 
 
7
  --eval_split_name="test" \
8
  --model_index_name="Whisper Tiny Swedish" \
9
  --max_train_samples="64" \
@@ -22,7 +25,7 @@ python $1run_speech_recognition_seq2seq_streaming.py \
22
  --generation_max_length="225" \
23
  --length_column_name="input_length" \
24
  --max_duration_in_seconds="30" \
25
- --text_column_name="sentence" \
26
  --freeze_feature_encoder="False" \
27
  --report_to="wandb" \
28
  --metric_for_best_model="wer" \
 
1
  python $1run_speech_recognition_seq2seq_streaming.py \
2
  --model_name_or_path="openai/whisper-tiny" \
3
+ --dataset_train_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,babelbox/babelbox_voice,NbAiLab/NST,arpelarpe/nota,NbAiLab/NPSC,google/fleurs,google/fleurs,google/fleurs" \
4
+ --dataset_train_config_name="sv-SE,da,nn-NO,,no-distant,,16k_mp3_nynorsk,sv_se,da_dk,nb_no" \
5
+ --language_train="swedish,danish,norwegian,swedish,norwegian,danish,norwegian,swedish,danish,norwegian" \
6
+ --train_split_name="train+validation,train+validation,train+validation,train,train+test,train,train+validation,train+validation,train+validation,train+validation" \
7
+ --dataset_eval_name="mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0,mozilla-foundation/common_voice_11_0" \
8
+ --dataset_eval_config_name="sv-SE,da,nn-NO" \
9
+ --language_eval="swedish,danish,norwegian" \
10
  --eval_split_name="test" \
11
  --model_index_name="Whisper Tiny Swedish" \
12
  --max_train_samples="64" \
 
25
  --generation_max_length="225" \
26
  --length_column_name="input_length" \
27
  --max_duration_in_seconds="30" \
28
+ --text_column_name="sentence,text" \
29
  --freeze_feature_encoder="False" \
30
  --report_to="wandb" \
31
  --metric_for_best_model="wer" \