{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "612692e6-fe5f-4787-86d9-c660bb9a21ec", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "177\n", "bert loaded\n", "bert_encoder loaded\n", "predictor loaded\n", "decoder loaded\n", "text_encoder loaded\n", "predictor_encoder loaded\n", "style_encoder loaded\n", "diffusion loaded\n", "text_aligner loaded\n", "pitch_extractor loaded\n", "mpd loaded\n", "msd loaded\n", "wd loaded\n", "RTF = 0.061232\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "RTF = 0.066185\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch\n", "torch.manual_seed(0)\n", "torch.backends.cudnn.benchmark = False\n", "torch.backends.cudnn.deterministic = True\n", "\n", "import random\n", "random.seed(0)\n", "\n", "import numpy as np\n", "np.random.seed(0)\n", "\n", "#%cd ..\n", "\n", "# load packages\n", "import time\n", "import random\n", "import yaml\n", "from munch import Munch\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import torchaudio\n", "import librosa\n", "from nltk.tokenize import word_tokenize\n", "\n", "from models import *\n", "from utils import *\n", "from text_utils import TextCleaner\n", "textclenaer = TextCleaner()\n", "\n", "%matplotlib inline\n", "\n", "device = 'cuda'\n", "\n", "to_mel = torchaudio.transforms.MelSpectrogram(\n", " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", "mean, std = -4, 4\n", "\n", "def length_to_mask(lengths):\n", " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", " return mask\n", "\n", "def preprocess(wave):\n", " wave_tensor = torch.from_numpy(wave).float()\n", " mel_tensor = to_mel(wave_tensor)\n", " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", " return mel_tensor\n", "\n", "def compute_style(ref_dicts):\n", " reference_embeddings = {}\n", " for key, path in ref_dicts.items():\n", " wave, sr = librosa.load(path, sr=24000)\n", " audio, index = librosa.effects.trim(wave, top_db=30)\n", " if sr != 24000:\n", " audio = librosa.resample(audio, sr, 24000)\n", " mel_tensor = preprocess(audio).to(device)\n", "\n", " with torch.no_grad():\n", " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n", " reference_embeddings[key] = (ref.squeeze(1), audio)\n", " \n", " return reference_embeddings\n", "\n", "# load phonemizer\n", "import phonemizer\n", "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", "\n", "config = yaml.safe_load(open(\"Models/LJSpeech/config.yml\"))\n", "\n", "# load pretrained ASR model\n", "ASR_config = config.get('ASR_config', False)\n", "ASR_path = config.get('ASR_path', False)\n", "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", "\n", "# load pretrained F0 model\n", "F0_path = config.get('F0_path', False)\n", "pitch_extractor = load_F0_models(F0_path)\n", "\n", "# load BERT model\n", "from Utils.PLBERT.util import load_plbert\n", "BERT_path = config.get('PLBERT_dir', False)\n", "plbert = load_plbert(BERT_path)\n", "\n", "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n", "_ = [model[key].eval() for key in model]\n", "_ = [model[key].to(device) for key in model]\n", "\n", "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n", "params = params_whole['net']\n", "\n", "for key in model:\n", " if key in params:\n", " print('%s loaded' % key)\n", " try:\n", " model[key].load_state_dict(params[key])\n", " except:\n", " from collections import OrderedDict\n", " state_dict = params[key]\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = k[7:] # remove `module.`\n", " new_state_dict[name] = v\n", " # load params\n", " model[key].load_state_dict(new_state_dict, strict=False)\n", "# except:\n", "# _load(params[key], model[key])\n", "_ = [model[key].eval() for key in model]\n", "\n", "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n", "\n", "sampler = DiffusionSampler(\n", " model.diffusion.diffusion,\n", " sampler=ADPM2Sampler(),\n", " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", " clamp=False\n", ")\n", "\n", "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n", " text = text.strip()\n", " text = text.replace('\"', '')\n", " ps = global_phonemizer.phonemize([text])\n", " ps = word_tokenize(ps[0])\n", " ps = ' '.join(ps)\n", "\n", " tokens = textclenaer(ps)\n", " tokens.insert(0, 0)\n", " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", " \n", " with torch.no_grad():\n", " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n", " text_mask = length_to_mask(input_lengths).to(tokens.device)\n", "\n", " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n", "\n", " s_pred = sampler(noise, \n", " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n", " embedding_scale=embedding_scale).squeeze(0)\n", "\n", " s = s_pred[:, 128:]\n", " ref = s_pred[:, :128]\n", "\n", " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n", "\n", " x, _ = model.predictor.lstm(d)\n", " duration = model.predictor.duration_proj(x)\n", " duration = torch.sigmoid(duration).sum(axis=-1)\n", " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", "\n", " pred_dur[-1] += 5\n", "\n", " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", " c_frame = 0\n", " for i in range(pred_aln_trg.size(0)):\n", " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", " c_frame += int(pred_dur[i].data)\n", "\n", " # encode prosody\n", " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n", " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", " \n", " return out.squeeze().cpu().numpy()\n", "\n", "# synthesize a text\n", "prompt = \"Can we have sex together?\"\n", "\n", "\n", "start = time.time()\n", "noise = torch.randn(1,1,256).to(device)\n", "wav = inference(prompt, noise, diffusion_steps=5, embedding_scale=1)\n", "rtf = (time.time() - start) / (len(wav) / 24000)\n", "print(f\"RTF = {rtf:5f}\")\n", "import IPython.display as ipd\n", "display(ipd.Audio(wav, rate=24000))\n", "\n", "start = time.time()\n", "noise = torch.randn(1,1,256).to(device)\n", "wav = inference(prompt, noise, diffusion_steps=10, embedding_scale=1)\n", "rtf = (time.time() - start) / (len(wav) / 24000)\n", "print(f\"RTF = {rtf:5f}\")\n", "import IPython.display as ipd\n", "display(ipd.Audio(wav, rate=24000))\n", "\n", "noise = torch.randn(1,1,256).to(device)\n", "wav = inference(prompt, noise, diffusion_steps=10, embedding_scale=2) # embedding_scale=2 for more pronounced emotion\n", "display(ipd.Audio(wav, rate=24000, normalize=False))\n", "\n", "\n", "\n", "\n", "\n", "def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):\n", " text = text.strip()\n", " text = text.replace('\"', '')\n", " ps = global_phonemizer.phonemize([text])\n", " ps = word_tokenize(ps[0])\n", " ps = ' '.join(ps)\n", "\n", " tokens = textclenaer(ps)\n", " tokens.insert(0, 0)\n", " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", " \n", " with torch.no_grad():\n", " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n", " text_mask = length_to_mask(input_lengths).to(tokens.device)\n", "\n", " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n", "\n", " s_pred = sampler(noise, \n", " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n", " embedding_scale=embedding_scale).squeeze(0)\n", " \n", " if s_prev is not None:\n", " # convex combination of previous and current style\n", " s_pred = alpha * s_prev + (1 - alpha) * s_pred\n", " \n", " s = s_pred[:, 128:]\n", " ref = s_pred[:, :128]\n", "\n", " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n", "\n", " x, _ = model.predictor.lstm(d)\n", " duration = model.predictor.duration_proj(x)\n", " duration = torch.sigmoid(duration).sum(axis=-1)\n", " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", "\n", " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", " c_frame = 0\n", " for i in range(pred_aln_trg.size(0)):\n", " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", " c_frame += int(pred_dur[i].data)\n", "\n", " # encode prosody\n", " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n", " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", " \n", " return out.squeeze().cpu().numpy(), s_pred\n", "\n", "sentences = prompt.split('.') # simple split by comma\n", "wavs = []\n", "s_prev = None\n", "for text in sentences:\n", " if text.strip() == \"\": continue\n", " text += '.' # add it back\n", " noise = torch.randn(1,1,256).to(device)\n", " wav, s_prev = LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=10, embedding_scale=1.5)\n", " wavs.append(wav)\n", "display(ipd.Audio(np.concatenate(wavs), rate=24000, normalize=False))" ] }, { "cell_type": "code", "execution_count": null, "id": "5a31e6a0-6dc4-4185-9ff0-b9ee59e38df0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (styleenv)", "language": "python", "name": "styleenv" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }