training / flax /README.md
Dragon116rus's picture
Saving train state of step 10000
afcbb23 verified

Reproducing Distil-Whisper

This sub-folder contains all the training and inference scripts to reproduce the Distil-Whisper project. Distil-Whisper is written in JAX to leverage the fast training and inference speed offered by TPU v4 hardware. However, it also works efficiently on GPU hardware without any additional code changes.

Reproducing the Distil-Whisper project requires four stages to be completed in successive order:

  1. Pseudo-labelling
  2. Initialisation
  3. Training
  4. Evaluation

This README is partitioned according to the four stages. Each section provides a minimal example for running the scripts used in the project. The final scripts used to train the model are referenced in-line.

It is worth noting that the experiments performed in JAX/Flax have been on English ASR only. For multilingual training code, the PyTorch Training Code can easily be used, facilitating anyone to run Whisper distillation on a language of their choice.

Requirements

Distil-Whisper is written in Python, JAX and Flax, and heavily leverages the Flax Whisper implementation in 🤗 Transformers. The instructions for installing the package are as follows:

  1. Install JAX from the official instructions, ensuring you install the correct version for your hardware (GPU or TPU).
  2. Install the distil_whisper package by cloning the repository and performing an editable installation:
git clone https://github.com/huggingface/distil-whisper.git
cd distil-whisper/training/flax
pip install -e .

Pseudo-Labelling

Pseudo-labelling is the process of generating target text predictions for the input audio data using the teacher model. The generated text labels then replace the ground truth text labels when performing distillation. The rationale for using pseudo-labels instead of ground truth labels is to circumvent the issue of inconsistent transcription formatting across datasets.

The python script run_pseudo_labelling.py is a flexible inference script that can be used to generate pseudo-labels under a range of settings, including using both greedy and beam-search. It is also compatible with 🤗 Datasets streaming mode, allowing users to load massive audio datasets with no disk space requirements. For more information on streaming mode, the reader is referred to the blog post: A Complete Guide to Audio Datasets.

The following script demonstrates how to pseudo-label the LibriSpeech 960h dataset with greedy sampling and streaming mode:

#!/usr/bin/env bash

python run_pseudo_labelling.py \
  --model_name_or_path "openai/whisper-large-v2" \
  --dataset_name "librispeech_asr" \
  --dataset_config_name "all" \
  --data_split_name "train.clean.100+train.clean.360+train.other.500" \
  --text_column_name "text" \
  --output_dir "./transcriptions" \
  --per_device_eval_batch_size 16 \
  --max_label_length 256 \
  --dtype "bfloat16" \
  --report_to "wandb" \
  --dataloader_num_workers 16 \
  --streaming \
  --push_to_hub \
  --generation_num_beams 1  # for greedy, set >1 for beam

The script will save the generated pseudo-labels alongside the file ids to the output directory output_dir. Adding the --push_to_hub argument uploads the generated pseudo-labels to the Hugging Face Hub on save.

The directory pseudo_labelling_scripts contains a collection of bash scripts for pseudo-labelling all 10 audio datasets used in the project. The datasets with the Whisper generated transcriptions can be found on the Hugging Face Hub under the Distil Whisper organisation. They can be re-used should you wish to bypass the data labelling stage of the reproduction.

Initialisation

The script create_student_model.py can be used to initialise a small student model from a large teacher model. When initialising a student model with fewer layers than the teacher model, the student is initialised by copying maximally spaced layers from the teacher, as per the DistilBart recommendations.

The following command demonstrates how to initialise a student model from the large-v2 checkpoint, with all 32 encoder layer and 2 decoder layers. The 2 student decoder layers are copied from teacher layers 1 and 32 respectively, as the maximally spaced layers.

#!/usr/bin/env bash

python create_student_model.py \
  --teacher_checkpoint "openai/whisper-large-v2" \
  --encoder_layers 32 \
  --decoder_layers 2 \
  --save_dir "./large-32-2" \
  --push_to_hub

Training

The script run_distillation.py is an end-to-end script for loading multiple datasets, a student model, a teacher model, and performing teacher-student distillation. It uses the loss formulation from DistilBart, which is a combination of a cross-entropy, KL-divergence and mean-square error (MSE) loss:

