ulysses115 commited on
Commit
883b1e0
1 Parent(s): b2350c3

Upload inference.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.ipynb +200 -0
inference.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%matplotlib inline\n",
10
+ "import matplotlib.pyplot as plt\n",
11
+ "import IPython.display as ipd\n",
12
+ "\n",
13
+ "import os\n",
14
+ "import json\n",
15
+ "import math\n",
16
+ "import torch\n",
17
+ "from torch import nn\n",
18
+ "from torch.nn import functional as F\n",
19
+ "from torch.utils.data import DataLoader\n",
20
+ "\n",
21
+ "import commons\n",
22
+ "import utils\n",
23
+ "from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate\n",
24
+ "from models import SynthesizerTrn\n",
25
+ "from text.symbols import symbols\n",
26
+ "from text import text_to_sequence\n",
27
+ "\n",
28
+ "from scipy.io.wavfile import write\n",
29
+ "\n",
30
+ "\n",
31
+ "def get_text(text, hps):\n",
32
+ " text_norm = text_to_sequence(text, hps.data.text_cleaners)\n",
33
+ " if hps.data.add_blank:\n",
34
+ " text_norm = commons.intersperse(text_norm, 0)\n",
35
+ " text_norm = torch.LongTensor(text_norm)\n",
36
+ " return text_norm"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {},
42
+ "source": [
43
+ "## LJ Speech"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "hps = utils.get_hparams_from_file(\"./configs/ljs_base.json\")"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "net_g = SynthesizerTrn(\n",
62
+ " len(symbols),\n",
63
+ " hps.data.filter_length // 2 + 1,\n",
64
+ " hps.train.segment_size // hps.data.hop_length,\n",
65
+ " **hps.model).cuda()\n",
66
+ "_ = net_g.eval()\n",
67
+ "\n",
68
+ "_ = utils.load_checkpoint(\"/path/to/pretrained_ljs.pth\", net_g, None)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
78
+ "with torch.no_grad():\n",
79
+ " x_tst = stn_tst.cuda().unsqueeze(0)\n",
80
+ " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
81
+ " audio = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
82
+ "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "## VCTK"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "hps = utils.get_hparams_from_file(\"./configs/vctk_base.json\")"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": null,
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "net_g = SynthesizerTrn(\n",
108
+ " len(symbols),\n",
109
+ " hps.data.filter_length // 2 + 1,\n",
110
+ " hps.train.segment_size // hps.data.hop_length,\n",
111
+ " n_speakers=hps.data.n_speakers,\n",
112
+ " **hps.model).cuda()\n",
113
+ "_ = net_g.eval()\n",
114
+ "\n",
115
+ "_ = utils.load_checkpoint(\"/path/to/pretrained_vctk.pth\", net_g, None)"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "stn_tst = get_text(\"VITS is Awesome!\", hps)\n",
125
+ "with torch.no_grad():\n",
126
+ " x_tst = stn_tst.cuda().unsqueeze(0)\n",
127
+ " x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()\n",
128
+ " sid = torch.LongTensor([4]).cuda()\n",
129
+ " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n",
130
+ "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "### Voice Conversion"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data)\n",
147
+ "collate_fn = TextAudioSpeakerCollate()\n",
148
+ "loader = DataLoader(dataset, num_workers=8, shuffle=False,\n",
149
+ " batch_size=1, pin_memory=True,\n",
150
+ " drop_last=True, collate_fn=collate_fn)\n",
151
+ "data_list = list(loader)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": null,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "with torch.no_grad():\n",
161
+ " x, x_lengths, spec, spec_lengths, y, y_lengths, sid_src = [x.cuda() for x in data_list[0]]\n",
162
+ " sid_tgt1 = torch.LongTensor([1]).cuda()\n",
163
+ " sid_tgt2 = torch.LongTensor([2]).cuda()\n",
164
+ " sid_tgt3 = torch.LongTensor([4]).cuda()\n",
165
+ " audio1 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt1)[0][0,0].data.cpu().float().numpy()\n",
166
+ " audio2 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt2)[0][0,0].data.cpu().float().numpy()\n",
167
+ " audio3 = net_g.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt3)[0][0,0].data.cpu().float().numpy()\n",
168
+ "print(\"Original SID: %d\" % sid_src.item())\n",
169
+ "ipd.display(ipd.Audio(y[0].cpu().numpy(), rate=hps.data.sampling_rate, normalize=False))\n",
170
+ "print(\"Converted SID: %d\" % sid_tgt1.item())\n",
171
+ "ipd.display(ipd.Audio(audio1, rate=hps.data.sampling_rate, normalize=False))\n",
172
+ "print(\"Converted SID: %d\" % sid_tgt2.item())\n",
173
+ "ipd.display(ipd.Audio(audio2, rate=hps.data.sampling_rate, normalize=False))\n",
174
+ "print(\"Converted SID: %d\" % sid_tgt3.item())\n",
175
+ "ipd.display(ipd.Audio(audio3, rate=hps.data.sampling_rate, normalize=False))"
176
+ ]
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.7.7"
196
+ }
197
+ },
198
+ "nbformat": 4,
199
+ "nbformat_minor": 4
200
+ }