{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "39b65a99-3b38-4d91-b710-87fd1dcbecd2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import Audio, interleave_datasets, IterableDataset, IterableDatasetDict, load_dataset\n", "from transformers import WhisperProcessor\n", "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "from typing import List, Optional" ] }, { "cell_type": "code", "execution_count": 2, "id": "24ddd8b0-01b2-4fde-8fa6-a9e0b8f6cbab", "metadata": {}, "outputs": [], "source": [ "def load_multiple_streaming_datasets(\n", " dataset_names: List,\n", " dataset_config_names: List,\n", " splits: Optional[List] = None,\n", " text_column_names: Optional[List] = None,\n", " sampling_rate: Optional[int] = 16000,\n", " stopping_strategy: Optional[str] = \"all_exhausted\",\n", " **kwargs\n", ") -> IterableDataset:\n", "\n", " if len(dataset_names) != len(dataset_config_names):\n", " raise ValueError(\n", " f\"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(dataset_config_names)} configs.\"\n", " )\n", "\n", " if splits is not None and len(splits) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits.\"\n", " )\n", "\n", " if text_column_names is not None and len(text_column_names) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(text_column_names)} text column names.\"\n", " )\n", "\n", " splits = splits if splits is not None else [\"train\" for i in range(len(dataset_names))]\n", " text_column_names = (\n", " text_column_names if text_column_names is not None else [\"text\" for i in range(len(dataset_names))]\n", " )\n", "\n", " all_datasets = []\n", " # iterate over the datasets we want to interleave\n", " for i, dataset_name in enumerate(dataset_names):\n", " dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)\n", " # resample to specified sampling rate\n", " dataset = dataset.cast_column(\"audio\", Audio(sampling_rate))\n", " # normalise columns to [\"audio\", \"sentence\"]\n", " if text_column_names[i] != \"sentence\":\n", " dataset = dataset.rename_column(text_column_names[i], \"sentence\")\n", " dataset = dataset.remove_columns(set(dataset.features.keys()) - set([\"audio\", \"sentence\"]))\n", " all_datasets.append(dataset)\n", "\n", " interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)\n", " return interleaved_dataset" ] }, { "cell_type": "code", "execution_count": 3, "id": "863cba04-f5ed-49d0-9870-f8c72959de07", "metadata": {}, "outputs": [], "source": [ "def normalize_transcriptions(batch):\n", " # optional pre-processing steps\n", " transcription = batch[\"sentence\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " batch[\"sentence\"] = transcription\n", " return batch" ] }, { "cell_type": "code", "execution_count": 4, "id": "17f2c20a-1f60-4eab-9a85-de9b25a65eff", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"sentence\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 9, "id": "5d6c7cbe-4ca3-4ad1-b07c-d5126c1b93b7", "metadata": {}, "outputs": [], "source": [ "dataset_names = [\"mozilla-foundation/common_voice_11_0\", \"google/fleurs\", \"openslr\", \"collectivat/tv3_parla\", \"projecte-aina/parlament_parla\", \"projecte-aina/parlament_parla\"]\n", "dataset_config_names = [\"ca\", \"ca_es\", \"SLR69\", \"ca\", \"clean\", \"other\"]\n", "text_column_names = [\"sentence\", \"transcription\", \"sentence\", \"text\", \"sentence\", \"sentence\"]" ] }, { "cell_type": "code", "execution_count": 10, "id": "aa5f610e-f52b-4cac-806e-56f9d50b6aa8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration ca\n" ] } ], "source": [ "trainset = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)" ] }, { "cell_type": "code", "execution_count": 11, "id": "1090feb6-827c-4b6e-a682-73b3f0e983df", "metadata": {}, "outputs": [], "source": [ "testset = IterableDataset\n", "testset = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ca\", split=\"test\", streaming=True, use_auth_token=True)\n", "testset = testset.cast_column(\"audio\", Audio(sampling_rate=16000))" ] }, { "cell_type": "code", "execution_count": 12, "id": "5b3a3be3-312e-4574-a467-b35c0c978308", "metadata": {}, "outputs": [], "source": [ "COLUMNS_TO_KEEP = [\"sentence\", \"audio\"]\n", "all_columns = testset.features\n", "columns_to_remove = set(all_columns) - set(COLUMNS_TO_KEEP)\n", "\n", "testset = testset.remove_columns(columns_to_remove)" ] }, { "cell_type": "code", "execution_count": 13, "id": "f2c2d794-b915-4d28-8d75-e16c38d924a5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainset.features" ] }, { "cell_type": "code", "execution_count": 14, "id": "c06ecaaa-5d2f-4c79-a617-73a92af367e3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testset.features" ] }, { "cell_type": "code", "execution_count": null, "id": "ba64db32-891f-4d1d-b273-de88194d40ea", "metadata": {}, "outputs": [], "source": [ "#######################\n", "### Datasets are ready to vectorize\n", "#######################" ] }, { "cell_type": "code", "execution_count": 15, "id": "5f019459-3ab7-4677-bcb4-6fc00c73af00", "metadata": {}, "outputs": [], "source": [ "do_lower_case = True\n", "do_remove_punctuation = True\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 16, "id": "c0f97732-7376-44bb-8981-f511fc88d6d4", "metadata": {}, "outputs": [], "source": [ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-medium\", language=\"Catalan\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 19, "id": "404c4a64-1363-42f9-9065-ac36218dfb38", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = trainset.map(prepare_dataset).with_format(\"torch\")\n", "vectorized_testset = testset.map(prepare_dataset).with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 46, "id": "30a49fdb-6663-4485-a62c-d9ae01665e8c", "metadata": {}, "outputs": [], "source": [ "# trainset = trainset.map(normalize_transcriptions)\n", "# testset = raw_testset.map(normalize_transcriptions)" ] }, { "cell_type": "code", "execution_count": 66, "id": "9928ab49-5adf-4407-8aac-198487324913", "metadata": {}, "outputs": [], "source": [ "#def prepare_dataset(batch):\n", "# audio = batch[\"audio\"]\n", "# batch = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"], text=batch[\"sentence\"])\n", " \n", "# batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", "# return batch" ] }, { "cell_type": "code", "execution_count": 67, "id": "351a0b24-a5ac-4094-8fb6-8a5f9f69460a", "metadata": {}, "outputs": [], "source": [ "#trainset = trainset.map(prepare_dataset)\n", "#testset = testset.map(prepare_dataset)" ] }, { "cell_type": "code", "execution_count": 21, "id": "7965552d-db78-40d4-b993-a5fe81d807ea", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.shuffle( buffer_size=500,seed=0,)\n", "vectorized_testset = vectorized_testset.shuffle( buffer_size=500,seed=0,)" ] }, { "cell_type": "code", "execution_count": 22, "id": "8f850564-b99b-4d1d-8554-f21b7a25651f", "metadata": {}, "outputs": [], "source": [ "MAX_DURATION_IN_SECONDS = 30.0\n", "\n", "def is_audio_length_in_range(input_length):\n", " return input_length < MAX_DURATION_IN_SECONDS" ] }, { "cell_type": "code", "execution_count": 24, "id": "7093ea65-1af5-4f4b-aadf-4dba4c9aba4d", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])\n", "vectorized_testset = vectorized_testset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])" ] }, { "cell_type": "code", "execution_count": 71, "id": "4e2666ef-54eb-4dc4-ad9e-0c65321b05ba", "metadata": {}, "outputs": [], "source": [ "#######################\n", "### Setting up Model\n", "#######################" ] }, { "cell_type": "code", "execution_count": 25, "id": "241fc3d6-8e50-4fe3-b2d9-92bc9b434c39", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lengths and need different padding methods\n", " # first treat the audio inputs by simply returning torch tensors\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # get the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 26, "id": "ea8452c1-a2ab-4282-91cc-034c13741533", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 27, "id": "75c02085-b510-48ef-82b5-8a17f89a7f3a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 6.61MB/s]\n" ] } ], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"wer\")" ] }, { "cell_type": "code", "execution_count": 28, "id": "cde4ef71-64e5-49fd-bd6d-3fa245860be6", "metadata": {}, "outputs": [], "source": [ "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": 29, "id": "b939e25e-992e-4137-9a6e-0ffda7a9421a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: 100%|██████████| 1.97k/1.97k [00:00<00:00, 1.63MB/s]\n", "Downloading: 100%|██████████| 3.06G/3.06G [01:03<00:00, 48.3MB/s] \n" ] } ], "source": [ "from transformers import WhisperForConditionalGeneration\n", "\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-medium\")" ] }, { "cell_type": "code", "execution_count": 30, "id": "c941f2e0-80b6-47ed-bfba-c23ca93de177", "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": null, "id": "c10151ab-42d9-4cc1-a6fb-f09757e4597e", "metadata": {}, "outputs": [], "source": [ "#######################\n", "### Setting up my Train Configs\n", "#######################" ] }, { "cell_type": "code", "execution_count": 31, "id": "b63c08f5-b4b4-4b0c-8007-547ba437d852", "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./\",\n", " per_device_train_batch_size=32,\n", " gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size\n", " learning_rate=1e-5,\n", " warmup_steps=100,\n", " max_steps=1000,\n", " gradient_checkpointing=True,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"wer\",\n", " greater_is_better=False,\n", " push_to_hub=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "id": "7957deff-3874-4ca4-a9b5-8a329ac88473", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainerCallback\n", "from transformers.trainer_pt_utils import IterableDatasetShard\n", "from torch.utils.data import IterableDataset\n", "\n", "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n", "class ShuffleCallback(TrainerCallback):\n", " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n", " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n", " pass # set_epoch() is handled by the Trainer\n", " elif isinstance(train_dataloader.dataset, IterableDataset):\n", " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)" ] }, { "cell_type": "code", "execution_count": 33, "id": "e2a68a7f-c34a-410d-9457-eea4b585c517", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "max_steps is given, it will override any value given in num_train_epochs\n", "Using cuda_amp half precision backend\n" ] } ], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=vectorized_trainset,\n", " eval_dataset=vectorized_testset,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", " callbacks=[ShuffleCallback()],\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "id": "c0c5c27a-e1c8-40e2-9911-1c94dbddda0d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Configuration saved in ./config.json\n", "Model weights saved in ./pytorch_model.bin\n", "Feature extractor saved in ./preprocessor_config.json\n", "tokenizer config file saved in ./tokenizer_config.json\n", "Special tokens file saved in ./special_tokens_map.json\n", "added tokens file saved in ./added_tokens.json\n" ] } ], "source": [ "model.save_pretrained(training_args.output_dir)\n", "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": 35, "id": "2bcd5878-cf8a-49cc-b79b-a396b1fb0d28", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 64000\n", " Num Epochs = 9223372036854775807\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 64\n", " Gradient Accumulation steps = 2\n", " Total optimization steps = 1000\n", " Number of trainable parameters = 763857920\n", "Reading metadata...: 905243it [00:16, 54272.90it/s]\n", "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: sentence, audio, input_length. If sentence, audio, input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Wer | \n", "
---|---|---|---|
1000 | \n", "0.121900 | \n", "0.201443 | \n", "10.968810 | \n", "
"
],
"text/plain": [
"