https://github.com/huggingface/distil-whisper/blob/4dd831543e6c40b1159f1ec951db7f4fe0e86850/run_distillation.py#L1725

The weight assigned to the MSE loss is configurable. The others are fixed to the values from the DistilBART paper.

The following command takes the LibriSpeech 960h dataset that was pseudo-labelled in the first stage and trains the 2-layer decoder model intialised in the previous step. Note that multiple training datasets and splits can be loaded by separating the dataset arguments by + symbols. Thus, the script generalises to any number of training datasets.

#!/usr/bin/env bash

python3 run_distillation.py \
  --model_name_or_path "./large-32-2" \
  --teacher_model_name_or_path "openai/whisper-large-v2" \
  --train_dataset_name "librispeech_asr+librispeech_asr+librispeech_asr" \
  --train_dataset_config_name "all+all+all" \
  --train_split_name "train.clean.100+train.clean.360+train.other.500" \
  --train_dataset_samples "100+360+500" \
  --eval_dataset_name "librispeech_asr" \
  --eval_dataset_config_name "all" \
  --eval_split_name "validation.clean" \
  --eval_steps 5000 \
  --save_steps 5000 \
  --warmup_steps 500 \
  --learning_rate 0.0001 \
  --lr_scheduler_type "constant_with_warmup" \
  --logging_steps 25 \
  --save_total_limit 1 \
  --max_steps 20000 \
  --wer_threshold 10 \
  --per_device_train_batch_size 64 \
  --per_device_eval_batch_size 64 \
  --dataloader_num_workers 16 \
  --dtype "bfloat16" \
  --output_dir "./" \
  --do_train \
  --do_eval \
  --use_scan \
  --gradient_checkpointing \
  --overwrite_output_dir \
  --predict_with_generate \
  --freeze_encoder \
  --streaming \
  --use_auth_token \
  --push_to_hub

The above training script will take approximately 20 hours to complete on a TPU v4-8 and yield a final WER of 2.3%.

Training logs will be reported to TensorBoard and WandB, provided the relevant packages are available. An example of a saved checkpoint pushed to the Hugging Face Hub can be found here: large-32-2.

There are a few noteworthy arguments that can be configured to give optimal training performance:

  • train_dataset_samples: defines the number of training samples in each dataset. Used to calculate the sampling probabilities in the dataloader. A good starting point is setting the samples to the number of hours of audio data in each split. A more refined strategy is setting it to the number of training samples in each split, however this might require downloading the dataset offline to compute these statistics.
  • wer_threshold: sets the WER threshold between the normalised pseudo-labels and normalised ground truth labels. Any samples with WER > wer_threshold are discarded from the training data. This is beneficial to avoid training the student model on pseudo-labels where Whisper hallucinated or got the predictions grossly wrong.
  • freeze_encoder: whether to freeze the entire encoder of the student model during training. Beneficial when the student encoder is copied exactly from the teacher encoder. In this case, the encoder hidden-states from the teacher model are re-used for the student model. Stopping the gradient computation through the encoder and sharing the encoder hidden-states provides a significant memory saving, and can enable up to 2x batch sizes.
  • dtype: data type (dtype) in which the model computation should be performed. Note that this only controls the dtype of the computations (forward and backward pass), and not the dtype of the parameters or optimiser states.

The Distil Whisper project extends the above script to train on a combined dataset formed from 12 open-source ASR datasets, totalling 22k hours and over 50k speakers. Template scripts to run training on this composite dataset can be found in the directory distillation_scripts.

Evaluation

There are two types of evaluation performed in Distil-Whisper:

  1. Short form: evaluation on audio samples less than 30s in duration. Examples include typical ASR test sets, such as the LibriSpeech validation set.
  2. Long form: evaluation on audio samples longer than 30s in duration. Examples include entire TED talks or earnings calls.

Both forms of evaluation are performed using the word-error rate (WER) metric.

Short Form

The script run_eval.py can be used to evaluate a trained student model over multiple validation sets. The following example demonstrates how to evaluate the student model trained in the previous step on the LibriSpeech validation.clean and validation.other dev sets. Again, it leverages streaming mode to bypass the need to download the data offline:

