File size: 38,434 Bytes
55f3766 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 |
## Training Distil-Whisper
This sub-folder contains all the scripts required to train a Distil-Whisper model in your choice of language. They are
slightly modified from the original scripts used to distill Whisper for English ASR (as-per the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
The main difference is that these scripts are written in [PyTorch](https://pytorch.org), whereas the original scripts
are in [JAX](https://jax.readthedocs.io/en/latest/#)/[Flax](https://flax.readthedocs.io/en/latest/). These scripts are
also made to be easier to run end-to-end, whereas the original scripts require more steps and are somewhat hard-coded
for English ASR. Both sets of scripts achieve equivalent downstream results when the hyper-parameters are set equal.
If you are interested in reproducing the original Distil-Whisper checkpoints, we refer you to the sub-folder [Flax Training](./flax/README.md).
Otherwise, if you wish to distill Whisper on your own language/dataset, we recommend you use these scripts for ease of use
and the configurability they provide.
Reproducing the Distil-Whisper project requires four stages to be completed in successive order:
1. [Pseudo-labelling](#1-pseudo-labelling)
2. [Initialisation](#2-initialisation)
3. [Training](#3-training)
4. [Evaluation](#4-evaluation)
This README is partitioned according to the four stages. Each section provides a minimal example for running the
scripts used in the project. We will use a running example of distilling the Whisper model for Hindi speech recognition
on the Common Voice dataset. Note that this dataset only contains ~20 hours of audio data. Thus, it can be run extremely
quickly, but does not provide sufficient data to achieve optimal performance. We recommend training on upwards of 1000
hours of data should you want to match the performance of Whisper on high-resource languages.
## Requirements
The Distil-Whisper training code is written in [PyTorch](https://pytorch.org) and [Accelerate](https://huggingface.co/docs/accelerate/index).
It heavily leverages the Whisper implementation in [🤗 Transformers](https://github.com/huggingface/transformers) for both
training and inference.
The instructions for installing the package are as follows:
1. Install PyTorch from the [official instructions](https://pytorch.org/get-started/locally/), ensuring you install the correct version for your hardware and CUDA version.
2. Fork the `distil-whisper` repository by clicking on the [fork](https://github.com/huggingface/distil-whisper/fork) button on the reopsitory's page
3. Clone the `distil-whisper` repository and add the base repository as a remote. This will allow you to "pull" any upstream changes that are made to the base repository:
```bash
git clone https://github.com/<your GitHub handle>/distil-whisper.git
cd distil-whisper
git remote add upstream https://github.com/huggingface/distil-whisper.git
```
4. pip install the required packages from the [setup.py](./setup.py) file:
```bash
cd training
pip install -e .
cd ../..
```
5. Configure Accelerate by running the following command. Note that you should set the number of GPUs you wish to use for distillation, and also the data type (dtype) to your preferred dtype for training/inference (e.g. `bfloat16` on A100 GPUs, `float16` on V100 GPUs, etc.):
```bash
accelerate config
```
6. The last thing we need to do is link our Hugging Face account so that we can pull/push model repositories on the Hub. This will allow us to save our final distilled weights on the Hub so that we can share them with the community. Run the command:
```bash
git config --global credential.helper store
huggingface-cli login
```
And then enter an authentication token from https://huggingface.co/settings/tokens. Create a new token if you do not have one already. You should make sure that this token has "write" privileges.
To confirm that you have a working environment, first accept the terms of use of the Common Voice 16.1 dataset on the Hub: https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1
You can run the following code cell to stream one sample of data from the Common Voice dataset, and check that you can
perform inference using the "tiny" Whisper model:
```python
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", low_cpu_mem_usage=True)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model.to("cuda")
common_voice = load_dataset("mozilla-foundation/common_voice_16_1", "en", split="validation", streaming=True)
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
inputs = processor(next(iter(common_voice))["audio"]["array"], sampling_rate=16000, return_tensors="pt")
input_features = inputs.input_features
generated_ids = model.generate(input_features.to("cuda"), max_new_tokens=128)
pred_text = processor.decode(generated_ids[0], skip_special_tokens=True)
print("Pred text:", pred_text)
print("Environment set up successful?", generated_ids.shape[-1] == 20)
```
## 1. Pseudo-Labelling
The python script [`run_pseudo_labelling.py`](run_pseudo_labelling.py) is a flexible inference script that can be used
to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible
with [🤗 Datasets](https://github.com/huggingface/datasets) *streaming mode*, allowing users to load massive audio
datasets with **no disk space requirements**. For more information on streaming mode, the reader is referred to the
blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#streaming-mode-the-silver-bullet).
> As of the latest Distil-Whisper release, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3), this
pseudo-labelling script also performs the added operation of concatenating (or packing) the audio inputs to 30-seconds.
Not only does this lead to a WER improvement when using sequential long-form decoding algorithm, but concatenating audios
to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised.
The following script demonstrates how to pseudo-label the Hindi split of the Common Voice 16.1 dataset with greedy sampling:
```bash
#!/usr/bin/env bash
accelerate launch run_pseudo_labelling.py \
--model_name_or_path "openai/whisper-large-v3" \
--dataset_name "mozilla-foundation/common_voice_16_1" \
--dataset_config_name "hi" \
--dataset_split_name "train+validation+test" \
--text_column_name "sentence" \
--id_column_name "path" \
--output_dir "./common_voice_16_1_hi_pseudo_labelled" \
--wandb_project "distil-whisper-labelling" \
--per_device_eval_batch_size 64 \
--dtype "bfloat16" \
--attn_implementation "sdpa" \
--logging_steps 500 \
--max_label_length 256 \
--concatenate_audio \
--preprocessing_batch_size 500 \
--preprocessing_num_workers 8 \
--dataloader_num_workers 8 \
--report_to "wandb" \
--language "hi" \
--task "transcribe" \
--return_timestamps \
--streaming False \
--generation_num_beams 1 \
--push_to_hub
```
On an 80 GB A100 GPU, the following script takes approximately 5 minutes to concatenate and pre-process the 20 hours of
audio data, and a further 10 minutes to transcribe the pseudo-labels. The pseudo-labelled dataset corresponding to this
script is available on the Hugging Face Hub under [sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled](https://huggingface.co/datasets/sanchit-gandhi/common_voice_16_1_hi_pseudo_labelled).
The WER of the pre-trained Whisper large-v3 model is 17.2% on the test split. We will compare the performance of our distilled model against this number.
There are two noteworthy arguments that configure the dataset concatenation (or packing) process:
1. `concatenate_audio`: whether or not to concatenate (or pack) the audios to 30-second chunks. The latest Distil-Whisper model, [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3#differences-with-distil-large-v2), highlights the WER improvements obtained using the sequential long-form decoding algorithm when concatenated audios are used. Concatenating audios to 30-seconds also improves the throughput during training, since the amount of zero-padding on the audio inputs is minimised. Hence, it is highly recommended to set `--concatenate_audio=True`.
2. `preprocessing_batch_size`: the batch size to use when concatenating (or packing) the audios. Using a larger batch size results in a greater portion of audio samples being packed to 30-seconds, at the expense of higher memory consumption. If you exceed your system's RAM when performing the concatenation operation, reduce the `preprocessing_batch_size` by a factor of 2 to 250 or even 125.
3. `preprocessing_num_workers`: the number of multiprocessing workers to use when concatenating the audios. Using more workers will result in faster pre-processing, at the expense of higher memory consumption. Ensure you do not exceed the maximum number of CPUs on your device.
In addition, the following arguments configure the inference of the Whisper model:
1. `language`: explicitly setting the language token during inference substantially improves the generation performance of the Whisper model, since the model is forced always to predict in the given language. We recommend you set the language to the language you wish to distil the Whisper model on. The only exception is when distilling an English-only model (i.e. where the model id is appended with an `.en`, e.g. `small.en`), the language argument should be set to None, since there is no language token used during training/inference.
2. `return_timestamps`: whether or not to predict timestamps in the pseudo-labels. Timestamp prediction is required should you want your distilled model to be able to predict timestamps at inference time (e.g. for the original OpenAI long-form transcription algorithm). However, the pseudo-labels are marginally less accurate than not using timestamps. We recommend pseudo-labelling **with** timestamps to ensure the distilled model is as general as possible.
3. `attn_implementation`: which attention implementation to use for inference. Set to `sdpa` for [PyTorch SDPA](https://huggingface.co/docs/transformers/v4.35.2/en/perf_infer_gpu_one#bettertransformer), or `flash_attn_2` if your hardware supports Flash Attention 2 and you have the [package installed](https://github.com/Dao-AILab/flash-attention).
4. `streaming`: whether or not to use Datasets' streaming mode. If enabled, the audio data will be streamed from the Hugging Face Hub with no disk space requirements. However, the user is then responsible for adding the pseudo-labels to the dataset script in a follow-up step (see [Using Streaming Mode](#TODO)). If set to `False`, the audio data will be downloaded and pre-processed offline. At the end of pseudo-labelling, the pseudo-labels will be automatically appended to the original dataset, meaning the dataset is ready to be used for the subsequent training step without any additional steps.
5. `generation_num_beams`: how many beams to use while decoding. In practice, we found the distilled model to perform comparably when the data was pseudo-labelled with `generation_num_beams=1` (greedy) or `generation_num_beams>1` (beam). This is likely because the WER filter compensates for the lower quality pseudo-labels obtained using greedy search. However, using `generation_num_beams=1` gives substantially faster inference time for the pseudo-labelling step, and so we recommend this configuration.
Should you have your own audio dataset, you can first [convert it](https://huggingface.co/docs/datasets/audio_dataset) to
Hugging Face Datasets format and push it to the Hugging Face Hub. You can then pseudo-label it using the script above,
replacing the `--dataset_name` with the name of your dataset on the Hub.
Otherwise, you may wish to use an open-source dataset already available on the Hugging Face Hub. We provide a summary of
the three most popular multilingual datasets in the table below. For more details, refer to the blog post: [A Complete Guide to Audio Datasets](https://huggingface.co/blog/audio-datasets#multilingual-speech-recognition).
| Dataset | Languages | Domain | Speaking Style | License | Text Column | ID Column |
|-----------------------------------------------------------------------------------------------|-----------|---------------------------------------|----------------|-----------|---------------------|--------------|
| [Multilingual LibriSpeech](https://huggingface.co/datasets/facebook/multilingual_librispeech) | 6 | Audiobooks | Narrated | CC-BY-4.0 | `"text"` | `"id"` |
| [Common Voice 16](https://huggingface.co/datasets/mozilla-foundation/common_voice_16_1) | 120 | Wikipedia text & crowd-sourced speech | Narrated | CC0-1.0 | `"sentence"` | `"path"` |
| [VoxPopuli](https://huggingface.co/datasets/facebook/voxpopuli) | 15 | European Parliament recordings | Spontaneous | CC0 | `"normalized_text"` | `"audio_id"` |
To achieve *robustness* to different distributions of audio data, it is recommended to train on multiple datasets where possible.
For example, the above three datasets all have splits for the German language. Thus, if distilling a Whisper model for German,
it would be wise to use a combination of the three datasets during training, in order to cover at least three distinct domains
(audiobooks, crowd-sourced speech, parliament recordings). You may wish to use a combination of open-source datasets, or
a combination of open-source and individually owned datasets to cover multiple distributions and domains.
## 2. Initialisation
The script [`create_student_model.py`](create_student_model.py) can be used to initialise a small student model
from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is
initialised by copying maximally spaced layers from the teacher, as per the [DistilBart](https://arxiv.org/abs/2010.13002)
recommendations.
First, we need to create a model repository on the Hugging Face Hub. This repository will contain all the required files
to reproduce the training run, alongside model weights, training logs and a README.md card. You can either create a model
repository directly on the Hugging Face Hub using the link: https://huggingface.co/new. Or, via the CLI, as we'll show here.
Let's pick a name for our distilled model: `distil-whisper-large-v3-hi`. We can run the following command to create a repository under this name:
```bash
huggingface-cli repo create distil-whisper-large-v3-hi
```
We can now see the model on the Hub, e.g. under https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
Let's clone the repository so that we can place our training script and model weights inside:
```bash
git lfs install
git clone https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi
```
Be sure to change the repo address to `https://huggingface.co/<your-user-name>/<your-repo-name>`
We can now copy the relevant training scrips to the repository:
```bash
cd distil-whisper-large-v3-hi
cp ../distil-whisper/training/create_student_model.py .
cp ../distil-whisper/training/run_distillation.py .
```
The following command demonstrates how to initialise a student model from the Whisper [large-v3](https://huggingface.co/openai/whisper-large-v3)
checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers
1 and 32 respectively, as the maximally spaced layers:
```bash
#!/usr/bin/env bash
python create_student_model.py \
--teacher_checkpoint "openai/whisper-large-v3" \
--encoder_layers 32 \
--decoder_layers 2 \
--save_dir "./distil-large-v3-init"
```
The initialised model will be saved to the sub-directory `distil-large-v3-init` in our model repository.
## 3. Training
The script [`run_distillation.py`](run_distillation.py) is an end-to-end script for loading multiple
datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation
from the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), which is a weighted sum of the cross-entropy and
KL-divergence loss terms.
The following command takes the Common Voice dataset that was pseudo-labelled in the first stage and trains the
2-layer decoder model intialised in the previous step. We pass the local path to the pseudo-labelled Common Voice dataset
(`../common_voice_16_1_hi_pseudo_labelled`), which you can change to the path where your local pseudo-labelled dataset is
saved.
In this example, we will combine the train and validation splits to give our training set, and evaluate on the test split
only. This is purely to demonstrate how to combine multiple pseudo-labelled datasets for training, rather than recommended
advice for defining train/validation splits. We advise that you train on the train splits of your dataset, evaluate and
tune hyper-parameters on the validation split, and only test the final checkpoint on the test split. Note how multiple
training datasets and splits can be loaded by separating the dataset arguments by `+` symbols. Thus, the script generalises
to any number of training datasets.
```bash
#!/usr/bin/env bash
accelerate launch run_distillation.py \
--model_name_or_path "./distil-large-v3-init" \
--teacher_model_name_or_path "openai/whisper-large-v3" \
--train_dataset_name "../common_voice_16_1_hi_pseudo_labelled+../common_voice_16_1_hi_pseudo_labelled" \
--train_split_name "train+validation" \
--text_column_name "sentence+sentence" \
--train_dataset_samples "7+4" \
--eval_dataset_name "../common_voice_16_1_hi_pseudo_labelled" \
--eval_split_name "test" \
--eval_text_column_name "sentence" \
--eval_steps 1000 \
--save_steps 1000 \
--warmup_steps 50 \
--learning_rate 0.0001 \
--lr_scheduler_type "constant_with_warmup" \
--timestamp_probability 0.2 \
--condition_on_prev_probability 0.2 \
--language "hi" \
--task "transcribe" \
--logging_steps 25 \
--save_total_limit 1 \
--max_steps 5000 \
--wer_threshold 20 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--dataloader_num_workers 8 \
--preprocessing_num_workers 8 \
--ddp_timeout 7200 \
--dtype "bfloat16" \
--attn_implementation "sdpa" \
--output_dir "./" \
--do_train \
--do_eval \
--gradient_checkpointing \
--overwrite_output_dir \
--predict_with_generate \
--freeze_encoder \
--freeze_embed_positions \
--streaming False \
--push_to_hub
```
The above training script will take approximately 3 hours to complete on an 80 GB A100 GPU and yield a final WER of 76%.
While the generations are starting to take form, there is still a 59% WER gap to the teacher model. This is hardly
surprising give we only have 15 hours of un-filtered data, and closer to just 1.5 hours with data filtering.
As mentioned above, using upwards of 1000 hours of data and training for 10k steps will likely yield
more competitive performance. For the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430), we trained on 21k hours
of audio data for 80k steps. We found that upwards of 13k hours of audio data was required to reach convergence on English
ASR (see Section 9.2 of the [paper](https://arxiv.org/abs/2311.00430)), so the more data you have, the better!
Scaling to multiple GPUs using [distributed data parallelism (DDP)](https://pytorch.org/tutorials/beginner/ddp_series_theory.html)
is trivial: simply run `accelerate config` and select the multi-GPU option, specifying the IDs of the GPUs you wish to use. The
above script can then be run using DDP with no code changes.
Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a
saved checkpoint pushed to the Hugging Face Hub can be found here: [sanchit-gandhi/distil-whisper-large-v3-hi](https://huggingface.co/sanchit-gandhi/distil-whisper-large-v3-hi).
There are a few noteworthy data arguments:
1. `train_dataset_samples`: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
2. `wer_threshold`: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > `wer_threshold` are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong. In our English distillation experiments, we found a WER threshold of 10% provides the optimal trade-off between ensuring high-quality transcriptions, and not filtering unnecessary amounts of training data. For multilingual distillation, the threshold should be set in accordance with the WER achieved by the pre-trained model on the test set.
3. `streaming`: whether or not to use Datasets' streaming mode. Recommended for large datasets, where the audio data can be streamed from the Hugging Face Hub with no disk space requirements.
4. `timestamp_probability`: the per-sample probability for retaining timestamp tokens in the labels (should they contain them). Retaining some portion of timestamp tokens in the training data is required to ensure the distilled model can predict timestamps at inference time. In our experiments, we found that training on timestamps with high-probability hurts the distilled model's transcription performance. Thus, we recommend setting this to a value below 0.5. Typically, a value of 0.2 works well, giving good transcription and timestamp performance.
5. `condition_on_prev_probability`: the per-sample probability for conditioning on previous labels. Conditioning on previous tokens is required to ensure the distilled model can be used with the "sequential" long-form transcription algorithm at inference time. We did not experiment with this parameter, but found values around 0.2 to provide adequate performance. OpenAI pre-trained Whisper on with a 50% probability for conditioning on previous tokens. Thus, you might wish to try higher values.
As well as a few noteworthy model arguments that can be configured to give optimal training performance:
1. `freeze_encoder`: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
2. `freeze_embed_positions`: whether to freeze the student model's decoder positional embeddings. Using the same embed positions as the teacher model, which is designed to handle context lengths up to 448 tokens, helps the student model retain its input id representation up to the full max input length.
3. `dtype`: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.
And finally, a few noteworthy training arguments:
1. `max_steps`: defines the total number of optimisation steps (forward + backward pass) during training. To reach convergence, you should use a dataset of at least 1k hours and train for a minimum of 50k steps.
2. `lr_scheduler_stype`: defines the learning rate schedule, one of `constant_with_warmup` or `linear`. When experimenting with a training set-up or training for very few steps (< 5k), using `constant_with_warmup` is typically beneficial, since the learning rate remains high over the short training run. When performing long training runs (> 5k), using a `linear` schedule generally results in superior downstream performance of the distilled model.
TODO:
- [ ] Template for model cards
## 4. Evaluation
There are four types of evaluation performed in Distil-Whisper:
1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
2. Sequential long form: evaluation on audio samples longer than 30s in duration using the original "sequential" long-form algorithm. Examples include entire TED talks or earnings calls.
3. Chunked long form: evaluation on audio samples longer than 30s in duration using the Transformers "chunked" long-form algorithm.
4. Speculative decoding: evaluation on audio samples less than 30s in duration, where a faster, distilled model is used as the assistant to a slower, teacher model.
All four forms of evaluation are performed using the script [`run_eval.py`](run_eval.py). Unlike the pseudo-labelling
and training scripts, the evaluation script assumes that only one GPU accelerator is used. We can copy the corresponding
evaluation script to the model repository using the following command:
```bash
cp ../distil-whisper/training/run_eval.py .
```
Models are assessed jointly using:
1. The *word-error rate (WER)* metric: measures the numer of substitution, deletion and insertion errors relative to the total number of words. A lower WER indicates a more accurate model.
2. The *inverse real-time factor (RTFx)* metric: measures the ratio of `audio input time : model compute time`. A higher RTFx indicates a faster model.
In all cases, it is particularly important to evaluate the final model on data that is *out-of-distribution (OOD)* with
the training data. Evaluating on OOD data provides insight as to how well the distilled model is likely to generalise to
different audio distributions at inference time. In our example, the Common Voice test set is *in-distribution (ID)*
with our training data, since it is taken from the same distribution as the Common Voice training set. Whereas the FLEURS
test set is OOD, since it is not used as part of the training set.
### Short Form
The script [`run_eval.py`](run_eval.py) can be used to evaluate a trained student model over multiple short-form
validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on
the Common Voice `test` set (ID) and also the FLEURS `test` set (OOD). Again, it leverages streaming mode to bypass
the need to download the data offline:
```bash
#!/usr/bin/env bash
python run_eval.py \
--model_name_or_path "./" \
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
--dataset_config_name "default+hi_in" \
--dataset_split_name "test+test" \
--text_column_name "sentence+transcription" \
--batch_size 16 \
--dtype "bfloat16" \
--generation_max_length 256 \
--language "hi" \
--attn_implementation "sdpa" \
--streaming
```
The student model achieves an average WER of TODO% with an RTFx of TODO for a batch size of 16. We can easily adapt the above
script to evaluate the teacher model, simply by switching the `model_name_or_path` to `openai/whisper-large-v3`, which
achieves an average WER of TODO% with an RTFx of TODO. Therefore, for a batch size of 16, the student model is a factor of TODO
times faster than the teacher. The WER gap can be closed by training on more data (at least 1k hours) for more training
steps (at least 50k).
### Sequential Long Form
The original Whisper paper presents a long-form transcription algorithm that sequentially transcribes 30-second segments
of audio and shifts the sliding window according to the timestamps predicted by the model. This style of sequential
inference is performed directly using the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
method in Transformers.
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
long-form evaluation sets using the sequential algorithm. Since we don't have a long-form validation set for Hindi to hand,
in this example we'll evaluate the official Distil-Whisper model [`distil-large-v3`](https://huggingface.co/distil-whisper/distil-large-v3)
on the TED-LIUM validation set:
```bash
#!/usr/bin/env bash
accelerate launch run_eval.py \
--model_name_or_path "distil-whisper/distil-large-v3" \
--dataset_name "distil-whisper/tedlium-long-form" \
--dataset_config_name "default" \
--dataset_split_name "validation" \
--text_column_name "text" \
--batch_size 16 \
--dtype "bfloat16" \
--generation_max_length 256 \
--language "en" \
--attn_implementation "sdpa" \
--streaming
```
### Chunked Long Form
Chunked long form evaluation runs on the premise that a single long audio file can be *chunked* into smaller segments and
inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction.
A small overlap (or *stride*) is used between adjacent segments to ensure a continuous transcription across chunks.
This style of chunked inference is performed using the [`pipeline`](https://huggingface.co/docs/transformers/main_classes/pipelines)
class, which provides a wrapper around the [`.generate`](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate)
function for long-form inference.
The script [`run_eval.py`](run_eval.py) can be used to evaluate the trained student model on an arbitrary number of
long-form evaluation sets using the pipeline class. Again, in this example we'll evaluate distil-large-v3 on the
TED-LIUM validation set:
```bash
#!/usr/bin/env bash
python run_eval.py \
--model_name_or_path "openai/whisper-large-v3" \
--dataset_name "distil-whisper/tedlium-long-form" \
--dataset_config_name "default" \
--dataset_split_name "validation" \
--text_column_name "text" \
--use_pipeline \
--chunk_length_s 25.0 \
--language "en" \
--return_timestamps \
--dtype "bfloat16" \
--streaming
```
The argument `chunk_length_s` controls the length of the chunked audio samples. It should be set to match the typical
length of audio the student model was trained on. If unsure about what value of `chunk_length_s` is optimal for your case,
it is recommended to run a *sweep* over all possible values. A template script for running a [WandB sweep](https://docs.wandb.ai/guides/sweeps)
can be found under [`run_chunk_length_s_sweep.yaml`](flax/long_form_transcription_scripts/run_chunk_length_s_sweep.yaml).
### Speculative Decoding
Speculative decoding, or assisted generation, relies on the premise that a faster, assistant model can be used to speed-up
the generation of a slower, assistant model. Speculative decoding mathematically ensures that exactly the same outputs as
Whisper are obtained, while being ~2 times faster. This makes it the perfect drop-in replacement for existing Whisper
pipelines, since exactly the same outputs are guaranteed.
Distil-Whisper checkpoints can be designed to be efficient assistant models to Whisper for speculative decoding. More precisely,
by freezing the encoder during training, the distilled model can share the same encoder weights as Whisper during inference, since
the encoder weights are un-changed. In doing so, only the distilled 2-layer decoder has to be loaded in addition to the
original Whisper model, which is approximately an 8% increase to the total parameter count, with up to 2x faster inference
for low batch sizes. For more details on speculative decoding, the reader is advised to refer to the following blog post:
[Speculative Decoding for 2x Faster Whisper Inference](https://huggingface.co/blog/whisper-speculative-decoding).
In the example below, we use our distilled model as an assistant to the large-v3 teacher model during inference:
```bash
#!/usr/bin/env bash
python run_eval.py \
--model_name_or_path "openai/whisper-large-v3" \
--assistant_model_name_or_path "./" \
--dataset_name "../common_voice_16_1_hi_pseudo_labelled+google/fleurs" \
--dataset_config_name "default+hi_in" \
--dataset_split_name "test+test" \
--text_column_name "sentence+transcription" \
--batch_size 16 \
--dtype "bfloat16" \
--generation_max_length 256 \
--language "hi" \
--attn_implementation "sdpa" \
--streaming
```
We see that we achieve a WER of TODO%, the same as what we obtained with the large-v3 model, but with an RTFx of TODO,
a factor of TODO faster than using the large-v3 model alone. The RTFx value can be improved by training the student on
more data and for more training steps, since this will improve the number of predicted tokens that match the teacher
predictions.
## Overview of Training Methods
### 1. Fine-Tuning
For fine-tuning, we take the original Whisper checkpoint and train it on one or more datasets using the standard
cross-entropy loss. As such, there is no involvement from the teacher checkpoint during training, and so the fine-tuned
model is permitted to *overfit* to the distribution of the training data we provide. This makes it appealing for "low-resource"
languages where the original Whisper model performs poorly, since we can boost the performance of the model on a single
language by *overfitting* to that distribution of data. Note that this means the fine-tuned model is prone to loosing
its robustness to different audio distributions, which is the trade-off with improving performance on a specified dataset.
As a rule of thumb, fine-tuning is appropriate for languages where the original Whisper model performs > 20% WER, and we
have a relatively small quantity of training data available (< 1000 hours). With fine-tuning, we require as little as **10 hours**
of training data to significantly boost the performance of the Whisper model. For an in-depth guide to fine-tuning Whisper,
the reader is advised to refer to the blog post: [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper).
### 2. Shrink and Fine-Tune
Shrink and fine-tune (SFT) is a knowledge distillation (KD) technique in which we first *shrink* the teacher model to a
smaller student model by copying maximally spaced layers, and then *fine-tune* the student model on the cross-entropy loss
as described above. Typically, we retain the full encoder from the Whisper model and only shrink the decoder. Retaining
the entire encoder helps significantly with maintaining Whisper's robustness to different audio distributions (_c.f._
Section 9.3 of the [Distil-Whisper paper](https://arxiv.org/abs/2311.00430)).
We can either train the student model on a dataset of (audio, text) pairs as above. Or, we can use the pre-trained
Whisper model to generate *pseudo-labels* for our audio data, and train on the (audio, pseudo-label) pairs.
Pseudo-labels can be used when either:
1. The original text transcriptions are normalised (lower-cased or no punctuation): the Whisper generated pseudo-labels contain both punctuation and casing, and so can be used as a substitute for the normalised transcriptions
2. The pre-trained Whisper model achieves < 20% WER on the languages: we then know the majority of the pseudo-labels will be accurate enough for us to train on.
They are not recommended when both of the following are true:
1. The original text is punctuated and cased
2. The pre-trained Whisper model achieves > 20% WER on the languages: in this case, we want to overfit to the particular distribution of the language, and so train directly on the original text data
To discard inaccurate pseudo-labels during training, we employ a simple WER heuristic to filter our pseudo-labelled
training data. We first normalise the original text and the pseudo-labelled text using the Whisper normaliser. If the
WER between the normalised text exceeds a 10% WER threshold, we discard the training sample. Else, we retain it for training.
Section 9.1 of the Distil-Whisper [paper](https://arxiv.org/abs/2311.00430) demonstrates the importance of using this
threshold for training.
### 3. KL Divergence
In the KL Divergence setting, the student model is initialised by shrinking the teacher as before, and then trained to
match the predictions of the teacher during training.
### Summary of Methods
The following table summarises the two training paradigms: fine-tuning and knowledge distillation (KD). It suggests
minimum values for the pre-trained WER / training data to achieve reasonable performance:
| Method | Pre-Trained WER / % | Training Data / h |
|-------------|---------------------|-------------------|
| Fine-tuning | > 20 | < 1000 |
| KD | < 20 | > 1000 |
## Acknowledgements
* OpenAI for the Whisper [model](https://huggingface.co/openai/whisper-large-v3) and [original codebase](https://github.com/openai/whisper)
* Hugging Face 🤗 [Transformers](https://github.com/huggingface/transformers) for the Whisper model implementation
* Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) program for Cloud TPU v4s used to train the official Distil-Whisper models
* The Hugging Face 🤗 cluster for enabling experimentation with the PyTorch scripts
## Citation
If you use this code-base, please consider citing the Distil-Whisper paper:
```
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
|