Automatic Speech Recognition
Transformers
4 languages
whisper
whisper-event
Generated from Trainer
Inference Endpoints
marinone94 commited on
Commit
5e05341
1 Parent(s): 4f87524

reset script

Browse files
run_speech_recognition_seq2seq_streaming.py CHANGED
@@ -20,10 +20,8 @@ with 🤗 Datasets' streaming mode.
20
  # You can also adapt this script for your own sequence to sequence speech
21
  # recognition task. Pointers for this are left as comments.
22
 
23
- import json
24
  import logging
25
  import os
26
- import subprocess
27
  import sys
28
  from dataclasses import dataclass, field
29
  from typing import Any, Dict, List, Optional, Union
@@ -49,12 +47,12 @@ from transformers import (
49
  set_seed,
50
  )
51
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
52
- from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES
53
  from transformers.trainer_pt_utils import IterableDatasetShard
54
  from transformers.trainer_utils import get_last_checkpoint, is_main_process
55
  from transformers.utils import check_min_version, send_example_telemetry
56
  from transformers.utils.versions import require_version
57
 
 
58
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
59
  check_min_version("4.25.0.dev0")
60
 
@@ -62,8 +60,6 @@ require_version("datasets>=1.18.2", "To fix: pip install -r examples/pytorch/spe
62
 
63
  logger = logging.getLogger(__name__)
64
 
65
- SENDING_NOTIFICATION = "*** Sending notification to email ***"
66
- RECIPIENT_ADDRESS = "marinone94@gmail.com"
67
 
68
  wandb_token = os.environ.get("WANDB_TOKEN", "None")
69
  hf_token = os.environ.get("HF_TOKEN", None)
@@ -165,16 +161,10 @@ class DataTrainingArguments:
165
  Arguments pertaining to what data we are going to input our model for training and eval.
166
  """
167
 
168
- dataset_train_name: str = field(
169
- default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
170
- )
171
- dataset_train_config_name: Optional[str] = field(
172
- default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
173
- )
174
- dataset_eval_name: str = field(
175
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
176
  )
177
- dataset_eval_config_name: Optional[str] = field(
178
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
179
  )
180
  text_column: Optional[str] = field(
@@ -243,16 +233,7 @@ class DataTrainingArguments:
243
  default=True,
244
  metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
245
  )
246
- language_train: str = field(
247
- default=None,
248
- metadata={
249
- "help": (
250
- "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
251
- "only. For English speech recognition, it should be set to `None`."
252
- )
253
- },
254
- )
255
- language_eval: str = field(
256
  default=None,
257
  metadata={
258
  "help": (
@@ -293,9 +274,6 @@ class DataCollatorSpeechSeq2SeqWithPadding:
293
 
294
  processor: Any
295
  decoder_start_token_id: int
296
- task_id: int
297
- # TODO: remove - infer language from dataset
298
- language_id: int = -100
299
 
300
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
301
  # split inputs and labels since they have to be of different lengths and need
@@ -303,7 +281,6 @@ class DataCollatorSpeechSeq2SeqWithPadding:
303
  model_input_name = self.processor.model_input_names[0]
304
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
305
  label_features = [{"input_ids": feature["labels"]} for feature in features]
306
- # lang_features = [f"<|{TO_LANGUAGE_CODE[feature['language']]}|>" for feature in features]
307
 
308
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
309
 
@@ -314,177 +291,40 @@ class DataCollatorSpeechSeq2SeqWithPadding:
314
 
315
  # if bos token is appended in previous tokenization step,
316
  # cut bos token here as it's append later anyways
317
-
318
- # lang_token_ids = self.processor.tokenizer(lang_features).input_ids
319
- # # Replace language and task if they are in the beginning, otherwise add them
320
- # if (labels[:, 1] == self.task_id).all().cpu().item():
321
- # labels[:, 0] = lang_token_ids
322
- # labels[:, 1] = torch.full_like(labels[:, 1], self.task_id)
323
- # else:
324
- # # convert task id to tensor of labels dim to concatenate
325
- # task_id = torch.full_like(labels[:, 0], self.task_id)
326
- # labels = torch.cat((lang_token_ids, task_id, labels), dim=1)
327
-
328
- # Set language to pad token
329
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
330
- labels[:, 1] = torch.full_like(labels[:, 1], -100)
331
- # labels[:, 0] = torch.full_like(labels[:, 0], -100)
332
- # labels[:, 1] = torch.full_like(labels[:, 1], -100)
333
-
334
- # remove start of sentence token from labels
335
- # if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
336
- # labels = labels[:, 1:]
337
-
338
- # # add start of sentence token to labels + language + task
339
- # labels = torch.cat((torch.full_like(labels[:, 0], self.task_id).unsqueeze(0).T, labels), dim=-1)
340
- # labels = torch.cat((torch.full_like(labels[:, 0], self.language_id).unsqueeze(0).T, labels), dim=-1)
341
- # labels = torch.cat((torch.full_like(labels[:, 0], self.decoder_start_token_id).unsqueeze(0).T, labels), dim=-1)
342
 
343
  batch["labels"] = labels
344
 
345
  return batch
346
 
347
 
348
- def notify_me(recipient, message=None):
349
- """
350
- Send an email to the specified address with the specified message
351
- """
352
- sender = os.environ.get("EMAIL_ADDRESS", None)
353
- password = os.environ.get("EMAIL_PASSWORD", None)
354
- if sender is None:
355
- logging.warning("No email address specified, not sending notification")
356
- if password is None:
357
- logging.warning("No email password specified, not sending notification")
358
- if message is None:
359
- message = "Training is finished!"
360
-
361
- if sender is not None:
362
- import smtplib
363
- from email.mime.text import MIMEText
364
-
365
- msg = MIMEText(message)
366
- msg["Subject"] = "Training updates..."
367
- msg["From"] = "marinone.auto@gmail.com"
368
- msg["To"] = recipient
369
-
370
- # send the email
371
- smtp_obj = smtplib.SMTP("smtp.gmail.com", 587)
372
- smtp_obj.starttls()
373
- smtp_obj.login(sender, password)
374
- smtp_obj.sendmail(sender, recipient, msg.as_string())
375
- smtp_obj.quit()
376
-
377
-
378
- def rename_col_and_resample(dataset, dataset_name, text_column_names, text_col_name_ref, audio_column_name, sampling_rate):
379
- raw_datasets_features = list(dataset.features.keys())
380
- logger.info(f"Dataset {dataset_name} - Features: {raw_datasets_features}")
381
-
382
- if text_col_name_ref not in raw_datasets_features:
383
- if len(text_column_names) == 1:
384
- raise ValueError("None of the text column names provided found in dataset."
385
- f"Text columns: {text_column_names}"
386
- f"Dataset columns: {raw_datasets_features}")
387
- flag = False
388
- for text_column_name in text_column_names:
389
- if text_column_name in raw_datasets_features:
390
- logger.info(f"Renaming text column {text_column_name} to {text_col_name_ref}")
391
- dataset = dataset.rename_column(text_column_name, text_col_name_ref)
392
- flag = True
393
- break
394
- if flag is False:
395
- raise ValueError("None of the text column names provided found in dataset."
396
- f"Text columns: {text_column_names}"
397
- f"Dataset columns: {raw_datasets_features}")
398
- if audio_column_name is not None and sampling_rate is not None:
399
- ds_sr = int(dataset.features[audio_column_name].sampling_rate)
400
- if ds_sr != sampling_rate:
401
- dataset = dataset.cast_column(
402
- audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
403
- )
404
-
405
- raw_datasets_features = list(dataset.features.keys())
406
- raw_datasets_features.remove(audio_column_name)
407
- raw_datasets_features.remove(text_col_name_ref)
408
- # Keep only audio and sentence
409
- dataset = dataset.remove_columns(column_names=raw_datasets_features)
410
- return dataset
411
-
412
-
413
- def load_maybe_streaming_dataset(
414
- dataset_names,
415
- dataset_config_names,
416
- split="train",
417
- streaming=True,
418
- audio_column_name=None,
419
- sampling_rate=None,
420
- **kwargs
421
- ):
422
  """
423
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
424
  each split is loaded individually and then splits combined by taking alternating examples from
425
  each (interleaving).
426
  """
427
- text_column_names = None
428
- if "text_column_name" in kwargs:
429
- text_column_names = kwargs.pop("text_column_name").split(",")
430
- text_col_name_ref = text_column_names[0]
431
-
432
- if "," in dataset_names or "+" in split:
433
  # load multiple splits separated by the `+` symbol with streaming mode
434
- dataset_splits = []
435
- for dataset_name, dataset_config_name, split_names in zip(
436
- dataset_names.split(","), dataset_config_names.split(","), split.split(",")
437
- ):
438
- for split_name in split_names.split("+"):
439
- if dataset_config_name:
440
- dataset = load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
441
- else:
442
- dataset = load_dataset(dataset_name, split=split_name, streaming=streaming, **kwargs)
443
-
444
- dataset = rename_col_and_resample(
445
- dataset,
446
- dataset_name,
447
- text_column_names,
448
- text_col_name_ref,
449
- audio_column_name,
450
- sampling_rate
451
- )
452
-
453
- dataset_splits.append(dataset)
454
-
455
  # interleave multiple splits to form one dataset
456
- interleaved_dataset = interleave_datasets(dataset_splits, stopping_strategy="all_exhausted")
457
  return interleaved_dataset
458
  else:
459
  # load a single split *with* streaming mode
460
-
461
- dataset = load_dataset(dataset_names, dataset_config_names, split=split, streaming=streaming, **kwargs)
462
- dataset = rename_col_and_resample(
463
- dataset,
464
- dataset_names,
465
- text_column_names,
466
- text_col_name_ref,
467
- audio_column_name,
468
- sampling_rate
469
- )
470
  return dataset
471
 
472
 
473
- def print_data_samples(dataset, tokenizer, max_samples=5):
474
- shown_samples = 0
475
- for batch in dataset:
476
- print("Target: ", tokenizer.decode(batch["labels"]))
477
- shown_samples += len(batch)
478
- if shown_samples >= max_samples:
479
- break
480
-
481
-
482
  def main():
483
  # 1. Parse input arguments
484
  # See all possible arguments in src/transformers/training_args.py
485
  # or by passing the --help flag to this script.
486
  # We now keep distinct sets of args, for a cleaner separation of concerns.
487
- logger.info("*** Parse args ***")
488
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
489
 
490
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
@@ -499,7 +339,6 @@ def main():
499
  send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
500
 
501
  # 2. Setup logging
502
- logger.info("*** Setup logging ***")
503
  logging.basicConfig(
504
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
505
  datefmt="%m/%d/%Y %H:%M:%S",
@@ -544,94 +383,78 @@ def main():
544
  # Set seed before initializing model.
545
  set_seed(training_args.seed)
546
 
547
- # Load feature extractor
548
- feature_extractor = AutoFeatureExtractor.from_pretrained(
549
- model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
550
- cache_dir=model_args.cache_dir,
551
- revision=model_args.model_revision,
552
- use_auth_token=hf_token if model_args.use_auth_token else None,
553
- )
554
-
555
  # 4. Load dataset
556
- logger.info("*** Load dataset ***")
557
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
558
 
559
- if len(data_args.language_eval.split(",")) > 1:
560
- raise ValueError("Implementation does not support multiple language evaluation.")
561
-
562
  if training_args.do_train:
563
  raw_datasets["train"] = load_maybe_streaming_dataset(
564
- data_args.dataset_train_name,
565
- data_args.dataset_train_config_name,
566
  split=data_args.train_split_name,
567
- use_auth_token=hf_token if model_args.use_auth_token else None,
568
  streaming=data_args.streaming,
569
- text_column_name=data_args.text_column_name,
570
- audio_column_name=data_args.audio_column_name,
571
- sampling_rate=int(feature_extractor.sampling_rate),
572
- # language=data_args.language_train
573
  )
574
 
575
  if training_args.do_eval:
576
  raw_datasets["eval"] = load_maybe_streaming_dataset(
577
- data_args.dataset_eval_name,
578
- data_args.dataset_eval_config_name,
579
  split=data_args.eval_split_name,
580
- use_auth_token=hf_token if model_args.use_auth_token else None,
581
  streaming=data_args.streaming,
582
- text_column_name=data_args.text_column_name,
583
- audio_column_name=data_args.audio_column_name,
584
- sampling_rate=int(feature_extractor.sampling_rate),
585
- # language=data_args.language_eval
586
  )
587
 
588
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
589
 
590
  if data_args.audio_column_name not in raw_datasets_features:
591
  raise ValueError(
592
- f"--audio_column_name '{data_args.audio_column_name}' not found in dataset. "
593
  "Make sure to set `--audio_column_name` to the correct audio column - one of "
594
  f"{', '.join(raw_datasets_features)}."
595
  )
596
 
597
- data_args.text_column_name = data_args.text_column_name.split(",")[0]
598
  if data_args.text_column_name not in raw_datasets_features:
599
  raise ValueError(
600
- f"--text_column_name {data_args.text_column_name} not found in dataset. "
601
  "Make sure to set `--text_column_name` to the correct text column - one of "
602
  f"{', '.join(raw_datasets_features)}."
603
  )
604
 
605
  # 5. Load pretrained model, tokenizer, and feature extractor
606
- logger.info("*** Load pretrained model, tokenizer, and feature extractor ***")
607
  # Distributed training:
608
  # The .from_pretrained methods guarantee that only one local process can concurrently
609
  config = AutoConfig.from_pretrained(
610
  model_args.config_name if model_args.config_name else model_args.model_name_or_path,
611
  cache_dir=model_args.cache_dir,
612
  revision=model_args.model_revision,
613
- use_auth_token=hf_token if model_args.use_auth_token else None
614
  )
615
 
616
- # Forced decoder ids will be overwritten before evaluation
617
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
618
 
619
  if training_args.gradient_checkpointing:
620
  config.update({"use_cache": False})
621
 
 
 
 
 
 
 
622
  tokenizer = AutoTokenizer.from_pretrained(
623
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
624
  cache_dir=model_args.cache_dir,
625
  use_fast=model_args.use_fast_tokenizer,
626
  revision=model_args.model_revision,
627
- use_auth_token=hf_token if model_args.use_auth_token else None,
628
  )
629
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
630
  model_args.model_name_or_path,
631
  config=config,
632
  cache_dir=model_args.cache_dir,
633
  revision=model_args.model_revision,
634
- use_auth_token=hf_token if model_args.use_auth_token else None,
635
  )
636
 
637
  if model.config.decoder_start_token_id is None:
@@ -642,26 +465,20 @@ def main():
642
 
643
  if model_args.freeze_encoder:
644
  model.freeze_encoder()
645
-
646
- tokenizer.set_prefix_tokens(language="swedish", task=data_args.task)
647
 
648
- # if data_args.language_train is not None and len(data_args.language_train.split(",")) == 1:
649
- # # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
650
- # # If more than a langugae is specified, it will be specified in the data collator
651
- # tokenizer.set_prefix_tokens(language=data_args.language_train, task=data_args.task)
652
- # elif data_args.language_train is not None and len(data_args.language_train.split(",")) > 1:
653
- # # make sure language and task are not stored in the model config
654
- # model.config.forced_decoder_ids = None
655
 
656
  # 6. Resample speech dataset if necessary
657
- # logger.info("*** Resample dataset ***")
658
- # dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
659
- # if dataset_sampling_rate != feature_extractor.sampling_rate:
660
-
 
661
 
662
  # 7. Preprocessing the datasets.
663
  # We need to read the audio files as arrays and tokenize the targets.
664
- logger.info("*** Preprocess dataset ***")
665
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
666
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
667
  audio_column_name = data_args.audio_column_name
@@ -701,7 +518,6 @@ def main():
701
  return batch
702
 
703
  with training_args.main_process_first(desc="dataset map pre-processing"):
704
- # raw_datasets_features.remove("language")
705
  vectorized_datasets = raw_datasets.map(
706
  prepare_dataset,
707
  remove_columns=raw_datasets_features,
@@ -726,7 +542,6 @@ def main():
726
  )
727
 
728
  # 8. Load Metric
729
- logger.info("*** Load metric ***")
730
  metric = evaluate.load("wer")
731
  do_normalize_eval = data_args.do_normalize_eval
732
 
@@ -751,7 +566,6 @@ def main():
751
  return {"wer": wer}
752
 
753
  # 9. Create a single speech processor
754
- logger.info("*** Init processor ***")
755
  if is_main_process(training_args.local_rank):
756
  # save feature extractor, tokenizer and config
757
  feature_extractor.save_pretrained(training_args.output_dir)
@@ -761,20 +575,14 @@ def main():
761
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
762
 
763
  # 10. Define data collator
764
- task_token = data_args.task
765
- if not task_token.startswith('<|'):
766
- task_token = f'<{task_token}>'
767
- task_id = tokenizer(task_token).input_ids[0]
768
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
769
  processor=processor,
770
  decoder_start_token_id=model.config.decoder_start_token_id,
771
- task_id=task_id
772
  )
773
 
774
  # 11. Configure Trainer
775
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
776
  # Only required for streaming: Trainer automatically shuffles non-streaming datasets
777
- logger.info("*** Set shuffle callback ***")
778
  class ShuffleCallback(TrainerCallback):
779
  def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
780
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
@@ -782,9 +590,7 @@ def main():
782
  elif isinstance(train_dataloader.dataset, IterableDataset):
783
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
784
 
785
-
786
  # Initialize Trainer
787
- logger.info("*** Init trainer ***")
788
  trainer = Seq2SeqTrainer(
789
  model=model,
790
  args=training_args,
@@ -795,139 +601,63 @@ def main():
795
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
796
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
797
  )
798
- logger.info("*** Trainer initialized ***")
799
-
800
- orig_push_to_hub = trainer.args.push_to_hub
801
- trainer.args.push_to_hub = False
802
 
803
  # 12. Training
804
  if training_args.do_train:
805
- logger.info("*** Train ***")
806
- print_data_samples(vectorized_datasets["train"], tokenizer)
807
  checkpoint = None
808
  if training_args.resume_from_checkpoint is not None:
809
  checkpoint = training_args.resume_from_checkpoint
810
  elif last_checkpoint is not None:
811
  checkpoint = last_checkpoint
812
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
813
- logger.info("*** Training completed ***")
814
- logger.info("*** Saving model ***")
815
- # We don't want to push the model to the hub now
816
- # so we temporarily set to false the push_to_hub attribute
817
- # and then reset it to the original value
818
  trainer.save_model() # Saves the feature extractor too for easy upload
819
- logger.info("*** Model saved ***")
820
  metrics = train_result.metrics
821
  if data_args.max_train_samples:
822
  metrics["train_samples"] = data_args.max_train_samples
823
- logger.info("*** Logging metrics ***")
824
  trainer.log_metrics("train", metrics)
825
- logger.info("*** Metrics logged ***")
826
- logger.info("*** Saving metrics ***")
827
  trainer.save_metrics("train", metrics)
828
- logger.info("*** Metrics saved ***")
829
- logger.info("*** Saving state ***")
830
  trainer.save_state()
831
- logger.info("*** State saved ***")
832
-
833
- # Run a test prediction to check outputs
834
- predictions = trainer.predict(
835
- test_dataset=vectorized_datasets["eval"].shuffle(seed=training_args.seed).take(5),
836
- metric_key_prefix="test",
837
- max_length=training_args.generation_max_length,
838
- num_beams=training_args.generation_num_beams,
839
- )
840
- logger.info("*** Test prediction done ***")
841
- preds = tokenizer.batch_decode(predictions.predictions)
842
- labels = tokenizer.batch_decode(predictions.label_ids)
843
- pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
844
- logger.info("Before setting language and task")
845
- logger.info(f"{pred_labels}")
846
- language_name = LANGUAGES[data_args.language_eval]
847
- trainer.model.config.forced_decoder_ids = \
848
- tokenizer.get_decoder_prompt_ids(language=language_name, task=data_args.task, no_timestamps=True)
849
- preds = tokenizer.batch_decode(predictions.predictions)
850
- labels = tokenizer.batch_decode(predictions.label_ids)
851
- pred_labels = [f"Prediction: {pred}\nLabel: {label}\n" for pred, label in zip(preds, labels)]
852
- logger.info("After setting language and task")
853
- logger.info(f"{pred_labels}")
854
 
855
  # 13. Evaluation
856
  results = {}
857
  if training_args.do_eval:
858
  logger.info("*** Evaluate ***")
859
- print_data_samples(vectorized_datasets["eval"], tokenizer)
860
  metrics = trainer.evaluate(
861
  metric_key_prefix="eval",
862
  max_length=training_args.generation_max_length,
863
  num_beams=training_args.generation_num_beams,
864
  )
865
- logger.info("*** Evaluation done ***")
866
  if data_args.max_eval_samples:
867
  metrics["eval_samples"] = data_args.max_eval_samples
868
- logger.info("*** Logging metrics ***")
869
  trainer.log_metrics("eval", metrics)
870
- logger.info("*** Metrics logged ***")
871
- logger.info("*** Saving metrics ***")
872
  trainer.save_metrics("eval", metrics)
873
- logger.info("*** Metrics saved ***")
874
 
875
  # 14. Write Training Stats
876
- logger.info("*** Writing training stats ***")
877
  kwargs = {
878
  "finetuned_from": model_args.model_name_or_path,
879
  "tasks": "automatic-speech-recognition",
880
  "tags": "whisper-event",
881
  }
882
- if data_args.dataset_train_name is not None:
883
- dataset_names = list(data_args.dataset_train_name.split(","))
884
- kwargs["dataset_tags"] = dataset_names
885
- # if data_args.dataset_train_config_name is not None:
886
- # dataset_config_names = list(data_args.dataset_train_config_name.split(","))
887
- # dataset_config_names_list = [f"{ds_name} {ds_cfg_name}" for ds_name, ds_cfg_name in zip(dataset_names, dataset_config_names)]
888
- # else:
889
- # dataset_config_names_list = dataset_names
890
- # kwargs["dataset"] = "\n".join(dataset_config_names_list)
891
- # if "common_voice" in data_args.dataset_name:
892
- # kwargs["language"] = data_args.dataset_config_name[:2]
893
- if data_args.language_train is not None:
894
- languages = list(set(data_args.language_train.split(",")))
895
- kwargs["language"] = languages
896
  if model_args.model_index_name is not None:
897
  kwargs["model_name"] = model_args.model_index_name
898
 
899
- logger.info("*** Training stats written ***")
900
- logger.info(json.dumps(kwargs, indent=4))
901
-
902
- # Training complete notification
903
- logger.info("*** Training and eval complete ***")
904
- logger.info(SENDING_NOTIFICATION)
905
- with open(os.path.join(training_args.output_dir, "train_results.json"), "r") as f:
906
- train_results = json.load(f)
907
- with open(os.path.join(training_args.output_dir, "eval_results.json"), "r") as f:
908
- eval_results = json.load(f)
909
- notify_me(recipient=RECIPIENT_ADDRESS,
910
- message=f"Training complete! {train_results = } {eval_results = }")
911
-
912
- trainer.args.push_to_hub = orig_push_to_hub
913
  if training_args.push_to_hub:
914
- logger.info("*** Pushing to hub ***")
915
  trainer.push_to_hub(**kwargs)
916
- logger.info("*** Pushed to hub ***")
917
- logger.info(SENDING_NOTIFICATION)
918
  else:
919
- logger.info("*** Creating model card ***")
920
  trainer.create_model_card(**kwargs)
921
- logger.info("*** Model card created ***")
922
- logger.info(SENDING_NOTIFICATION)
923
-
924
- with open(os.path.join(training_args.output_dir, "README.md"), "r") as f:
925
- readme = f.read()
926
- notify_me(recipient=RECIPIENT_ADDRESS,
927
- message=f"Model pushed to hub! {readme = }")
928
 
929
  return results
930
 
931
 
932
  if __name__ == "__main__":
933
- main()
 
20
  # You can also adapt this script for your own sequence to sequence speech
21
  # recognition task. Pointers for this are left as comments.
22
 
 
23
  import logging
24
  import os
 
25
  import sys
26
  from dataclasses import dataclass, field
27
  from typing import Any, Dict, List, Optional, Union
 
47
  set_seed,
48
  )
49
  from transformers.models.whisper.english_normalizer import BasicTextNormalizer
 
50
  from transformers.trainer_pt_utils import IterableDatasetShard
51
  from transformers.trainer_utils import get_last_checkpoint, is_main_process
52
  from transformers.utils import check_min_version, send_example_telemetry
53
  from transformers.utils.versions import require_version
54
 
55
+
56
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57
  check_min_version("4.25.0.dev0")
58
 
 
60
 
61
  logger = logging.getLogger(__name__)
62
 
 
 
63
 
64
  wandb_token = os.environ.get("WANDB_TOKEN", "None")
65
  hf_token = os.environ.get("HF_TOKEN", None)
 
161
  Arguments pertaining to what data we are going to input our model for training and eval.
162
  """
163
 
164
+ dataset_name: str = field(
 
 
 
 
 
 
165
  default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
166
  )
167
+ dataset_config_name: Optional[str] = field(
168
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
169
  )
170
  text_column: Optional[str] = field(
 
233
  default=True,
234
  metadata={"help": "Whether to normalise the references and predictions in the eval WER calculation."},
235
  )
236
+ language: str = field(
 
 
 
 
 
 
 
 
 
237
  default=None,
238
  metadata={
239
  "help": (
 
274
 
275
  processor: Any
276
  decoder_start_token_id: int
 
 
 
277
 
278
  def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
279
  # split inputs and labels since they have to be of different lengths and need
 
281
  model_input_name = self.processor.model_input_names[0]
282
  input_features = [{model_input_name: feature[model_input_name]} for feature in features]
283
  label_features = [{"input_ids": feature["labels"]} for feature in features]
 
284
 
285
  batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
286
 
 
291
 
292
  # if bos token is appended in previous tokenization step,
293
  # cut bos token here as it's append later anyways
 
 
 
 
 
 
 
 
 
 
 
 
294
  if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
295
+ labels = labels[:, 1:]
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  batch["labels"] = labels
298
 
299
  return batch
300
 
301
 
302
+ def load_maybe_streaming_dataset(dataset_name, dataset_config_name, split="train", streaming=True, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  """
304
  Utility function to load a dataset in streaming mode. For datasets with multiple splits,
305
  each split is loaded individually and then splits combined by taking alternating examples from
306
  each (interleaving).
307
  """
308
+ if "+" in split:
 
 
 
 
 
309
  # load multiple splits separated by the `+` symbol with streaming mode
310
+ dataset_splits = [
311
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=streaming, **kwargs)
312
+ for split_name in split.split("+")
313
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  # interleave multiple splits to form one dataset
315
+ interleaved_dataset = interleave_datasets(dataset_splits)
316
  return interleaved_dataset
317
  else:
318
  # load a single split *with* streaming mode
319
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=streaming, **kwargs)
 
 
 
 
 
 
 
 
 
320
  return dataset
321
 
322
 
 
 
 
 
 
 
 
 
 
323
  def main():
324
  # 1. Parse input arguments
325
  # See all possible arguments in src/transformers/training_args.py
326
  # or by passing the --help flag to this script.
327
  # We now keep distinct sets of args, for a cleaner separation of concerns.
 
328
  parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
329
 
330
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
 
339
  send_example_telemetry("run_speech_recognition_seq2seq_streaming", model_args, data_args)
340
 
341
  # 2. Setup logging
 
342
  logging.basicConfig(
343
  format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
344
  datefmt="%m/%d/%Y %H:%M:%S",
 
383
  # Set seed before initializing model.
384
  set_seed(training_args.seed)
385
 
 
 
 
 
 
 
 
 
386
  # 4. Load dataset
 
387
  raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
388
 
 
 
 
389
  if training_args.do_train:
390
  raw_datasets["train"] = load_maybe_streaming_dataset(
391
+ data_args.dataset_name,
392
+ data_args.dataset_config_name,
393
  split=data_args.train_split_name,
394
+ use_auth_token=True if model_args.use_auth_token else None,
395
  streaming=data_args.streaming,
 
 
 
 
396
  )
397
 
398
  if training_args.do_eval:
399
  raw_datasets["eval"] = load_maybe_streaming_dataset(
400
+ data_args.dataset_name,
401
+ data_args.dataset_config_name,
402
  split=data_args.eval_split_name,
403
+ use_auth_token=True if model_args.use_auth_token else None,
404
  streaming=data_args.streaming,
 
 
 
 
405
  )
406
 
407
  raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
408
 
409
  if data_args.audio_column_name not in raw_datasets_features:
410
  raise ValueError(
411
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
412
  "Make sure to set `--audio_column_name` to the correct audio column - one of "
413
  f"{', '.join(raw_datasets_features)}."
414
  )
415
 
 
416
  if data_args.text_column_name not in raw_datasets_features:
417
  raise ValueError(
418
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
419
  "Make sure to set `--text_column_name` to the correct text column - one of "
420
  f"{', '.join(raw_datasets_features)}."
421
  )
422
 
423
  # 5. Load pretrained model, tokenizer, and feature extractor
424
+ #
425
  # Distributed training:
426
  # The .from_pretrained methods guarantee that only one local process can concurrently
427
  config = AutoConfig.from_pretrained(
428
  model_args.config_name if model_args.config_name else model_args.model_name_or_path,
429
  cache_dir=model_args.cache_dir,
430
  revision=model_args.model_revision,
431
+ use_auth_token=True if model_args.use_auth_token else None,
432
  )
433
 
 
434
  config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
435
 
436
  if training_args.gradient_checkpointing:
437
  config.update({"use_cache": False})
438
 
439
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
440
+ model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
441
+ cache_dir=model_args.cache_dir,
442
+ revision=model_args.model_revision,
443
+ use_auth_token=True if model_args.use_auth_token else None,
444
+ )
445
  tokenizer = AutoTokenizer.from_pretrained(
446
  model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
447
  cache_dir=model_args.cache_dir,
448
  use_fast=model_args.use_fast_tokenizer,
449
  revision=model_args.model_revision,
450
+ use_auth_token=True if model_args.use_auth_token else None,
451
  )
452
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
453
  model_args.model_name_or_path,
454
  config=config,
455
  cache_dir=model_args.cache_dir,
456
  revision=model_args.model_revision,
457
+ use_auth_token=True if model_args.use_auth_token else None,
458
  )
459
 
460
  if model.config.decoder_start_token_id is None:
 
465
 
466
  if model_args.freeze_encoder:
467
  model.freeze_encoder()
 
 
468
 
469
+ if data_args.language is not None:
470
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
471
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
 
 
 
 
472
 
473
  # 6. Resample speech dataset if necessary
474
+ dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
475
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
476
+ raw_datasets = raw_datasets.cast_column(
477
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
478
+ )
479
 
480
  # 7. Preprocessing the datasets.
481
  # We need to read the audio files as arrays and tokenize the targets.
 
482
  max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
483
  min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
484
  audio_column_name = data_args.audio_column_name
 
518
  return batch
519
 
520
  with training_args.main_process_first(desc="dataset map pre-processing"):
 
521
  vectorized_datasets = raw_datasets.map(
522
  prepare_dataset,
523
  remove_columns=raw_datasets_features,
 
542
  )
543
 
544
  # 8. Load Metric
 
545
  metric = evaluate.load("wer")
546
  do_normalize_eval = data_args.do_normalize_eval
547
 
 
566
  return {"wer": wer}
567
 
568
  # 9. Create a single speech processor
 
569
  if is_main_process(training_args.local_rank):
570
  # save feature extractor, tokenizer and config
571
  feature_extractor.save_pretrained(training_args.output_dir)
 
575
  processor = AutoProcessor.from_pretrained(training_args.output_dir)
576
 
577
  # 10. Define data collator
 
 
 
 
578
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
579
  processor=processor,
580
  decoder_start_token_id=model.config.decoder_start_token_id,
 
581
  )
582
 
583
  # 11. Configure Trainer
584
  # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
585
  # Only required for streaming: Trainer automatically shuffles non-streaming datasets
 
586
  class ShuffleCallback(TrainerCallback):
587
  def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
588
  if isinstance(train_dataloader.dataset, IterableDatasetShard):
 
590
  elif isinstance(train_dataloader.dataset, IterableDataset):
591
  train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
592
 
 
593
  # Initialize Trainer
 
594
  trainer = Seq2SeqTrainer(
595
  model=model,
596
  args=training_args,
 
601
  compute_metrics=compute_metrics if training_args.predict_with_generate else None,
602
  callbacks=[ShuffleCallback()] if data_args.streaming else None,
603
  )
 
 
 
 
604
 
605
  # 12. Training
606
  if training_args.do_train:
 
 
607
  checkpoint = None
608
  if training_args.resume_from_checkpoint is not None:
609
  checkpoint = training_args.resume_from_checkpoint
610
  elif last_checkpoint is not None:
611
  checkpoint = last_checkpoint
612
  train_result = trainer.train(resume_from_checkpoint=checkpoint)
 
 
 
 
 
613
  trainer.save_model() # Saves the feature extractor too for easy upload
614
+
615
  metrics = train_result.metrics
616
  if data_args.max_train_samples:
617
  metrics["train_samples"] = data_args.max_train_samples
 
618
  trainer.log_metrics("train", metrics)
 
 
619
  trainer.save_metrics("train", metrics)
 
 
620
  trainer.save_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
  # 13. Evaluation
623
  results = {}
624
  if training_args.do_eval:
625
  logger.info("*** Evaluate ***")
 
626
  metrics = trainer.evaluate(
627
  metric_key_prefix="eval",
628
  max_length=training_args.generation_max_length,
629
  num_beams=training_args.generation_num_beams,
630
  )
 
631
  if data_args.max_eval_samples:
632
  metrics["eval_samples"] = data_args.max_eval_samples
633
+
634
  trainer.log_metrics("eval", metrics)
 
 
635
  trainer.save_metrics("eval", metrics)
 
636
 
637
  # 14. Write Training Stats
 
638
  kwargs = {
639
  "finetuned_from": model_args.model_name_or_path,
640
  "tasks": "automatic-speech-recognition",
641
  "tags": "whisper-event",
642
  }
643
+ if data_args.dataset_name is not None:
644
+ kwargs["dataset_tags"] = data_args.dataset_name
645
+ if data_args.dataset_config_name is not None:
646
+ kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
647
+ else:
648
+ kwargs["dataset"] = data_args.dataset_name
649
+ if "common_voice" in data_args.dataset_name:
650
+ kwargs["language"] = data_args.dataset_config_name[:2]
 
 
 
 
 
 
651
  if model_args.model_index_name is not None:
652
  kwargs["model_name"] = model_args.model_index_name
653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  if training_args.push_to_hub:
 
655
  trainer.push_to_hub(**kwargs)
 
 
656
  else:
 
657
  trainer.create_model_card(**kwargs)
 
 
 
 
 
 
 
658
 
659
  return results
660
 
661
 
662
  if __name__ == "__main__":
663
+ main()