#!/usr/bin/env bash

python run_eval.py \
  --model_name_or_path "./large-32-2" \
  --dataset_name "librispeech_asr+librispeech_asr" \
  --dataset_config_name "all+all" \
  --dataset_split_name "validation.clean+validation.other" \
  --output_dir "./large-32-2" \
  --per_device_eval_batch_size 64 \
  --dtype "bfloat16" \
  --dataloader_num_workers 16 \
  --report_to "wandb" \
  --streaming \
  --predict_with_generate

Long Form

Long form evaluation runs on the premise that a single long audio file can be chunked into smaller segments and inferred in parallel. The resulting transcriptions are then joined at the boundaries to give the final text prediction. A small overlap (or stride) is used between adjacent segments to ensure a continuous transcription across chunks.

This style of chunked inference is performed using the FlaxWhisperPipeline class, which is heavily inspired from Whisper JAX.

The script run_long_form_transcription.py can be used to evaluate the trained student model on an arbitrary number of long-form evaluation sets. The following script demonstrates how to evaluate the example student model on two such test sets, Earnings 21 and Earnings 22:

#!/usr/bin/env bash

python run_long_form_transcription.py \
  --model_name_or_path "./large-32-2" \
  --dataset_name "distil-whisper/earnings21+distil-whisper/earnings22" \
  --dataset_config_name "default+default" \
  --dataset_split_name "test+test+test+test" \
  --text_column_name "transcription+transcription" \
  --output_dir "./large-32-2" \
  --per_device_eval_batch_size 64 \
  --chunk_length_s 15 \
  --dtype "bfloat16" \
  --report_to "wandb" \
  --streaming

The argument chunk_length_s controls the length of the chunked audio samples. It should be set to match the typical length of audio the student model was trained on. If unsure about what value of chunk_length_s is optimal for your case, it is recommended to run a sweep over all possible values. A template script for running a WandB sweep can be found under run_chunk_length_s_sweep.yaml.

1. Pseudo Labelling

Greedy vs Beam

We found there to be little-to-no difference in the downstream performance of the distilled model after pseudo labelling using either greedy or beam-search. We attribute this to the minimal difference in performance of the pre-trained Whisper model under greedy and beam-search decoding, giving pseudo-labelled transcriptions of similar quality. We encourage users to generate pseudo-labels using greedy decoding given it runs significantly faster. Beam search is only advised if the pre-trained model is hallucinating significantly on the audio inputs, in which case it helps reduce the frequency and severity of hallucinations. If using beam search, the number of beams can be kept low: even 2 beams helps reduce the amount of hallucinations significantly.

Timestamps

Whisper is trained on a timestamp prediction task as part of the pre-training set-up. Here, a fixed proportion of the pre-training data includes sequence-level timestamps as part of the transcription labels:

<|0.00|> Hey, this is a test transcription. <|3.42|>

Timestamp prediction is useful for enriching the transcriptions with timing information for downstream tasks, such as aligning the Whisper transcription with the output of a speaker diarization system, and also reduces the frequency of hallucinations.

The pseudo-labelling scrip run_pseudo_labelling.py can be extended to predict timestamp information in the audio data by appending the --return_timestamps flag to the launch command. The timestamped labelled data can be passed to the training script in exactly the same way as the non-timestamped version, and the pre-processing function will take care of encoding the timestamps and appending the required task tokens.

Previous Context

Whisper is also pre-trained on a prompting task, where the transcription for the preceding utterance is fed as context to the current one:

<|startofprev|> This is the previous context from the preceding utterance.<|startoftranscript|> And this is the current utterance.<|endoftranscript|>

Annotating the transcriptions with previous context labels is only possible for datasets where we have consecutive files and unique speaker ids, since we need to ensure segment i directly follows on from segment i-1 if we use it as the prompt.

As per the Whisper paper, we mask out the loss over the previous context tokens. At inference time, we can replace the previous context with a “prompt” to encourage the model to generate text in the style of the prompt (i.e. for specific named entities, or styles of transcription)

Acknowledgements

  • 🤗 Hugging Face Transformers for the base Whisper implementation
  • Google's TPU Research Cloud (TRC) programme for their generous provision of Cloud TPUs