{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "612692e6-fe5f-4787-86d9-c660bb9a21ec", "metadata": {}, "outputs": [ { "ename": "ModuleNotFoundError", "evalue": "No module named 'models'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 27\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mlibrosa\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnltk\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtokenize\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m word_tokenize\n\u001b[0;32m---> 27\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmodels\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;241m*\u001b[39m\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtext_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TextCleaner\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'models'" ] } ], "source": [ "import torch\n", "torch.manual_seed(0)\n", "torch.backends.cudnn.benchmark = False\n", "torch.backends.cudnn.deterministic = True\n", "\n", "import random\n", "random.seed(0)\n", "\n", "import numpy as np\n", "np.random.seed(0)\n", "\n", "#%cd ..\n", "\n", "# load packages\n", "import time\n", "import random\n", "import yaml\n", "from munch import Munch\n", "import numpy as np\n", "import torch\n", "from torch import nn\n", "import torch.nn.functional as F\n", "import torchaudio\n", "import librosa\n", "from nltk.tokenize import word_tokenize\n", "\n", "from models import *\n", "from utils import *\n", "from text_utils import TextCleaner\n", "textclenaer = TextCleaner()\n", "\n", "%matplotlib inline\n", "\n", "device = 'cuda'\n", "\n", "to_mel = torchaudio.transforms.MelSpectrogram(\n", " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", "mean, std = -4, 4\n", "\n", "def length_to_mask(lengths):\n", " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", " return mask\n", "\n", "def preprocess(wave):\n", " wave_tensor = torch.from_numpy(wave).float()\n", " mel_tensor = to_mel(wave_tensor)\n", " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", " return mel_tensor\n", "\n", "def compute_style(ref_dicts):\n", " reference_embeddings = {}\n", " for key, path in ref_dicts.items():\n", " wave, sr = librosa.load(path, sr=24000)\n", " audio, index = librosa.effects.trim(wave, top_db=30)\n", " if sr != 24000:\n", " audio = librosa.resample(audio, sr, 24000)\n", " mel_tensor = preprocess(audio).to(device)\n", "\n", " with torch.no_grad():\n", " ref = model.style_encoder(mel_tensor.unsqueeze(1))\n", " reference_embeddings[key] = (ref.squeeze(1), audio)\n", " \n", " return reference_embeddings\n", "\n", "# load phonemizer\n", "import phonemizer\n", "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", "\n", "config = yaml.safe_load(open(\"Models/LJSpeech/config.yml\"))\n", "\n", "# load pretrained ASR model\n", "ASR_config = config.get('ASR_config', False)\n", "ASR_path = config.get('ASR_path', False)\n", "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", "\n", "# load pretrained F0 model\n", "F0_path = config.get('F0_path', False)\n", "pitch_extractor = load_F0_models(F0_path)\n", "\n", "# load BERT model\n", "from Utils.PLBERT.util import load_plbert\n", "BERT_path = config.get('PLBERT_dir', False)\n", "plbert = load_plbert(BERT_path)\n", "\n", "model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)\n", "_ = [model[key].eval() for key in model]\n", "_ = [model[key].to(device) for key in model]\n", "\n", "params_whole = torch.load(\"Models/LJSpeech/epoch_2nd_00100.pth\", map_location='cpu')\n", "params = params_whole['net']\n", "\n", "for key in model:\n", " if key in params:\n", " print('%s loaded' % key)\n", " try:\n", " model[key].load_state_dict(params[key])\n", " except:\n", " from collections import OrderedDict\n", " state_dict = params[key]\n", " new_state_dict = OrderedDict()\n", " for k, v in state_dict.items():\n", " name = k[7:] # remove `module.`\n", " new_state_dict[name] = v\n", " # load params\n", " model[key].load_state_dict(new_state_dict, strict=False)\n", "# except:\n", "# _load(params[key], model[key])\n", "_ = [model[key].eval() for key in model]\n", "\n", "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n", "\n", "sampler = DiffusionSampler(\n", " model.diffusion.diffusion,\n", " sampler=ADPM2Sampler(),\n", " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", " clamp=False\n", ")\n", "\n", "def inference(text, noise, diffusion_steps=5, embedding_scale=1):\n", " text = text.strip()\n", " text = text.replace('\"', '')\n", " ps = global_phonemizer.phonemize([text])\n", " ps = word_tokenize(ps[0])\n", " ps = ' '.join(ps)\n", "\n", " tokens = textclenaer(ps)\n", " tokens.insert(0, 0)\n", " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", " \n", " with torch.no_grad():\n", " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)\n", " text_mask = length_to_mask(input_lengths).to(tokens.device)\n", "\n", " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", " d_en = model.bert_encoder(bert_dur).transpose(-1, -2) \n", "\n", " s_pred = sampler(noise, \n", " embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,\n", " embedding_scale=embedding_scale).squeeze(0)\n", "\n", " s = s_pred[:, 128:]\n", " ref = s_pred[:, :128]\n", "\n", " d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)\n", "\n", " x, _ = model.predictor.lstm(d)\n", " duration = model.predictor.duration_proj(x)\n", " duration = torch.sigmoid(duration).sum(axis=-1)\n", " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", "\n", " pred_dur[-1] += 5\n", "\n", " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", " c_frame = 0\n", " for i in range(pred_aln_trg.size(0)):\n", " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", " c_frame += int(pred_dur[i].data)\n", "\n", " # encode prosody\n", " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", " out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), \n", " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", " \n", " return out.squeeze().cpu().numpy()\n", "\n", "# synthesize a text\n", "text = ''' Hi James! How are you? '''\n", "\n", "start = time.time()\n", "noise = torch.randn(1,1,256).to(device)\n", "wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)\n", "rtf = (time.time() - start) / (len(wav) / 24000)\n", "print(f\"RTF = {rtf:5f}\")\n", "import IPython.display as ipd\n", "display(ipd.Audio(wav, rate=24000))" ] }, { "cell_type": "code", "execution_count": null, "id": "e28c4e1a-ddb6-4914-8b26-a6053e90c272", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python (styleenv)", "language": "python", "name": "styleenv" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }