{
"cells": [
{
"cell_type": "markdown",
"id": "75b58048-7d14-4fc6-8085-1fc08c81b4a6",
"metadata": {
"id": "75b58048-7d14-4fc6-8085-1fc08c81b4a6"
},
"source": [
"# Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers"
]
},
{
"cell_type": "markdown",
"id": "fbfa8ad5-4cdc-4512-9058-836cbbf65e1a",
"metadata": {
"id": "fbfa8ad5-4cdc-4512-9058-836cbbf65e1a"
},
"source": [
"In this Colab, we present a step-by-step guide on how to fine-tune Whisper \n",
"for any multilingual ASR dataset using Hugging Face 🤗 Transformers. This is a \n",
"more \"hands-on\" version of the accompanying [blog post](https://huggingface.co/blog/fine-tune-whisper). \n",
"For a more in-depth explanation of Whisper, the Common Voice dataset and the theory behind fine-tuning, the reader is advised to refer to the blog post."
]
},
{
"cell_type": "markdown",
"id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e",
"metadata": {
"id": "afe0d503-ae4e-4aa7-9af4-dbcba52db41e"
},
"source": [
"## Introduction"
]
},
{
"cell_type": "markdown",
"id": "9ae91ed4-9c3e-4ade-938e-f4c2dcfbfdc0",
"metadata": {
"id": "9ae91ed4-9c3e-4ade-938e-f4c2dcfbfdc0"
},
"source": [
"Whisper is a pre-trained model for automatic speech recognition (ASR) \n",
"published in [September 2022](https://openai.com/blog/whisper/) by the authors \n",
"Alec Radford et al. from OpenAI. Unlike many of its predecessors, such as \n",
"[Wav2Vec 2.0](https://arxiv.org/abs/2006.11477), which are pre-trained \n",
"on un-labelled audio data, Whisper is pre-trained on a vast quantity of \n",
"**labelled** audio-transcription data, 680,000 hours to be precise. \n",
"This is an order of magnitude more data than the un-labelled audio data used \n",
"to train Wav2Vec 2.0 (60,000 hours). What is more, 117,000 hours of this \n",
"pre-training data is multilingual ASR data. This results in checkpoints \n",
"that can be applied to over 96 languages, many of which are considered \n",
"_low-resource_.\n",
"\n",
"When scaled to 680,000 hours of labelled pre-training data, Whisper models \n",
"demonstrate a strong ability to generalise to many datasets and domains.\n",
"The pre-trained checkpoints achieve competitive results to state-of-the-art \n",
"ASR systems, with near 3% word error rate (WER) on the test-clean subset of \n",
"LibriSpeech ASR and a new state-of-the-art on TED-LIUM with 4.7% WER (_c.f._ \n",
"Table 8 of the [Whisper paper](https://cdn.openai.com/papers/whisper.pdf)).\n",
"The extensive multilingual ASR knowledge acquired by Whisper during pre-training \n",
"can be leveraged for other low-resource languages; through fine-tuning, the \n",
"pre-trained checkpoints can be adapted for specific datasets and languages \n",
"to further improve upon these results. We'll show just how Whisper can be fine-tuned \n",
"for low-resource languages in this Colab."
]
},
{
"cell_type": "markdown",
"id": "e59b91d6-be24-4b5e-bb38-4977ea143a72",
"metadata": {
"id": "e59b91d6-be24-4b5e-bb38-4977ea143a72"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"id": "21b6316e-8a55-4549-a154-66d3da2ab74a",
"metadata": {
"id": "21b6316e-8a55-4549-a154-66d3da2ab74a"
},
"source": [
"The Whisper checkpoints come in five configurations of varying model sizes.\n",
"The smallest four are trained on either English-only or multilingual data.\n",
"The largest checkpoint is multilingual only. All nine of the pre-trained checkpoints \n",
"are available on the [Hugging Face Hub](https://huggingface.co/models?search=openai/whisper). The \n",
"checkpoints are summarised in the following table with links to the models on the Hub:\n",
"\n",
"| Size | Layers | Width | Heads | Parameters | English-only | Multilingual |\n",
"|--------|--------|-------|-------|------------|------------------------------------------------------|---------------------------------------------------|\n",
"| tiny | 4 | 384 | 6 | 39 M | [✓](https://huggingface.co/openai/whisper-tiny.en) | [✓](https://huggingface.co/openai/whisper-tiny.) |\n",
"| base | 6 | 512 | 8 | 74 M | [✓](https://huggingface.co/openai/whisper-base.en) | [✓](https://huggingface.co/openai/whisper-base) |\n",
"| small | 12 | 768 | 12 | 244 M | [✓](https://huggingface.co/openai/whisper-small.en) | [✓](https://huggingface.co/openai/whisper-small) |\n",
"| medium | 24 | 1024 | 16 | 769 M | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |\n",
"| large | 32 | 1280 | 20 | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large) |\n",
"\n",
"For demonstration purposes, we'll fine-tune the multilingual version of the \n",
"[`\"small\"`](https://huggingface.co/openai/whisper-small) checkpoint with 244M params (~= 1GB). \n",
"As for our data, we'll train and evaluate our system on a low-resource language \n",
"taken from the [Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0)\n",
"dataset. We'll show that with as little as 8 hours of fine-tuning data, we can achieve \n",
"strong performance in this language."
]
},
{
"cell_type": "markdown",
"id": "3a680dfc-cbba-4f6c-8a1f-e1a5ff3f123a",
"metadata": {
"id": "3a680dfc-cbba-4f6c-8a1f-e1a5ff3f123a"
},
"source": [
"------------------------------------------------------------------------\n",
"\n",
"\\\\({}^1\\\\) The name Whisper follows from the acronym “WSPSR”, which stands for “Web-scale Supervised Pre-training for Speech Recognition”."
]
},
{
"cell_type": "markdown",
"id": "b219c9dd-39b6-4a95-b2a1-3f547a1e7bc0",
"metadata": {
"id": "b219c9dd-39b6-4a95-b2a1-3f547a1e7bc0"
},
"source": [
"## Load Dataset"
]
},
{
"cell_type": "markdown",
"id": "674429c5-0ab4-4adf-975b-621bb69eca38",
"metadata": {
"id": "674429c5-0ab4-4adf-975b-621bb69eca38"
},
"source": [
"Using 🤗 Datasets, downloading and preparing data is extremely simple. \n",
"We can download and prepare the Common Voice splits in just one line of code. \n",
"\n",
"First, ensure you have accepted the terms of use on the Hugging Face Hub: [mozilla-foundation/common_voice_11_0](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0). Once you have accepted the terms, you will have full access to the dataset and be able to download the data locally.\n",
"\n",
"Since Hindi is very low-resource, we'll combine the `train` and `validation` \n",
"splits to give approximately 8 hours of training data. We'll use the 4 hours \n",
"of `test` data as our held-out test set:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
"metadata": {
"id": "a2787582-554f-44ce-9f38-4180a5ed6b44"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/ubuntu/.local/lib/python3.8/site-packages/pandas/core/computation/expressions.py:20: UserWarning: Pandas requires version '2.7.3' or newer of 'numexpr' (version '2.7.1' currently installed).\n",
" from pandas.core.computation.check import NUMEXPR_INSTALLED\n",
"Found cached dataset common_voice_11_0 (/home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ar/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n",
"Found cached dataset common_voice_11_0 (/home/ubuntu/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ar/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
" num_rows: 38481\n",
" })\n",
" test: Dataset({\n",
" features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n",
" num_rows: 10440\n",
" })\n",
"})\n"
]
}
],
"source": [
"from datasets import load_dataset, DatasetDict\n",
"\n",
"common_voice = DatasetDict()\n",
"\n",
"common_voice[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ar\", split=\"train+validation\", use_auth_token=True)\n",
"common_voice[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ar\", split=\"test\", use_auth_token=True)\n",
"\n",
"print(common_voice)"
]
},
{
"cell_type": "markdown",
"id": "d5c7c3d6-7197-41e7-a088-49b753c1681f",
"metadata": {
"id": "d5c7c3d6-7197-41e7-a088-49b753c1681f"
},
"source": [
"Most ASR datasets only provide input audio samples (`audio`) and the \n",
"corresponding transcribed text (`sentence`). Common Voice contains additional \n",
"metadata information, such as `accent` and `locale`, which we can disregard for ASR.\n",
"Keeping the notebook as general as possible, we only consider the input audio and\n",
"transcribed text for fine-tuning, discarding the additional metadata information:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "20ba635d-518c-47ac-97ee-3cad25f1e0ce",
"metadata": {
"id": "20ba635d-518c-47ac-97ee-3cad25f1e0ce"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['audio', 'sentence'],\n",
" num_rows: 38481\n",
" })\n",
" test: Dataset({\n",
" features: ['audio', 'sentence'],\n",
" num_rows: 10440\n",
" })\n",
"})\n"
]
}
],
"source": [
"common_voice = common_voice.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"])\n",
"\n",
"print(common_voice)"
]
},
{
"cell_type": "markdown",
"id": "2d63b2d2-f68a-4d74-b7f1-5127f6d16605",
"metadata": {
"id": "2d63b2d2-f68a-4d74-b7f1-5127f6d16605"
},
"source": [
"## Prepare Feature Extractor, Tokenizer and Data"
]
},
{
"cell_type": "markdown",
"id": "601c3099-1026-439e-93e2-5635b3ba5a73",
"metadata": {
"id": "601c3099-1026-439e-93e2-5635b3ba5a73"
},
"source": [
"The ASR pipeline can be de-composed into three stages: \n",
"1) A feature extractor which pre-processes the raw audio-inputs\n",
"2) The model which performs the sequence-to-sequence mapping \n",
"3) A tokenizer which post-processes the model outputs to text format\n",
"\n",
"In 🤗 Transformers, the Whisper model has an associated feature extractor and tokenizer, \n",
"called [WhisperFeatureExtractor](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperFeatureExtractor)\n",
"and [WhisperTokenizer](https://huggingface.co/docs/transformers/main/model_doc/whisper#transformers.WhisperTokenizer) \n",
"respectively.\n",
"\n",
"We'll go through details for setting-up the feature extractor and tokenizer one-by-one!"
]
},
{
"cell_type": "markdown",
"id": "560332eb-3558-41a1-b500-e83a9f695f84",
"metadata": {
"id": "560332eb-3558-41a1-b500-e83a9f695f84"
},
"source": [
"### Load WhisperFeatureExtractor"
]
},
{
"cell_type": "markdown",
"id": "32ec8068-0bd7-412d-b662-0edb9d1e7365",
"metadata": {
"id": "32ec8068-0bd7-412d-b662-0edb9d1e7365"
},
"source": [
"The Whisper feature extractor performs two operations:\n",
"1. Pads / truncates the audio inputs to 30s: any audio inputs shorter than 30s are padded to 30s with silence (zeros), and those longer that 30s are truncated to 30s\n",
"2. Converts the audio inputs to _log-Mel spectrogram_ input features, a visual representation of the audio and the form of the input expected by the Whisper model"
]
},
{
"cell_type": "markdown",
"id": "589d9ec1-d12b-4b64-93f7-04c63997da19",
"metadata": {
"id": "589d9ec1-d12b-4b64-93f7-04c63997da19"
},
"source": [
"