Edit model card

Description

This model is a distilled version of the Whisper large v2 model using decoder pruning. It is trained to give the same distribution as the teacher(large-v2) model using Distillation loss (KL loss) + CE Loss. The original model contains 32 decoder layers, whereas the distilled model contains only 8 layers and achieves 4.2% WER on the librispeech dataset with finetuning for just one epoch. The decoding speed of the model is 2x faster than vanilla large-v2 and 40% smaller in size.

Train on your data

accelerate launch student-teacher-distillation-streaming.py --freeze_encoder --keep_punctuation 
--keep_case --teacher_model_name_or_path openai/whisper-large-v2 --student_model_name_or_path large-v2-2 
--student_cache_dir large-v2-2 --output_dir whisper-large-v2-2-en-cv --data_cache_dir commonvoice 
--teacher_cache_dir cache --student_cache_dir large-v2-2-en-cv --text_column sentence 
--train_dataset_name mozilla-foundation/common_voice_13_0 --train_dataset_config_name en --train_split_name train 
--validation_dataset_name mozilla-foundation/common_voice_13_0 --validation_dataset_config_name en 
--validation_split_name test --max_val_samples 2000

Inference

>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset

>>> # load model and processor
>>> processor = WhisperProcessor.from_pretrained("rsonavane/distil-whisper-large-v2-8-ls")
>>> model = WhisperForConditionalGeneration.from_pretrained("rsonavane/distil-whisper-large-v2-8-ls")
>>> model.config.forced_decoder_ids = None

>>> # load dummy dataset and read audio files
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> sample = ds[0]["audio"]
>>> input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

>>> # generate token ids
>>> predicted_ids = model.generate(input_features)
>>> # decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
['<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|endoftext|>']

>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']

Limitations

This experiment aimed to explore the effectiveness of decoder pruning and distillation in enhancing performance after training. The model acquires a similar internal representation of the English language as its teacher model, but with improved inference speed and efficiency for downstream tasks. Additionally, it can be fine-tuned for multiple languages, maintaining the original model's performance while reducing inference latency. There are other frameworks such as JAX that can help improve the same.

Downloads last month
12

Dataset used to train rsonavane/distil-whisper-large-v2-8-ls