import torch from datasets import load_dataset, DatasetDict from datasets import Audio from transformers import WhisperFeatureExtractor from transformers import WhisperTokenizer from transformers import WhisperProcessor from transformers import WhisperForConditionalGeneration from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainer from dataclasses import dataclass from typing import Any, Dict, List, Union import evaluate # Functions # Define a Data Collator @dataclass class DataCollatorSpeechSeq2SeqWithPadding: processor: Any def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: # split inputs and labels since they have to be of different lengths and need different padding methods # first treat the audio inputs by simply returning torch tensors input_features = [{"input_features": feature["input_features"]} for feature in features] batch = self.processor.feature_extractor.pad( input_features, return_tensors="pt") # get the tokenized label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] # pad the labels to max length labels_batch = self.processor.tokenizer.pad( label_features, return_tensors="pt") # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill( labels_batch.attention_mask.ne(1), -100) # if bos token is appended in previous tokenization step, # cut bos token here as it's append later anyways if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels return batch def main(): # Metrics def compute_metrics(pred): pred_ids = pred.predictions label_ids = pred.label_ids # replace -100 with the pad_token_id label_ids[label_ids == -100] = tokenizer.pad_token_id # we do not want to group tokens when computing the metrics pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True) wer = 100 * metric.compute(predictions=pred_str, references=label_str) return {"wer": wer} # Prepare dataset def prepare_dataset(batch): # load and resample audio data from 48 to 16kHz audio = batch["audio"] # compute log-Mel input features from input audio array batch["input_features"] = feature_extractor( audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] # encode target text to label ids batch["labels"] = tokenizer(batch["sentence"]).input_ids return batch # Whisper Trainin Script # Map the source and target columns # Whisper expects these to be "audio" and "sentence". Change if anything else in the dataset source = "audio" target = "sentence" # Load a sample dataset speech_data = DatasetDict() # Examples #speech_data["train"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="train", use_auth_token=True) #speech_data["test"] = load_dataset("NbAiLab/NPSC", "16K_mp3_bokmaal", split="test", use_auth_token=True) # speech_data["train"] = load_dataset("NbAiLab/LIA_speech", split="train", use_auth_token=True) #speech_data["test"] = load_dataset("NbAiLab/LIA_speech", split="test", use_auth_token=True) # The smallest dataset I found speech_data["train"] = load_dataset( "mozilla-foundation/common_voice_11_0", "nn-NO", split="train", use_auth_token=True) speech_data["test"] = load_dataset( "mozilla-foundation/common_voice_11_0", "nn-NO", split="test", use_auth_token=True) # Rename columns if "audio" not in speech_data.column_names["train"]: speech_data = speech_data.rename_column(source, "audio") if "sentence" not in speech_data.column_names["train"]: speech_data = speech_data.rename_column(target, "sentence") # Remove not needed columns - Not really sure if this is necessary remove_list = [i for i in speech_data.column_names["train"] if i not in ["audio", "sentence"]] speech_data = speech_data.remove_columns(remove_list) # Initialise feature_extractor = WhisperFeatureExtractor.from_pretrained( "openai/whisper-small") tokenizer = WhisperTokenizer.from_pretrained( "openai/whisper-small", language="Norwegian", task="transcribe") processor = WhisperProcessor.from_pretrained( "openai/whisper-small", language="Norwegian", task="transcribe") data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor) # Prepare data speech_data = speech_data.cast_column("audio", Audio(sampling_rate=16000)) speech_data = speech_data.map( prepare_dataset, remove_columns=speech_data.column_names["train"], num_proc=1) # Metrics metric = evaluate.load("wer") # Initialise a Pretrained model # We need to set use_cache=False here if we want to use gradient accumulation model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-small", use_cache=False) # Overriding generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)): model.config.forced_decoder_ids = None model.config.suppress_tokens = [] # Training arguments training_args = Seq2SeqTrainingArguments( output_dir="../whisper-testrun1", # change to a repo name of your choice per_device_train_batch_size=16, gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size learning_rate=2e-5, warmup_steps=500, max_steps=5000, # Changed from 4000 gradient_checkpointing=True, group_by_length=True, evaluation_strategy="steps", per_device_eval_batch_size=8, predict_with_generate=True, generation_max_length=225, save_steps=500, eval_steps=500, logging_steps=25, report_to=["tensorboard"], load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, push_to_hub=True, ) trainer = Seq2SeqTrainer( args=training_args, model=model, train_dataset=speech_data["train"], eval_dataset=speech_data["test"], data_collator=data_collator, compute_metrics=compute_metrics, tokenizer=processor.feature_extractor, ) # Start training trainer.train() def _mp_fn(index): # For xla_spawn (TPUs) print("The XLA is initiated") main() if __name__ == "__main__": main()