Audio Course documentation

Fine-tuning the ASR model

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Fine-tuning the ASR model

In this section, we’ll cover a step-by-step guide on fine-tuning Whisper for speech recognition on the Common Voice 13 dataset. We’ll use the ‘small’ version of the model and a relatively lightweight dataset, enabling you to run fine-tuning fairly quickly on any 16GB+ GPU with low disk space requirements, such as the 16GB T4 GPU provided in the Google Colab free tier.

Should you have a smaller GPU or encounter memory issues during training, you can follow the suggestions provided for reducing memory usage. Conversely, should you have access to a larger GPU, you can amend the training arguments to maximise your throughput. Thus, this guide is accessible regardless of your GPU specifications!

Likewise, this guide outlines how to fine-tune the Whisper model for the Dhivehi language. However, the steps covered here generalise to any language in the Common Voice dataset, and more generally to any ASR dataset on the Hugging Face Hub. You can tweak the code to quickly switch to a language of your choice and fine-tune a Whisper model in your native tongue 🌍

Right! Now that’s out the way, let’s get started and kick-off our fine-tuning pipeline!

Prepare Environment

We strongly advise you to upload model checkpoints directly the Hugging Face Hub while training. The Hub provides:

  • Integrated version control: you can be sure that no model checkpoint is lost during training.
  • Tensorboard logs: track important metrics over the course of training.
  • Model cards: document what a model does and its intended use cases.
  • Community: an easy way to share and collaborate with the community! 🤗

Linking the notebook to the Hub is straightforward - it simply requires entering your Hub authentication token when prompted. Find your Hub authentication token here and enter it when prompted:

from huggingface_hub import notebook_login

notebook_login()

Output:

Login successful
Your token has been saved to /root/.huggingface/token

Load Dataset

Common Voice 13 contains approximately ten hours of labelled Dhivehi data, three of which is held-out test data. This is extremely little data for fine-tuning, so we’ll be relying on leveraging the extensive multilingual ASR knowledge acquired by Whisper during pre-training for the low-resource Dhivehi language.

Using 🤗 Datasets, downloading and preparing data is extremely simple. We can download and prepare the Common Voice 13 splits in just one line of code. Since Dhivehi is very low-resource, we’ll combine the train and validation splits to give approximately seven hours of training data. We’ll use the three hours of test data as our held-out test set:

from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="train+validation"
)
common_voice["test"] = load_dataset(
    "mozilla-foundation/common_voice_13_0", "dv", split="test"
)

print(common_voice)

Output:

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 4904
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],
        num_rows: 2212
    })
})
You can change the language identifier from `"dv"` to a language identifier of your choice. To see all possible languages in Common Voice 13, check out the dataset card on the Hugging Face Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0

Most ASR datasets only provide input audio samples (audio) and the corresponding transcribed text (sentence). Common Voice contains additional metadata information, such as accent and locale, which we can disregard for ASR. Keeping the notebook as general as possible, we only consider the input audio and transcribed text for fine-tuning, discarding the additional metadata information:

common_voice = common_voice.select_columns(["audio", "sentence"])

Feature Extractor, Tokenizer and Processor

The ASR pipeline can be de-composed into three stages:

  1. The feature extractor which pre-processes the raw audio-inputs to log-mel spectrograms
  2. The model which performs the sequence-to-sequence mapping
  3. The tokenizer which post-processes the predicted tokens to text

In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, called WhisperFeatureExtractor and WhisperTokenizer respectively. To make our lives simple, these two objects are wrapped under a single class, called the WhisperProcessor. We can call the WhisperProcessor to perform both the audio pre-processing and the text token post-processing. In doing so, we only need to keep track of two objects during training: the processor and the model.

When performing multilingual fine-tuning, we need to set the "language" and "task" when instantiating the processor. The "language" should be set to the source audio language, and the task to "transcribe" for speech recognition or "translate" for speech translation. These arguments modify the behaviour of the tokenizer, and should be set correctly to ensure the target labels are encoded properly.

We can see all possible languages supported by Whisper by importing the list of languages:

from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE

TO_LANGUAGE_CODE

If you scroll through this list, you’ll notice that many languages are present, but Dhivehi is one of few that is not! This means that Whisper was not pre-trained on Dhivehi. However, this doesn’t mean that we can’t fine tune Whisper on it. In doing so, we’ll be teaching Whisper a new language, one that the pre-trained checkpoint does not support. That’s pretty cool, right!

When you fine-tune it on a new language, Whisper does a good job at leveraging its knowledge of the other 96 languages it’s pre-trained on. Largely speaking, all modern languages will be linguistically similar to at least one of the 96 languages Whisper already knows, so we’ll fall under this paradigm of cross-lingual knowledge representation.

What we need to do to fine-tune Whisper on a new language is find the language most similar that Whisper was pre-trained on. The Wikipedia article for Dhivehi states that Dhivehi is closely related to the Sinhalese language of Sri Lanka. If we check the language codes again, we can see that Sinhalese is present in the Whisper language set, so we can safely set our language argument to "sinhalese".

Right! We’ll load our processor from the pre-trained checkpoint, setting the language to "sinhalese" and task to "transcribe" as explained above:

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained(
    "openai/whisper-small", language="sinhalese", task="transcribe"
)

It’s worth reiterating that in most circumstances, you’ll find that the language you want to fine-tune on is in the set of pre-training languages, in which case you can simply set the language directly as your source audio language! Note that both of these arguments should be omitted for English-only fine-tuning, where there is only one option for the language ("English") and task ("transcribe").

Pre-Process the Data

Let’s have a look at the dataset features. Pay particular attention to the "audio" column - this details the sampling rate of our audio inputs:

common_voice["train"].features

Output:

{'audio': Audio(sampling_rate=48000, mono=True, decode=True, id=None),
 'sentence': Value(dtype='string', id=None)}

Since our input audio is sampled at 48kHz, we need to downsample it to 16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model.

We’ll set the audio inputs to the correct sampling rate using dataset’s cast_column method. This operation does not change the audio in-place, but rather signals to datasets to resample audio samples on-the-fly when they are loaded:

from datasets import Audio

sampling_rate = processor.feature_extractor.sampling_rate
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=sampling_rate))

Now we can write a function to prepare our data ready for the model:

  1. We load and resample the audio data on a sample-by-sample basis by calling sample["audio"]. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.
  2. We use the feature extractor to compute the log-mel spectrogram input features from our 1-dimensional audio array.
  3. We encode the transcriptions to label ids through the use of the tokenizer.
def prepare_dataset(example):
    audio = example["audio"]

    example = processor(
        audio=audio["array"],
        sampling_rate=audio["sampling_rate"],
        text=example["sentence"],
    )

    # compute input length of audio sample in seconds
    example["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    return example

We can apply the data preparation function to all of our training examples using 🤗 Datasets’ .map method. We’ll remove the columns from the raw training data (the audio and text), leaving just the columns returned by the prepare_dataset function:

common_voice = common_voice.map(
    prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1
)

Finally, we filter any training data with audio samples longer than 30s. These samples would otherwise be truncated by the Whisper feature-extractor which could affect the stability of training. We define a function that returns True for samples that are less than 30s, and False for those that are longer:

max_input_length = 30.0


def is_audio_in_length_range(length):
    return length < max_input_length

We apply our filter function to all samples of our training dataset through 🤗 Datasets’ .filter method:

common_voice["train"] = common_voice["train"].filter(
    is_audio_in_length_range,
    input_columns=["input_length"],
)

Let’s check how much training data we removed through this filtering step:

common_voice["train"]

Output

Dataset({
    features: ['input_features', 'labels', 'input_length'],
    num_rows: 4904
})

Alright! In this case we actually have the same number of samples as before, so there were no samples longer than 30s. This might not be the case if you switch languages, so it’s best to keep this filter step in-place for robustness. With that, we have our data fully prepared for training! Let’s continue and take a look at how we can use this data to fine-tune Whisper.

Training and Evaluation

Now that we’ve prepared our data, we’re ready to dive into the training pipeline. The 🤗 Trainer will do much of the heavy lifting for us. All we have to do is:

  • Define a data collator: the data collator takes our pre-processed data and prepares PyTorch tensors ready for the model.

  • Evaluation metrics: during evaluation, we want to evaluate the model using the word error rate (WER) metric. We need to define a compute_metrics function that handles this computation.

  • Load a pre-trained checkpoint: we need to load a pre-trained checkpoint and configure it correctly for training.

  • Define the training arguments: these will be used by the 🤗 Trainer in constructing the training schedule.

Once we’ve fine-tuned the model, we will evaluate it on the test data to verify that we have correctly trained it to transcribe speech in Dhivehi.

Define a Data Collator

The data collator for a sequence-to-sequence speech model is unique in the sense that it treats the input_features and labels independently: the input_features must be handled by the feature extractor and the labels by the tokenizer.

The input_features are already padded to 30s and converted to a log-Mel spectrogram of fixed dimension, so all we have to do is convert them to batched PyTorch tensors. We do this using the feature extractor’s .pad method with return_tensors=pt. Note that no additional padding is applied here since the inputs are of fixed dimension, the input_features are simply converted to PyTorch tensors.

On the other hand, the labels are un-padded. We first pad the sequences to the maximum length in the batch using the tokenizer’s .pad method. The padding tokens are then replaced by -100 so that these tokens are not taken into account when computing the loss. We then cut the start of transcript token from the beginning of the label sequence as we append it later during training.

We can leverage the WhisperProcessor we defined earlier to perform both the feature extractor and the tokenizer operations:

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [
            {"input_features": feature["input_features"][0]} for feature in features
        ]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100
        )

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

We can now initialise the data collator we’ve just defined:

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

Onwards!

Evaluation Metrics

Next, we define the evaluation metric we’ll use on our evaluation set. We’ll use the Word Error Rate (WER) metric introduced in the section on Evaluation, the ‘de-facto’ metric for assessing ASR systems.

We’ll load the WER metric from 🤗 Evaluate:

import evaluate

metric = evaluate.load("wer")

We then simply have to define a function that takes our model predictions and returns the WER metric. This function, called compute_metrics, first replaces -100 with the pad_token_id in the label_ids (undoing the step we applied in the data collator to ignore padded tokens correctly in the loss). It then decodes the predicted and label ids to strings. Finally, it computes the WER between the predictions and reference labels. Here, we have the option of evaluating with the ‘normalised’ transcriptions and predictions, which have punctuation and casing removed. We recommend you follow this to benefit from the WER improvement obtained by normalising the transcriptions.

from transformers.models.whisper.english_normalizer import BasicTextNormalizer

normalizer = BasicTextNormalizer()


def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    # compute orthographic wer
    wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)

    # compute normalised WER
    pred_str_norm = [normalizer(pred) for pred in pred_str]
    label_str_norm = [normalizer(label) for label in label_str]
    # filtering step to only evaluate the samples that correspond to non-zero references:
    pred_str_norm = [
        pred_str_norm[i] for i in range(len(pred_str_norm)) if len(label_str_norm[i]) > 0
    ]
    label_str_norm = [
        label_str_norm[i]
        for i in range(len(label_str_norm))
        if len(label_str_norm[i]) > 0
    ]

    wer = 100 * metric.compute(predictions=pred_str_norm, references=label_str_norm)

    return {"wer_ortho": wer_ortho, "wer": wer}

Load a Pre-Trained Checkpoint

Now let’s load the pre-trained Whisper small checkpoint. Again, this is trivial through use of 🤗 Transformers!

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

We’ll set use_cache to False for training since we’re using gradient checkpointing and the two are incompatible. We’ll also override two generation arguments to control the behaviour of the model during inference: we’ll force the language and task tokens during generation by setting the language and task arguments, and also re-enable cache for generation to speed-up inference time:

from functools import partial

# disable cache during training since it's incompatible with gradient checkpointing
model.config.use_cache = False

# set language and task for generation and re-enable cache
model.generate = partial(
    model.generate, language="sinhalese", task="transcribe", use_cache=True
)

Define the Training Configuration

In the final step, we define all the parameters related to training. Here, we set the number of training steps to 500. This is enough steps to see a big WER improvement compared to the pre-trained Whisper model, while ensuring that fine-tuning can be run in approximately 45 minutes on a Google Colab free tier. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments docs.

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-dv",  # name on the HF Hub
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    lr_scheduler_type="constant_with_warmup",
    warmup_steps=50,
    max_steps=500,  # increase to 4000 if you have your own GPU or a Colab paid plan
    gradient_checkpointing=True,
    fp16=True,
    fp16_full_eval=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=16,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=500,
    eval_steps=500,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)
If you do not want to upload the model checkpoints to the Hub, set `push_to_hub=False`.

We can forward the training arguments to the 🤗 Trainer along with our model, dataset, data collator and compute_metrics function:

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

And with that, we’re ready to start training!

Training

To launch training, simply execute:

trainer.train()

Training will take approximately 45 minutes depending on your GPU or the one allocated to the Google Colab. Depending on your GPU, it is possible that you will encounter a CUDA "out-of-memory" error when you start training. In this case, you can reduce the per_device_train_batch_size incrementally by factors of 2 and employ gradient_accumulation_steps to compensate.

Output:

Training Loss Epoch Step Validation Loss Wer Ortho Wer
0.136 1.63 500 0.1727 63.8972 14.0661

Our final WER is 14.1% - not bad for seven hours of training data and just 500 training steps! That amounts to a 112% improvement versus the pre-trained model! That means we’ve taken a model that previously had no knowledge about Dhivehi, and fine-tuned it to recognise Dhivehi speech with adequate accuracy in under one hour 🤯

The big question is how this compares to other ASR systems. For that, we can view the autoevaluate leaderboard, a leaderboard that categorises models by language and dataset, and subsequently ranks them according to their WER.

Looking at the leaderboard, we see that our model trained for 500 steps convincingly beats the pre-trained Whisper Small checkpoint that we evaluated in the previous section. Nice job 👏

We see that there are a few checkpoints that do better than the one we trained. The beauty of the Hugging Face Hub is that it’s a collaborative platform - if we don’t have the time or resources to perform a longer training run ourselves, we can load a checkpoint that someone else in the community has trained and been kind enough to share (making sure to thank them for it!). You’ll be able to load these checkpoints in exactly the same way as the pre-trained ones using the pipeline class as we did previously! So there’s nothing stopping you cherry-picking the best model on the leaderboard to use for your task!

We can automatically submit our checkpoint to the leaderboard when we push the training results to the Hub - we simply have to set the appropriate key-word arguments (kwargs). You can change these values to match your dataset, language and model name accordingly:

kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_13_0",
    "dataset": "Common Voice 13",  # a 'pretty' name for the training dataset
    "language": "dv",
    "model_name": "Whisper Small Dv - Sanchit Gandhi",  # a 'pretty' name for your model
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

The training results can now be uploaded to the Hub. To do so, execute the push_to_hub command:

trainer.push_to_hub(**kwargs)

This will save the training logs and model weights under "your-username/the-name-you-picked". For this example, check out the upload at sanchit-gandhi/whisper-small-dv.

While the fine-tuned model yields satisfactory results on the Common Voice 13 Dhivehi test data, it is by no means optimal. The purpose of this guide is to demonstrate how to fine-tune an ASR model using the 🤗 Trainer for multilingual speech recognition.

If you have access to your own GPU or are subscribed to a Google Colab paid plan, you can increase max_steps to 4000 steps to improve the WER further by training for more steps. Training for 4000 steps will take approximately 3-5 hours depending on your GPU and yield WER results approximately 3% lower than training for 500 steps. If you decide to train for 4000 steps, we also recommend changing the learning rate scheduler to a linear schedule (set lr_scheduler_type="linear"), as this will yield an additional performance boost over long training runs.

The results could likely be improved further by optimising the training hyperparameters, such as learning rate and dropout, and using a larger pre-trained checkpoint (medium or large). We leave this as an exercise to the reader.

Sharing Your Model

You can now share this model with anyone using the link on the Hub. They can load it with the identifier "your-username/the-name-you-picked" directly into the pipeline() object. For instance, to load the fine-tuned checkpoint “sanchit-gandhi/whisper-small-dv”:

from transformers import pipeline

pipe = pipeline("automatic-speech-recognition", model="sanchit-gandhi/whisper-small-dv")

Conclusion

In this section, we covered a step-by-step guide on fine-tuning the Whisper model for speech recognition 🤗 Datasets, Transformers and the Hugging Face Hub. We first loaded the Dhivehi subset of the Common Voice 13 dataset and pre-processed it by computing log-mel spectrograms and tokenising the text. We then defined a data collator, evaluation metric and training arguments, before using the 🤗 Trainer to train and evaluate our model. We finished by uploading the fine-tuned model to the Hugging Face Hub, and showcased how to share and use it with the pipeline() class.

If you followed through to this point, you should now have a fine-tuned checkpoint for speech recognition, well done! 🥳 Even more importantly, you’re equipped with all the tools you need to fine-tune the Whisper model on any speech recognition dataset or domain. So what are you waiting for! Pick one of the datasets covered in the section Choosing a Dataset or select a dataset of your own, and see whether you can get state-of-the-art performance! The leaderboard is waiting for you…

< > Update on GitHub