{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import os\n", "import re\n", "import gc\n", "import math\n", "import json\n", "import hashlib\n", "import numpy as np\n", "import logging\n", "import torchaudio\n", "from tqdm.auto import tqdm\n", "import torch.nn.functional as F\n", "from encodec.utils import convert_audio\n", "from accelerate import Accelerator\n", "from accelerate.utils import set_seed\n", "from transformers import BertTokenizer\n", "from huggingface_hub import hf_hub_download\n", "from packaging import version\n", "from diffusers.optimization import get_scheduler\n", "\n", "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n", "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n", "from bark.model import GPTConfig, GPT\n", "from bark.model_fine import FineGPT, FineGPTConfig" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training Args" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_batch_size = 8\n", "eval_batch_size = 8\n", "grad_accum = 2\n", "ckpt_path = 'models/coarse_2.pt'\n", "model_type = \"coarse\"\n", "dataset_path = 'datasets/joe_biden_state_of_union/'\n", "logging_dir = 'logs/'\n", "log_with = 'wandb'\n", "hubert_path = 'data/models/hubert/hubert.pt'\n", "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n", "\n", "output_dir = 'coarse_output/'\n", "resume_from_checkpoint = None\n", "\n", "checkpointing_steps = 1000\n", "\n", "mixed_precision = 'bf16'\n", "bits = 16 #4 4 and 8 bit are a work in progress\n", "compute_dtype = torch.bfloat16\n", "double_quant = True\n", "quant_type = 'nf4'\n", "\n", "lora_dim = 64\n", "lora_scaling = 1\n", "lora_dropout = 0.1\n", "lora_module_name = 'transformer.h'\n", "optimize_lora_params_only = False\n", "\n", "learning_rate = 1e-4\n", "scale_lr = False\n", "use_8bit_adam = False\n", "adam_beta1 = 0.9\n", "adam_beta2 = 0.999\n", "adam_epsilon = 1e-8\n", "weight_decay = 0.01\n", "\n", "llm_int8_skip_modules = None\n", "keep_in_fp32_modules = ['lm_head']\n", "\n", "lr_scheduler_type = 'linear'\n", "lr_warmup_steps = 60\n", "num_train_epochs = 5\n", "max_train_steps = None\n", "max_grad_norm = 1.0\n", "\n", "semantic_cross_entropy_loss_weight = 0.0\n", "\n", "seed = 741" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Define Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CONTEXT_WINDOW_SIZE = 1024\n", "\n", "MAX_SEMANTIC_LEN = 256\n", "\n", "SEMANTIC_RATE_HZ = 49.9\n", "SEMANTIC_VOCAB_SIZE = 10_000\n", "\n", "TEXT_ENCODING_OFFSET = 10_048\n", "SEMANTIC_PAD_TOKEN = 10_000\n", "TEXT_PAD_TOKEN = 129_595\n", "SEMANTIC_INFER_TOKEN = 129_599\n", "\n", "MAX_COARSE_LEN = 768\n", "\n", "SAMPLE_RATE = 24_000\n", "CHANNELS = 1\n", "\n", "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n", "COARSE_INFER_TOKEN = 12_050\n", "\n", "CODEBOOK_SIZE = 1024\n", "N_COARSE_CODEBOOKS = 2\n", "N_FINE_CODEBOOKS = 8\n", "COARSE_RATE_HZ = 75\n", "\n", "logger = logging.getLogger(__name__)\n", "\n", "\n", "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n", "\n", "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n", "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n", "\n", "\n", "def _clear_cuda_cache():\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.synchronize()\n", "\n", "\n", "def _md5(fname):\n", " hash_md5 = hashlib.md5()\n", " with open(fname, \"rb\") as f:\n", " for chunk in iter(lambda: f.read(4096), b\"\"):\n", " hash_md5.update(chunk)\n", " return hash_md5.hexdigest()\n", "\n", "\n", "def _download(from_hf_path, file_name, to_local_path):\n", " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n", " path = '/'.join(to_local_path.split(\"/\")[:-1])\n", " os.makedirs(path, exist_ok=True)\n", " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n", " os.replace(os.path.join(path, file_name), to_local_path)\n", "\n", "\n", "def _tokenize(tokenizer, text):\n", " return tokenizer.encode(text, add_special_tokens=False)\n", "\n", "\n", "def _detokenize(tokenizer, enc_text):\n", " return tokenizer.decode(enc_text)\n", "\n", "\n", "def _normalize_whitespace(text):\n", " return re.sub(r\"\\s+\", \" \", text).strip()\n", "\n", "\n", "REMOTE_MODEL_PATHS = {\n", " \"text_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"text.pt\",\n", " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n", " },\n", " \"coarse_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"coarse.pt\",\n", " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n", " },\n", " \"fine_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"fine.pt\",\n", " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n", " },\n", " \"text\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"text_2.pt\",\n", " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n", " },\n", " \"coarse\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"coarse_2.pt\",\n", " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n", " },\n", " \"fine\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"fine_2.pt\",\n", " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n", " },\n", "}\n", "\n", "\n", "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n", " if model_type == \"text\":\n", " ConfigClass = GPTConfig\n", " ModelClass = GPT\n", " elif model_type == \"coarse\":\n", " ConfigClass = GPTConfig\n", " ModelClass = GPT\n", " elif model_type == \"fine\":\n", " ConfigClass = FineGPTConfig\n", " ModelClass = FineGPT\n", " else:\n", " raise NotImplementedError()\n", " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n", " model_info = REMOTE_MODEL_PATHS[model_key]\n", " if ckpt_path in [None, '']:\n", " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n", " if not os.path.exists(ckpt_path):\n", " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n", " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n", " checkpoint = torch.load(ckpt_path, map_location=device)\n", " # this is a hack\n", " model_args = checkpoint[\"model_args\"]\n", " if \"input_vocab_size\" not in model_args:\n", " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n", " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n", " del model_args[\"vocab_size\"]\n", " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n", " model = ModelClass(gptconf)\n", " state_dict = checkpoint[\"model\"]\n", " # fixup checkpoint\n", " unwanted_prefix = \"_orig_mod.\"\n", " for k, v in list(state_dict.items()):\n", " if k.startswith(unwanted_prefix):\n", " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n", " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n", " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n", " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n", " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n", " if len(extra_keys) != 0:\n", " raise ValueError(f\"extra keys found: {extra_keys}\")\n", " if len(missing_keys) != 0:\n", " raise ValueError(f\"missing keys: {missing_keys}\")\n", " model.load_state_dict(state_dict, strict=False)\n", " n_params = model.get_num_params()\n", " val_loss = checkpoint[\"best_val_loss\"].item()\n", " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n", " del checkpoint, state_dict\n", " _clear_cuda_cache()\n", " if model_type == \"text\":\n", " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n", " return model, tokenizer\n", " return model\n", "\n", "\n", "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n", " assert len(arr.shape) == 2\n", " arr = arr.copy()\n", " if offset_size is not None:\n", " for n in range(1, arr.shape[0]):\n", " arr[n, :] += offset_size * n\n", " flat_arr = arr.ravel(\"F\")\n", " return flat_arr\n", "\n", "\n", "def load_filepaths_and_text(filename, split=\"|\"):\n", " with open(filename, encoding='utf-8', errors='ignore') as f:\n", " filepaths_and_text = [line.strip().split(split) for line in f]\n", " base = os.path.dirname(filename)\n", " for j in range(len(filepaths_and_text)):\n", " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n", " return filepaths_and_text\n", "\n", "\n", "class TtsDataset(torch.utils.data.Dataset):\n", " def __init__(self, opt):\n", " self.path = os.path.dirname(opt['path'])\n", " self.mode = opt['mode']\n", " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n", "\n", " def __getitem__(self, index):\n", " audiopath_and_text = self.audiopaths_and_text[index]\n", " audiopath = audiopath_and_text[0]\n", "\n", " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n", " semantic_tokens = tokens['semantic']\n", " coarse_tokens = _flatten_codebooks(tokens['coarse'], offset_size=CODEBOOK_SIZE) + SEMANTIC_VOCAB_SIZE\n", "\n", " return torch.from_numpy(semantic_tokens), torch.from_numpy(coarse_tokens)\n", "\n", " def __len__(self):\n", " return len(self.audiopaths_and_text)\n", "\n", "\n", "class TtsCollater():\n", " def __init__(self):\n", " pass\n", " def __call__(self, batch):\n", " max_semantic_len = MAX_SEMANTIC_LEN\n", " max_coarse_len = MAX_COARSE_LEN\n", " semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS\n", " semantic_tokens = []\n", " coarse_tokens = []\n", "\n", " for b in batch:\n", " semantic_tokens_, coarse_tokens_ = b\n", " start_idx = None\n", " if len(semantic_tokens_) > max_semantic_len:\n", " start_idx = np.random.randint(0, len(semantic_tokens_) - max_semantic_len + 1)\n", " semantic_tokens_ = semantic_tokens_[start_idx:start_idx+max_semantic_len]\n", " semantic_tokens_ = F.pad(semantic_tokens_, (0, max_semantic_len-len(semantic_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n", " semantic_tokens_ = torch.cat([semantic_tokens_, torch.tensor([COARSE_INFER_TOKEN])])\n", " semantic_tokens.append(semantic_tokens_)\n", "\n", " if start_idx is not None:\n", " start_idx_coarse = int(start_idx * semantic_to_coarse_ratio) \n", " coarse_tokens_ = coarse_tokens_[start_idx_coarse:start_idx_coarse+max_coarse_len]\n", " coarse_tokens_ = F.pad(coarse_tokens_, (0, max_coarse_len-len(coarse_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n", " coarse_tokens.append(coarse_tokens_)\n", "\n", " return {\n", " 'semantic_tokens': torch.stack(semantic_tokens).contiguous(),\n", " 'coarse_tokens': torch.stack(coarse_tokens).contiguous()\n", " }\n", " \n", "\n", "accelerator = Accelerator(\n", " gradient_accumulation_steps=grad_accum,\n", " mixed_precision=mixed_precision,\n", " log_with=log_with,\n", " logging_dir=logging_dir,\n", ")\n", "device = accelerator.device\n", "\n", "os.makedirs(output_dir, exist_ok=True)\n", "\n", "set_seed(seed)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Setup Dataset (only need to do this once)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n", "\n", "# path = dataset_path\n", "\n", "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", "# from hubert.hubert_manager import HuBERTManager\n", "# hubert_manager = HuBERTManager()\n", "# from hubert.pre_kmeans_hubert import CustomHubert\n", "# from hubert.customtokenizer import CustomTokenizer\n", "# hubert_manager.make_sure_hubert_installed()\n", "# hubert_manager.make_sure_tokenizer_installed()\n", "\n", "# # Load the HuBERT model\n", "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n", "# hubert_model.eval()\n", "# for param in hubert_model.parameters():\n", "# param.requires_grad = False\n", "\n", "# # Load the CustomTokenizer model\n", "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n", "\n", "# from bark.generation import load_codec_model\n", "# codec_model = load_codec_model(use_gpu=True)\n", "# codec_model.eval()\n", "# for param in codec_model.parameters():\n", "# param.requires_grad = False\n", "\n", "\n", "# def get_duration(wav, sr):\n", "# return wav.shape[1] / sr\n", "\n", "# valid_lines_train = []\n", "# # convert wavs to semantic tokens\n", "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n", "# wav, sr = torchaudio.load(wav_path)\n", "# if not get_duration(wav, sr) > max_duration_sec:\n", "# valid_lines_train.append((wav_path, txt))\n", "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", "\n", "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", "\n", "# # save semantic tokens\n", "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", "# semantic_tokens = semantic_tokens.cpu().numpy()\n", "\n", "# # Extract discrete codes from EnCodec\n", "# with torch.no_grad():\n", "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", "\n", "# # move codes to cpu\n", "# codes = codes.cpu().numpy()\n", "\n", "# # save tokens\n", "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", "\n", "# # rewrite train.txt with valid lines\n", "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n", "# for wav_path, txt in valid_lines_train:\n", "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", "# f.write(f'{wav_path}|{txt}\\n')\n", "\n", "# valid_lines_valid = []\n", "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n", "# wav, sr = torchaudio.load(wav_path)\n", "# if not get_duration(wav, sr) > max_duration_sec:\n", "# valid_lines_valid.append((wav_path, txt))\n", "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", "\n", "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", "\n", "# # save semantic tokens\n", "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", "# semantic_tokens = semantic_tokens.cpu().numpy()\n", " \n", "# # Extract discrete codes from EnCodec\n", "# with torch.no_grad():\n", "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", "\n", "# # move codes to cpu\n", "# codes = codes.cpu().numpy()\n", "\n", "# # save tokens\n", "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", "\n", "# # rewrite valid.txt with valid lines\n", "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n", "# for wav_path, txt in valid_lines_valid:\n", "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", "# f.write(f'{wav_path}|{txt}\\n')\n", "\n", "# del hubert_model\n", "# del hubert_tokenizer\n", "# del codec_model\n", "# gc.collect()\n", "# torch.cuda.empty_cache()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if scale_lr:\n", " learning_rate = (\n", " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n", " )\n", "\n", "if use_8bit_adam:\n", " try:\n", " import bitsandbytes as bnb\n", " except ImportError:\n", " raise ImportError(\n", " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n", " )\n", "\n", " optimizer_class = bnb.optim.AdamW8bit\n", "else:\n", " optimizer_class = torch.optim.AdamW" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "quantization_config=BitsAndBytesConfig(\n", " load_in_4bit=bits == 4,\n", " load_in_8bit=bits == 8,\n", " llm_int8_threshold=6.0,\n", " llm_int8_has_fp16_weight=False,\n", " bnb_4bit_compute_dtype=compute_dtype,\n", " bnb_4bit_use_double_quant=double_quant,\n", " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n", ")\n", "\n", "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n", "# if quantization_config.load_in_8bit:\n", "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n", "# elif quantization_config.load_in_4bit:\n", "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n", "\n", "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n", "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n", "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n", "# else:\n", "# modules_to_not_convert = llm_int8_skip_modules\n", "\n", "# if not isinstance(modules_to_not_convert, list):\n", "# modules_to_not_convert = [modules_to_not_convert]\n", "\n", "# modules_to_not_convert.extend(keep_in_fp32_modules)\n", "\n", "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n", "\n", "# if quantization_config.load_in_4bit and not supports_4bit:\n", "# raise ValueError(\n", "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n", "# \" make sure you have the latest version of `bitsandbytes` installed\"\n", "# )\n", " \n", "# if len(modules_to_not_convert) == 0:\n", "# modules_to_not_convert = None\n", "\n", "# model = replace_with_bnb_linear(\n", "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n", "# )\n", "\n", "# # training in 8-bit is only available in 0.37.0+\n", "# model._is_kbit_training_enabled = version.parse(\n", "# importlib_metadata.version(\"bitsandbytes\")\n", "# ) >= version.parse(\"0.37.0\")\n", "\n", "# model.config.quantization_config = quantization_config" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if bits == 4:\n", " from accelerate.utils import CustomDtype\n", " target_dtype = CustomDtype.INT4\n", "elif bits == 8:\n", " target_dtype = torch.int8\n", "\n", "if lora_dim > 0:\n", " for param in model.parameters():\n", " if param.ndim == 1:\n", " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", " param.data = param.data.to(torch.float32)\n", " \n", " class CastOutputToFloat(nn.Sequential):\n", " def forward(self, x):\n", " return super().forward(x).to(torch.float32)\n", "\n", " model.lm_head = CastOutputToFloat(model.lm_head)\n", "\n", " model = convert_linear_layer_to_lora(model, lora_module_name,\n", " lora_dim=lora_dim, lora_scaling=lora_scaling,\n", " lora_dropout=lora_dropout)\n", " if optimize_lora_params_only:\n", " model = only_optimize_lora_parameters(model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params_to_optimize = (\n", " param for param in model.parameters() if param.requires_grad\n", " )\n", "\n", "optimizer = optimizer_class(\n", " params_to_optimize,\n", " lr=learning_rate,\n", " betas=(adam_beta1, adam_beta2),\n", " weight_decay=weight_decay,\n", " eps=adam_epsilon,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt_train = {\n", " 'path': dataset_path,\n", " 'mode': 'train',\n", "}\n", "\n", "opt_val = {\n", " 'path': dataset_path,\n", " 'mode': 'valid',\n", "}\n", "\n", "train_dataset = TtsDataset(opt_train)\n", "validation_dataset = TtsDataset(opt_val)\n", "\n", "train_dataloader = torch.utils.data.DataLoader(\n", " train_dataset,\n", " batch_size=train_batch_size,\n", " collate_fn=TtsCollater(),\n", ")\n", "\n", "validation_dataloader = torch.utils.data.DataLoader(\n", " validation_dataset,\n", " batch_size=eval_batch_size,\n", " collate_fn=TtsCollater(),\n", ")\n", "\n", "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n", "\n", "# Scheduler and math around the number of training steps.\n", "overrode_max_train_steps = False\n", "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", "if max_train_steps is None:\n", " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", " overrode_max_train_steps = True\n", "\n", "lr_scheduler = get_scheduler(\n", " lr_scheduler_type,\n", " optimizer=optimizer,\n", " num_warmup_steps=lr_warmup_steps * grad_accum,\n", " num_training_steps=max_train_steps * grad_accum,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n", " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n", ")\n", "accelerator.register_for_checkpointing(lr_scheduler)\n", "\n", "weight_dtype = torch.float32\n", "if accelerator.mixed_precision == \"fp16\":\n", " weight_dtype = torch.float16\n", "elif accelerator.mixed_precision == \"bf16\":\n", " weight_dtype = torch.bfloat16" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", "if overrode_max_train_steps:\n", " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", "# Afterwards we recalculate our number of training epochs\n", "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n", "\n", "# We need to initialize the trackers we use, and also store our configuration.\n", "# The trackers initializes automatically on the main process.\n", "if accelerator.is_main_process:\n", " accelerator.init_trackers(\"bark_coarse\", config={})\n", "\n", "# Train!\n", "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n", "logger.info(\"***** Running training *****\")\n", "logger.info(f\" Num examples = {len(train_dataset)}\")\n", "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n", "logger.info(f\" Num Epochs = {num_train_epochs}\")\n", "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n", "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n", "logger.info(f\" Total optimization steps = {max_train_steps}\")\n", "global_step = 0\n", "first_epoch = 0\n", "\n", "if resume_from_checkpoint:\n", " if resume_from_checkpoint != \"latest\":\n", " path = os.path.basename(resume_from_checkpoint)\n", " else:\n", " # Get the most recent checkpoint\n", " dirs = os.listdir(output_dir)\n", " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", " path = dirs[-1]\n", " accelerator.print(f\"Resuming from checkpoint {path}\")\n", " accelerator.load_state(os.path.join(output_dir, path))\n", " global_step = int(path.split(\"-\")[1])\n", "\n", " resume_global_step = global_step * grad_accum\n", " first_epoch = resume_global_step // num_update_steps_per_epoch\n", " resume_step = resume_global_step % num_update_steps_per_epoch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if accelerator.is_main_process:\n", " model.eval()\n", " validation_loss = 0.0\n", " num_batches = 0\n", " num_samples = 0\n", " with torch.no_grad():\n", " for val_step, val_batch in enumerate(validation_dataloader):\n", " # Similar to training, process the validation batch\n", " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n", " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n", " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n", "\n", " # Forward pass for validation\n", " val_logits = model(val_inputs, training=True)\n", " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n", "\n", " # Calculate the validation loss\n", " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", " validation_loss += val_loss.item()\n", " num_batches += 1\n", " num_samples += val_batch['semantic_tokens'].size(0)\n", "\n", " average_validation_loss = validation_loss / num_batches\n", " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Only show the progress bar once on each machine.\n", "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n", "progress_bar.set_description(\"Steps\")\n", "\n", "for epoch in range(first_epoch, num_train_epochs):\n", " model.train()\n", " for step, batch in enumerate(train_dataloader):\n", " # Skip steps until we reach the resumed step\n", " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n", " if step % grad_accum == 0:\n", " progress_bar.update(1)\n", " continue\n", "\n", " with accelerator.accumulate(model):\n", " targets = batch['coarse_tokens'][:, 1:].contiguous()\n", " \n", " # Remove the last coarse token from the inputs since there is no target for it.\n", " coarse_inputs = batch['coarse_tokens'][:, :-1]\n", "\n", " # Combine the semantic tokens and coarse tokens and feed them into the model.\n", " inputs = torch.cat([batch['semantic_tokens'], coarse_inputs], dim=1)\n", " logits = model(inputs, training=True)\n", "\n", " # We're only interested in the logits for the coarse tokens, so we ignore the logits for the input text tokens.\n", " coarse_logits = logits[:, batch['semantic_tokens'].size(1):].contiguous()\n", "\n", " # Compute the loss.\n", " loss = criterion(coarse_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n", "\n", " if semantic_cross_entropy_loss_weight > 0 and semantic_cross_entropy_loss_weight is not None:\n", " semantic_logits = logits[:, :batch['semantic_tokens'].size(1)].contiguous()\n", " semantic_loss = criterion(\n", " semantic_logits.view(-1, model.config.input_vocab_size),\n", " batch['semantic_tokens'].view(-1),\n", " )\n", " num_semantic_logits = semantic_logits.size(1)\n", " num_coarse_logits = coarse_logits.size(1)\n", " loss = (\n", " semantic_loss * num_semantic_logits * semantic_cross_entropy_loss_weight +\n", " loss * num_coarse_logits\n", " ) / (num_semantic_logits + num_coarse_logits)\n", "\n", " accelerator.backward(loss)\n", " if accelerator.sync_gradients:\n", " params_to_clip = (\n", " param for param in model.parameters() if param.requires_grad\n", " )\n", " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " # Checks if the accelerator has performed an optimization step behind the scenes\n", " if accelerator.sync_gradients:\n", " progress_bar.update(1)\n", " global_step += 1\n", "\n", " if global_step % checkpointing_steps == 0:\n", " if accelerator.is_main_process:\n", " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n", " accelerator.save_state(save_path)\n", " logger.info(f\"Saved state to {save_path}\")\n", "\n", " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", "\n", " if global_step >= max_train_steps:\n", " break\n", " \n", " accelerator.wait_for_everyone()\n", "\n", "if accelerator.is_main_process:\n", " if lora_dim > 0:\n", " model = convert_lora_to_linear_layer(model)\n", " # save model\n", " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n", " \n", " config = model.config.__dict__\n", " # save config\n", " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n", " json.dump(config, f, indent=2)\n", "\n", "accelerator.end_training()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Validation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if accelerator.is_main_process:\n", " model.eval()\n", " validation_loss = 0.0\n", " num_batches = 0\n", " num_samples = 0\n", " with torch.no_grad():\n", " for val_step, val_batch in enumerate(validation_dataloader):\n", " # Similar to training, process the validation batch\n", " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n", " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n", " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n", "\n", " # Forward pass for validation\n", " val_logits = model(val_inputs, training=True)\n", " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n", "\n", " # Calculate the validation loss\n", " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", " validation_loss += val_loss.item()\n", " num_batches += 1\n", " num_samples += val_batch['semantic_tokens'].size(0)\n", "\n", " average_validation_loss = validation_loss / num_batches\n", " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }