{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "9e852db9", "metadata": {}, "outputs": [], "source": [ "# This notebook is currently designed for a GPU using fp16. Hyperparameters however are barely tuned." ] }, { "cell_type": "code", "execution_count": null, "id": "e730080b", "metadata": {}, "outputs": [], "source": [ "import json\n", "import random\n", "import torch\n", "from pathlib import Path\n", "from accelerate import Accelerator\n", "from datasets import load_dataset, concatenate_datasets\n", "from datasets.features import Audio\n", "from torch.utils.data import Dataset, DataLoader\n", "from torch.optim import AdamW\n", "from torch.utils.tensorboard import SummaryWriter\n", "from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor\n", "from wer import calculate_wer # Not what's used in eval.py.\n", "from model import Wav2VecGPT2Model" ] }, { "cell_type": "code", "execution_count": null, "id": "72af6337", "metadata": { "scrolled": true }, "outputs": [], "source": [ "common_voice = load_dataset('mozilla-foundation/common_voice_7_0', 'de', use_auth_token=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "6396e61d", "metadata": {}, "outputs": [], "source": [ "EXPERIMENT_NAME = '00'\n", "\n", "model_dir = Path('end2end/de') / EXPERIMENT_NAME\n", "log_dir = model_dir / 'logs'\n", "log_dir.mkdir(exist_ok=True, parents=True)\n", "\n", "config = {\n", " 'encoder_id': 'jonatasgrosman/wav2vec2-large-xlsr-53-german',\n", " 'decoder_id': 'dbmdz/german-gpt2',\n", " 'decoder_pad_token': '_',\n", " 'decoder_bos_token': '~',\n", " 'num_beams': 1,\n", " 'num_val_examples': 1500,\n", " 'batch_size': 8,\n", " 'base_lr': 3e-4,\n", " 'weight_decay': 0.,\n", " 'accumulate_grad': 4,\n", " 'max_epochs': 10,\n", " 'max_len': 36 # len(max(tokenizer(common_voice['validation']['sentence'] + common_voice['test']['sentence']).input_ids, key=len))" ] }, { "cell_type": "code", "execution_count": null, "id": "6c632a61", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(config['decoder_id'])\n", "tokenizer.add_special_tokens({'pad_token': config['decoder_pad_token'], 'bos_token': config['decoder_bos_token']})\n", "\n", "wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n", "\n", "model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(\n", " config['encoder_id'], config['decoder_id'], max_length=config['max_len'], num_beams=config['num_beams']\n", ")\n", "\n", "model.config.decoder_start_token_id = tokenizer.bos_token_id\n", "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "code", "execution_count": null, "id": "30e5b73c", "metadata": {}, "outputs": [], "source": [ "# Load model from decoder-only training.\n", "model.load_state_dict(torch.load('decoder_only/de/00/model.pt'))" ] }, { "cell_type": "code", "execution_count": null, "id": "5466e908", "metadata": {}, "outputs": [], "source": [ "class AudioDataset(Dataset):\n", " \n", " def __init__(self, ds):\n", " self.ds = ds\n", " \n", " def __len__(self):\n", " return len(self.ds)\n", " \n", " def __getitem__(self, idx):\n", " eg = self.ds[idx]\n", " return eg['audio']['array'], eg['sentence']\n", " \n", "def collate_fn(examples):\n", " # Remove the longest examples, should be only three and these may lead to OOM- or Index-Errors.\n", " examples = [eg for eg in examples if len(eg[0]) < 300_000]\n", " \n", " audio_features = wave2vec_extractor(\n", " [eg[0] for eg in examples], sampling_rate=16_000, return_tensors='pt', padding='longest'\n", " ).input_values\n", " \n", " input_ids = tokenizer(\n", " [eg[1] for eg in examples], return_tensors='pt', padding=True\n", " ).input_ids\n", " \n", " return audio_features, input_ids" ] }, { "cell_type": "code", "execution_count": null, "id": "0453ccc1", "metadata": {}, "outputs": [], "source": [ "train = common_voice['train'].cast_column('audio', Audio(sampling_rate=16_000))\n", "val = common_voice['validation'].cast_column('audio', Audio(sampling_rate=16_000))" ] }, { "cell_type": "code", "execution_count": null, "id": "ad81c9ab", "metadata": {}, "outputs": [], "source": [ "random.seed(419)\n", "val_inds = list(range(len(common_voice['validation'])))\n", "random.shuffle(val_inds)\n", "\n", "train_ds = AudioDataset(concatenate_datasets([train, val.select(val_inds[config['num_val_examples']:])]))\n", "val_ds = AudioDataset(val.select(val_inds[:config['num_val_examples']]))\n", "\n", "train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)\n", "val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d1c290", "metadata": {}, "outputs": [], "source": [ "accelerator = Accelerator(fp16=True)\n", "print(f'Using {accelerator.device}.')" ] }, { "cell_type": "code", "execution_count": null, "id": "2af1f2f1", "metadata": {}, "outputs": [], "source": [ "optimizer = AdamW(model.parameters(), lr=config['base_lr'], weight_decay=config['weight_decay'])" ] }, { "cell_type": "code", "execution_count": null, "id": "6921d32c", "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)" ] }, { "cell_type": "code", "execution_count": null, "id": "d699c404", "metadata": {}, "outputs": [], "source": [ "with open(log_dir / 'config.json', 'w') as config_file:\n", " json.dump(config, config_file, indent=4)\n", " \n", "writer = SummaryWriter(log_dir)\n", "val_golds = common_voice['validation'].select(val_inds[:config['num_val_examples']])['sentence']\n", "best_val_wer = 10.\n", "global_train_step = 0\n", "\n", "for epoch in range(config['max_epochs']):\n", " model.train()\n", " for batch_step, (audio_features, input_ids) in enumerate(train_dl):\n", " global_train_step += 1\n", " \n", " out = model(labels=input_ids, input_values=audio_features)\n", " accelerator.backward(out.loss)\n", " writer.add_scalar('train_loss', out.loss.item(), global_train_step)\n", " \n", " if (batch_step + 1) % config['accumulate_grad'] == 0:\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " if batch_step % 300 == 0:\n", " print(out.loss.item())\n", " \n", " model.eval()\n", " val_preds = []\n", " for audio_features, input_ids in val_dl:\n", " with torch.no_grad():\n", " generated = model.generate(audio_features)\n", " val_preds += tokenizer.batch_decode(generated)\n", " val_preds = [pred.lstrip('~').rstrip('_') for pred in val_preds]\n", " wer = calculate_wer(val_preds, val_golds)\n", " writer.add_scalar('val_wer', wer, epoch)\n", " print('WER: ', wer)\n", " \n", " if wer < best_val_wer:\n", " torch.save(model.state_dict(), model_dir / 'model.pt')\n", " print('Saved Model.')\n", " best_val_wer = wer" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }