{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import IPython.display as ipd\n",
    "\n",
    "import os\n",
    "import json\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import commons\n",
    "import utils\n",
    "from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate\n",
    "from models import SynthesizerTrn\n",
    "from text.symbols import symbols\n",
    "from text import text_to_sequence\n",
    "\n",
    "from scipy.io.wavfile import write\n",
    "\n",
    "\n",
    "def get_text(text, hps):\n",
    "    text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
    "    if hps.data.add_blank:\n",
    "        text_norm = commons.intersperse(text_norm, 0)\n",
    "    text_norm = torch.LongTensor(text_norm)\n",
    "    return text_norm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LJ Speech"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = utils.get_hparams_from_file(\"./configs/ljs_base.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_g = SynthesizerTrn(\n",
    "    len(symbols),\n",
    "    hps.data.filter_length // 2 + 1,\n",
    "    hps.train.segment_size // hps.data.hop_length,\n",
    "    **hps.model).cuda()\n",
    "_ = net_g.eval()\n",
    "\n",
    "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
    "with torch.no_grad():\n",
    "    x_tst = stn_tst.cuda().unsqueeze(0)\n",
    "    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
    "    audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
    "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VCTK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = utils.get_hparams_from_file(\"./configs/vctk_base.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_g = SynthesizerTrn(\n",
    "    len(symbols),\n",
    "    hps.data.filter_length // 2 + 1,\n",
    "    hps.train.segment_size // hps.data.hop_length,\n",
    "    n_speakers=hps.data.n_speakers,\n",
    "    **hps.model).cuda()\n",
    "_ = net_g.eval()\n",
    "\n",
    "_ = utils.load_checkpoint(\"/path/to/pretrained_vctk.pth\", net_g, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
    "with torch.no_grad():\n",
    "    x_tst = stn_tst.cuda().unsqueeze(0)\n",
    "    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
    "    sid = torch.LongTensor([4]).cuda()\n",
    "    audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
    "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Voice Conversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n",
    "collate_fn = TextAudioSpeakerCollate()\n",
    "loader = DataLoader(dataset, num_workers=8, shuffle=False,\n",
    "    batch_size=1, pin_memory=True,\n",
    "    drop_last=True, collate_fn=collate_fn)\n",
    "data_list = list(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda() for x in data_list[0]]\n",
    "    sid_tgt1 = torch.LongTensor([1]).cuda()\n",
    "    sid_tgt2 = torch.LongTensor([2]).cuda()\n",
    "    sid_tgt3 = torch.LongTensor([4]).cuda()\n",
    "    audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n",
    "    audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n",
    "    audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n",
    "print(\"Original SID: %d\" % sid_src.item())\n",
    "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt1.item())\n",
    "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt2.item())\n",
    "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n",
    "print(\"Converted SID: %d\" % sid_tgt3.item())\n",
    "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}