{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e8bfeb22-b42b-47cb-b3df-199864340445", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, DatasetDict\n", "from transformers import WhisperForConditionalGeneration\n", "from transformers import Seq2SeqTrainingArguments\n", "from transformers import Seq2SeqTrainer\n", "\n", "from transformers import WhisperTokenizer\n", "from transformers import WhisperFeatureExtractor\n", "from transformers import WhisperProcessor\n", "from datasets import Audio\n", "import evaluate\n", "\n", "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union" ] }, { "cell_type": "code", "execution_count": 2, "id": "37465609-e163-4fe8-8522-1711ed551af5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset common_voice_11_0 (/home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n", "Found cached dataset common_voice_11_0 (/home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n", " num_rows: 22\n", " })\n", " test: Dataset({\n", " features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],\n", " num_rows: 6\n", " })\n", "})\n" ] } ], "source": [ "common_voice = DatasetDict()\n", "\n", "common_voice[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ml\", split=\"train[:5%]+validation\", use_auth_token=True)\n", "common_voice[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ml\", split=\"test[:5%]\", use_auth_token=True)\n", "\n", "print(common_voice)" ] }, { "cell_type": "code", "execution_count": 3, "id": "d940f0e6-8c51-47e4-929f-fcf8f91be3d6", "metadata": {}, "outputs": [], "source": [ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-tiny\")\n", "tokenizer = WhisperTokenizer.from_pretrained(\"openai/whisper-tiny\", language=\"Malayalam\", task=\"transcribe\")\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", language=\"Malayalam\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 4, "id": "ddf259d2-1387-46f2-8964-b977576cc89b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'client_id': '29ca16eb2c0faea0be0ad73b5d826f5e81dc6fd4acfa9241a002b5d3619fd51c5b00b009e7b98b50caa5829f8a96697d5942b120749ee63a5d637c632bd0f7bc', 'path': '/home/.cache/huggingface/datasets/downloads/extracted/5e6fee23ff6621c1021a557e4424852db80c5f277edb03408614c85e4831964c/common_voice_ml_28913601.mp3', 'audio': {'path': '/home/.cache/huggingface/datasets/downloads/extracted/5e6fee23ff6621c1021a557e4424852db80c5f277edb03408614c85e4831964c/common_voice_ml_28913601.mp3', 'array': array([-5.9054565e-16, -5.8716256e-14, -5.4170010e-15, ...,\n", " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32), 'sampling_rate': 48000}, 'sentence': 'എന്തുകൊണ്ട് യുവാക്കൾ കൂടുതൽ രാഷ്ട്രീയമായി ചിന്തിക്കണം, എന്തുകൊണ്ട് അവർ സംഘടിതരാകണം എന്നതിന്റെ ഉദാത്തമായ ഉദാഹരണമാകുന്നു കേരളം.', 'up_votes': 2, 'down_votes': 0, 'age': '', 'gender': '', 'accent': '', 'locale': 'ml', 'segment': ''}\n" ] } ], "source": [ "print(common_voice[\"train\"][0])" ] }, { "cell_type": "code", "execution_count": 5, "id": "6018adab-f4ff-43c4-bb94-a70db7d78d91", "metadata": {}, "outputs": [], "source": [ "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))" ] }, { "cell_type": "code", "execution_count": 6, "id": "1b435bee-3042-45f4-8451-b6a466a9ec98", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'client_id': '29ca16eb2c0faea0be0ad73b5d826f5e81dc6fd4acfa9241a002b5d3619fd51c5b00b009e7b98b50caa5829f8a96697d5942b120749ee63a5d637c632bd0f7bc', 'path': '/home/.cache/huggingface/datasets/downloads/extracted/5e6fee23ff6621c1021a557e4424852db80c5f277edb03408614c85e4831964c/common_voice_ml_28913601.mp3', 'audio': {'path': '/home/.cache/huggingface/datasets/downloads/extracted/5e6fee23ff6621c1021a557e4424852db80c5f277edb03408614c85e4831964c/common_voice_ml_28913601.mp3', 'array': array([-4.3097585e-14, 1.7633505e-13, 2.9013527e-13, ...,\n", " 0.0000000e+00, 0.0000000e+00, 0.0000000e+00], dtype=float32), 'sampling_rate': 16000}, 'sentence': 'എന്തുകൊണ്ട് യുവാക്കൾ കൂടുതൽ രാഷ്ട്രീയമായി ചിന്തിക്കണം, എന്തുകൊണ്ട് അവർ സംഘടിതരാകണം എന്നതിന്റെ ഉദാത്തമായ ഉദാഹരണമാകുന്നു കേരളം.', 'up_votes': 2, 'down_votes': 0, 'age': '', 'gender': '', 'accent': '', 'locale': 'ml', 'segment': ''}\n" ] } ], "source": [ "print(common_voice[\"train\"][0])" ] }, { "cell_type": "code", "execution_count": 7, "id": "0e4cc10e-3d90-4c27-9ccf-7a4fd6875353", "metadata": {}, "outputs": [], "source": [ "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "\n", "do_lower_case = False\n", "do_remove_punctuation = True\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 8, "id": "241a1504-c8fb-4322-bb53-fabaa01a607f", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"sentence\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 9, "id": "e4059864-c622-4f80-99d3-3450fd852454", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f/cache-10c57a3e7cf91619.arrow\n", "Loading cached processed dataset at /home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f/cache-8adb63851a4a51f7.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.97 s, sys: 16 ms, total: 2.99 s\n", "Wall time: 2.99 s\n" ] } ], "source": [ "%%time\n", "common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names[\"train\"], num_proc=1)" ] }, { "cell_type": "code", "execution_count": 10, "id": "9367d36e-5144-4d3e-ba56-0661e6124f34", "metadata": {}, "outputs": [], "source": [ "max_input_length = 30.0\n", "\n", "def is_audio_in_length_range(length):\n", " return length < max_input_length" ] }, { "cell_type": "code", "execution_count": 11, "id": "87f10bfa-8db2-4142-a4eb-fbae0c72acb3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/.cache/huggingface/datasets/mozilla-foundation___common_voice_11_0/ml/11.0.0/f8e47235d9b4e68fa24ed71d63266a02018ccf7194b2a8c9c598a5f3ab304d9f/cache-153a5b29ef28024e.arrow\n" ] } ], "source": [ "common_voice[\"train\"] = common_voice[\"train\"].filter(\n", " is_audio_in_length_range,\n", " input_columns=[\"input_length\"],\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "id": "0cafcf11-31cb-400e-a6c8-4968386770ed", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lengths and need different padding methods\n", " # first treat the audio inputs by simply returning torch tensors\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # get the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 13, "id": "8a44b772-a3f7-49bf-9c49-d204b83eae00", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 14, "id": "be5e492f-d37e-4548-ad4c-7375fb444d69", "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"wer\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "eda340b8-663d-4596-9a86-f166ed0ba036", "metadata": {}, "outputs": [], "source": [ "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": 16, "id": "f54bf70f-aa99-4353-b079-7eda0baabf4a", "metadata": {}, "outputs": [], "source": [ "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")" ] }, { "cell_type": "code", "execution_count": 17, "id": "4e4ccefd-1069-4a76-9e1c-921364502959", "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 22, "id": "b884bb51-99e8-4a60-8596-e9e8d954742a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "PyTorch: setting up devices\n" ] } ], "source": [ "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./\",\n", " per_device_train_batch_size=64,\n", " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n", " learning_rate=1e-5,\n", " warmup_steps=50,\n", " max_steps=500,\n", " gradient_checkpointing=True,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"wer\",\n", " greater_is_better=False,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "id": "c90ef090-4f16-4bc8-8c9c-6e3293c68917", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/whisper-ml-first-model/./ is already a clone of https://huggingface.co/kurianbenoy/whisper-ml-first-model. Make sure you pull the latest changes with `repo.git_pull()`.\n", "max_steps is given, it will override any value given in num_train_epochs\n", "Using cuda_amp half precision backend\n" ] } ], "source": [ "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=common_voice[\"train\"],\n", " eval_dataset=common_voice[\"test\"],\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor.feature_extractor,\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "id": "50c03267-897f-497d-b11f-78ed16c80480", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Feature extractor saved in ./preprocessor_config.json\n", "tokenizer config file saved in ./tokenizer_config.json\n", "Special tokens file saved in ./special_tokens_map.json\n", "added tokens file saved in ./added_tokens.json\n" ] } ], "source": [ "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "57c06b81-60f2-4c78-ab5f-7e1e1dac97c8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n", "/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 22\n", " Num Epochs = 500\n", " Instantaneous batch size per device = 64\n", " Total train batch size (w. parallel, distributed & accumulation) = 64\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 500\n", " Number of trainable parameters = 37760640\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "
---|
"
],
"text/plain": [
"