{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "521e21ab", "metadata": {}, "outputs": [], "source": [ "# This notebook is currently designed for a GPU using fp16. Hyperparameters however are barely tuned." ] }, { "cell_type": "code", "execution_count": null, "id": "1732f970", "metadata": {}, "outputs": [], "source": [ "import random\n", "import torch\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": null, "id": "f55f4047", "metadata": {}, "outputs": [], "source": [ "EXPERIMENT_NAME = '00'\n", "DATA_PATH = Path('../data/common_voice/de')\n", "\n", "model_dir = Path('decoder_only/de') / EXPERIMENT_NAME\n", "log_dir = model_dir / 'logs'\n", "log_dir.mkdir(exist_ok=True, parents=True)\n", "\n", "config = {\n", " 'use_train_frac': 1.0, # When using all samples the wav2vec-outputs take up ~275GB disk space!!(~360,000 samples)\n", " 'use_val_frac': 0.2,\n", " 'encoder_id': 'jonatasgrosman/wav2vec2-large-xlsr-53-german',\n", " 'decoder_id': 'dbmdz/german-gpt2',\n", " 'decoder_pad_token': '_',\n", " 'decoder_bos_token': '~',\n", " 'num_beams': 1,\n", " 'batch_size': 16,\n", " 'weight_decay': 0.,\n", " 'accumulate_grad': 2,\n", " 'max_epochs': 10,\n", " 'max_len': 36 # len(max(tokenizer(common_voice['validation']['sentence'] + common_voice['test']['sentence']).input_ids, key=len))\n", "}" ] }, { "cell_type": "markdown", "id": "eb3de6a4", "metadata": {}, "source": [ "# Feature Extraction" ] }, { "cell_type": "code", "execution_count": null, "id": "b176328e", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "from datasets import load_dataset\n", "from datasets.features import Audio\n", "from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor" ] }, { "cell_type": "code", "execution_count": null, "id": "54e70696", "metadata": {}, "outputs": [], "source": [ "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d22752", "metadata": {}, "outputs": [], "source": [ "def extract_features_to_files(model, feature_extractor, dataset_split, batch_size, output_path):\n", " output_path = Path(output_path)\n", " output_path.mkdir(parents=True, exist_ok=True)\n", "\n", " model.eval().cuda()\n", " for i in range(0, len(dataset_split), batch_size):\n", " batch = dataset_split[i:i+batch_size]\n", " sent_batch = batch['sentence']\n", " audio_batch = batch['audio']\n", " for i, eg in enumerate(audio_batch):\n", " # Remove the longest examples, should be only three and these may lead to OOM- or Index-Errors.\n", " if len(eg['array']) > 300_000:\n", " print('Too Long.')\n", " audio_batch.pop(i)\n", " sent_batch.pop(i)\n", " features = feature_extractor([eg['array'] for eg in audio_batch],\n", " sampling_rate=16_000,\n", " return_tensors='pt',\n", " padding='longest')\n", "\n", " with torch.no_grad():\n", " out = model(features.input_values.cuda(), attention_mask=features.attention_mask.cuda())\n", "\n", " assert len(sent_batch) == len(audio_batch) == len(out.last_hidden_state)\n", " for sent, audio, hs in zip(sent_batch, audio_batch, out.last_hidden_state.bfloat16().cpu()):\n", " file_name = audio['path'].split('/')[-1]\n", " torch.save(\n", " # .clone() is necessary: https://github.com/pytorch/pytorch/issues/1995\n", " {'sentence': sent, 'wave2vec_features': hs.clone()},\n", " output_path / file_name\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "06324b6f", "metadata": {}, "outputs": [], "source": [ "if not DATA_PATH.exists():\n", " \n", " common_voice = load_dataset('mozilla-foundation/common_voice_7_0', 'de', use_auth_token=True)\n", " \n", " random.seed(419)\n", " train_inds = list(range(len(common_voice['train'])))\n", " random.shuffle(train_inds)\n", " val_inds = list(range(len(common_voice['validation'])))\n", " random.shuffle(val_inds)\n", " \n", " train_inds = train_inds[:int(config['use_train_frac'] * len(train_inds))]\n", " train = common_voice['train'].select(train_inds)\n", " train = train.cast_column('audio', Audio(sampling_rate=16_000))\n", " \n", " val_inds = val_inds[:int(config['use_val_frac'] * len(val_inds))]\n", " val = common_voice['validation'].select(val_inds)\n", " val = val.cast_column('audio', Audio(sampling_rate=16_000))\n", " \n", " # Load Model for feature extraction.\n", " wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n", " wave2vec = Wav2Vec2Model.from_pretrained(config['encoder_id'])\n", " wave2vec.eval().cuda()\n", " \n", " extract_features_to_files(wave2vec, wave2vec_extractor, train, batch_size=8, output_path=DATA_PATH / 'train')\n", " extract_features_to_files(wave2vec, wave2vec_extractor, val, batch_size=8, output_path=DATA_PATH / 'val')\n", " \n", " wave2vec.cpu()\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "id": "b2ae2a47", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": null, "id": "188ef54f", "metadata": {}, "outputs": [], "source": [ "import json\n", "from accelerate import Accelerator\n", "from torch.utils.data import DataLoader\n", "from torch.optim import AdamW\n", "from torch.utils.tensorboard import SummaryWriter\n", "from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor\n", "from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2BaseModelOutput\n", "from data_loading import make_collate_fn, S2TDataset\n", "from wer import calculate_wer # Not what's used in eval.py.\n", "from model import Wav2VecGPT2Model" ] }, { "cell_type": "code", "execution_count": null, "id": "41518c81", "metadata": { "scrolled": false }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(config['decoder_id'])\n", "tokenizer.add_special_tokens({'pad_token': config['decoder_pad_token'], 'bos_token': config['decoder_bos_token']})\n", "\n", "model = Wav2VecGPT2Model.from_encoder_decoder_pretrained(\n", " config['encoder_id'], config['decoder_id'], max_length=config['max_len'], num_beams=config['num_beams']\n", ")\n", "\n", "model.config.decoder_start_token_id = tokenizer.bos_token_id\n", "model.config.pad_token_id = tokenizer.pad_token_id" ] }, { "cell_type": "code", "execution_count": null, "id": "a95ec028", "metadata": {}, "outputs": [], "source": [ "collate_fn = make_collate_fn(tokenizer)\n", "\n", "train_ds = S2TDataset(DATA_PATH / 'train')\n", "train_dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=4)\n", "\n", "val_ds = S2TDataset(DATA_PATH / 'val')\n", "val_dl = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=4)" ] }, { "cell_type": "code", "execution_count": null, "id": "0aaeeced", "metadata": {}, "outputs": [], "source": [ "high_lr_modules = ['cross_attn', 'crossattention', 'enc_to_dec_proj', 'encoder_outputs_pos_emb']\n", "high_lr_params = [p for n, p in model.named_parameters() if any(m in n for m in high_lr_modules)]\n", "\n", "optimizer_grouped_parameters = [\n", " {\n", " \"params\": high_lr_params,\n", " \"lr\": 5e-4,\n", " },\n", " {\n", " \"params\": [p for n, p in model.decoder.named_parameters() if not any(m in n for m in high_lr_modules)],\n", " \"lr\": 1e-6,\n", " },\n", "]\n", "optimizer = AdamW(optimizer_grouped_parameters, weight_decay=0.)" ] }, { "cell_type": "code", "execution_count": null, "id": "cf98d090", "metadata": {}, "outputs": [], "source": [ "accelerator = Accelerator(fp16=True)\n", "print(f'Using {accelerator.device}.')" ] }, { "cell_type": "code", "execution_count": null, "id": "da9e928e", "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dl, val_dl = accelerator.prepare(model, optimizer, train_dl, val_dl)" ] }, { "cell_type": "code", "execution_count": null, "id": "f191f256", "metadata": { "scrolled": false }, "outputs": [], "source": [ "with open(log_dir / 'config.json', 'w') as config_file:\n", " json.dump(config, config_file, indent=4)\n", " \n", "writer = SummaryWriter(log_dir)\n", "val_golds = [eg['sentence'] for eg in val_ds]\n", "best_val_wer = 10.\n", "global_train_step = 0\n", "\n", "for epoch in range(config['max_epochs']):\n", " \n", " model.train()\n", " model.encoder.cpu() # Model gets moved to gpu for evaluation (see below).\n", " torch.cuda.empty_cache()\n", " for batch_step, (encoder_hidden_states, att_mask, input_ids) in enumerate(train_dl):\n", " if encoder_hidden_states.shape[1] > 1024:\n", " # That's too long for the position embeddings. \n", " # TODO: handle this in model code.\n", " print(f'SKIPPED: {encoder_hidden_states.shape}')\n", " continue\n", " global_train_step += 1\n", " \n", " out = model(labels=input_ids, encoder_outputs=Wav2Vec2BaseModelOutput(encoder_hidden_states))\n", " accelerator.backward(out.loss)\n", " writer.add_scalar('train_loss', out.loss.item(), global_train_step)\n", " \n", " if (batch_step + 1) % config['accumulate_grad'] == 0:\n", " optimizer.step()\n", " optimizer.zero_grad()\n", " \n", " if batch_step % 300 == 0:\n", " print(out.loss.item())\n", " \n", " model.eval()\n", " model.cuda() # Necessary for input_ids to be initialized on the correct device.\n", " val_preds = []\n", " for encoder_hidden_states, att_mask, _ in val_dl:\n", " with torch.no_grad():\n", " generated = model.generate(\n", " encoder_outputs=Wav2Vec2BaseModelOutput(last_hidden_state=encoder_hidden_states)\n", " )\n", " val_preds += tokenizer.batch_decode(generated)\n", " val_preds = [pred.lstrip('~').rstrip('_') for pred in val_preds]\n", " wer = calculate_wer(val_preds, val_golds)\n", " writer.add_scalar('val_wer', wer, epoch)\n", " print('WER: ', wer)\n", " \n", " if wer < best_val_wer:\n", " torch.save(model.state_dict(), model_dir / 'model.pt')\n", " print('Saved Model.')\n", " best_val_wer = wer" ] }, { "cell_type": "code", "execution_count": null, "id": "d84a7e5c", "metadata": {}, "outputs": [], "source": [ "# # Load saved pytorch model and save with all necessary model files.\n", "# output_path = model_dir /'full_model'\n", "# \n", "# model.load_state_dict(torch.load(model_dir / 'model.pt'))\n", "# \n", "# tokenizer.save_pretrained(output_path)\n", "# wave2vec_extractor = Wav2Vec2FeatureExtractor.from_pretrained(config['encoder_id'])\n", "# wave2vec_extractor.save_pretrained(output_path)\n", "# model.save_pretrained(output_path)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }