Yusen commited on
Commit
19e199d
·
1 Parent(s): 22224e1

Upload 3 files

Browse files
Files changed (3) hide show
  1. data_utils.py +154 -0
  2. models.py +351 -0
  3. modules.py +342 -0
data_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data
7
+
8
+ import commons
9
+ from mel_processing import spectrogram_torch, spec_to_mel_torch
10
+ from utils import load_wav_to_torch, load_filepaths_and_text, transform
11
+
12
+ # import h5py
13
+
14
+
15
+ """Multi speaker version"""
16
+
17
+
18
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
19
+ """
20
+ 1) loads audio, speaker_id, text pairs
21
+ 2) normalizes text and converts them to sequences of integers
22
+ 3) computes spectrograms from audio files.
23
+ """
24
+
25
+ def __init__(self, audiopaths, hparams):
26
+ self.audiopaths = load_filepaths_and_text(audiopaths)
27
+ self.max_wav_value = hparams.data.max_wav_value
28
+ self.sampling_rate = hparams.data.sampling_rate
29
+ self.filter_length = hparams.data.filter_length
30
+ self.hop_length = hparams.data.hop_length
31
+ self.win_length = hparams.data.win_length
32
+ self.sampling_rate = hparams.data.sampling_rate
33
+ self.use_sr = hparams.train.use_sr
34
+ self.spec_len = hparams.train.max_speclen
35
+ self.spk_map = hparams.spk
36
+
37
+ random.seed(1234)
38
+ random.shuffle(self.audiopaths)
39
+
40
+ def get_audio(self, filename):
41
+ filename = filename.replace("\\", "/")
42
+ audio, sampling_rate = load_wav_to_torch(filename)
43
+ if sampling_rate != self.sampling_rate:
44
+ raise ValueError("{} SR doesn't match target {} SR".format(
45
+ sampling_rate, self.sampling_rate))
46
+ audio_norm = audio / self.max_wav_value
47
+ audio_norm = audio_norm.unsqueeze(0)
48
+ spec_filename = filename.replace(".wav", ".spec.pt")
49
+ if os.path.exists(spec_filename):
50
+ spec = torch.load(spec_filename)
51
+ else:
52
+ spec = spectrogram_torch(audio_norm, self.filter_length,
53
+ self.sampling_rate, self.hop_length, self.win_length,
54
+ center=False)
55
+ spec = torch.squeeze(spec, 0)
56
+ torch.save(spec, spec_filename)
57
+
58
+ spk = filename.split("/")[-2]
59
+ spk = torch.LongTensor([self.spk_map[spk]])
60
+
61
+ c = torch.load(filename + ".soft.pt").squeeze(0)
62
+ c = torch.repeat_interleave(c, repeats=2, dim=1)
63
+
64
+ f0 = np.load(filename + ".f0.npy")
65
+ f0 = torch.FloatTensor(f0)
66
+ lmin = min(c.size(-1), spec.size(-1), f0.shape[0])
67
+ assert abs(c.size(-1) - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape, filename)
68
+ assert abs(lmin - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape)
69
+ assert abs(lmin - c.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape)
70
+ spec, c, f0 = spec[:, :lmin], c[:, :lmin], f0[:lmin]
71
+ audio_norm = audio_norm[:, :lmin * self.hop_length]
72
+ _spec, _c, _audio_norm, _f0 = spec, c, audio_norm, f0
73
+ while spec.size(-1) < self.spec_len:
74
+ spec = torch.cat((spec, _spec), -1)
75
+ c = torch.cat((c, _c), -1)
76
+ f0 = torch.cat((f0, _f0), -1)
77
+ audio_norm = torch.cat((audio_norm, _audio_norm), -1)
78
+ start = random.randint(0, spec.size(-1) - self.spec_len)
79
+ end = start + self.spec_len
80
+ spec = spec[:, start:end]
81
+ c = c[:, start:end]
82
+ f0 = f0[start:end]
83
+ audio_norm = audio_norm[:, start * self.hop_length:end * self.hop_length]
84
+
85
+ return c, f0, spec, audio_norm, spk
86
+
87
+ def __getitem__(self, index):
88
+ return self.get_audio(self.audiopaths[index][0])
89
+
90
+ def __len__(self):
91
+ return len(self.audiopaths)
92
+
93
+
94
+ class EvalDataLoader(torch.utils.data.Dataset):
95
+ """
96
+ 1) loads audio, speaker_id, text pairs
97
+ 2) normalizes text and converts them to sequences of integers
98
+ 3) computes spectrograms from audio files.
99
+ """
100
+
101
+ def __init__(self, audiopaths, hparams):
102
+ self.audiopaths = load_filepaths_and_text(audiopaths)
103
+ self.max_wav_value = hparams.data.max_wav_value
104
+ self.sampling_rate = hparams.data.sampling_rate
105
+ self.filter_length = hparams.data.filter_length
106
+ self.hop_length = hparams.data.hop_length
107
+ self.win_length = hparams.data.win_length
108
+ self.sampling_rate = hparams.data.sampling_rate
109
+ self.use_sr = hparams.train.use_sr
110
+ self.audiopaths = self.audiopaths[:5]
111
+ self.spk_map = hparams.spk
112
+
113
+
114
+ def get_audio(self, filename):
115
+ filename = filename.replace("\\", "/")
116
+ audio, sampling_rate = load_wav_to_torch(filename)
117
+ if sampling_rate != self.sampling_rate:
118
+ raise ValueError("{} SR doesn't match target {} SR".format(
119
+ sampling_rate, self.sampling_rate))
120
+ audio_norm = audio / self.max_wav_value
121
+ audio_norm = audio_norm.unsqueeze(0)
122
+ spec_filename = filename.replace(".wav", ".spec.pt")
123
+ if os.path.exists(spec_filename):
124
+ spec = torch.load(spec_filename)
125
+ else:
126
+ spec = spectrogram_torch(audio_norm, self.filter_length,
127
+ self.sampling_rate, self.hop_length, self.win_length,
128
+ center=False)
129
+ spec = torch.squeeze(spec, 0)
130
+ torch.save(spec, spec_filename)
131
+
132
+ spk = filename.split("/")[-2]
133
+ spk = torch.LongTensor([self.spk_map[spk]])
134
+
135
+ c = torch.load(filename + ".soft.pt").squeeze(0)
136
+
137
+ c = torch.repeat_interleave(c, repeats=2, dim=1)
138
+
139
+ f0 = np.load(filename + ".f0.npy")
140
+ f0 = torch.FloatTensor(f0)
141
+ lmin = min(c.size(-1), spec.size(-1), f0.shape[0])
142
+ assert abs(c.size(-1) - spec.size(-1)) < 4, (c.size(-1), spec.size(-1), f0.shape)
143
+ assert abs(f0.shape[0] - spec.shape[-1]) < 4, (c.size(-1), spec.size(-1), f0.shape)
144
+ spec, c, f0 = spec[:, :lmin], c[:, :lmin], f0[:lmin]
145
+ audio_norm = audio_norm[:, :lmin * self.hop_length]
146
+
147
+ return c, f0, spec, audio_norm, spk
148
+
149
+ def __getitem__(self, index):
150
+ return self.get_audio(self.audiopaths[index][0])
151
+
152
+ def __len__(self):
153
+ return len(self.audiopaths)
154
+
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import attentions
8
+ import commons
9
+ import modules
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from commons import init_weights, get_padding
14
+ from vdecoder.hifigan.models import Generator
15
+ from utils import f0_to_coarse
16
+
17
+ class ResidualCouplingBlock(nn.Module):
18
+ def __init__(self,
19
+ channels,
20
+ hidden_channels,
21
+ kernel_size,
22
+ dilation_rate,
23
+ n_layers,
24
+ n_flows=4,
25
+ gin_channels=0):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.hidden_channels = hidden_channels
29
+ self.kernel_size = kernel_size
30
+ self.dilation_rate = dilation_rate
31
+ self.n_layers = n_layers
32
+ self.n_flows = n_flows
33
+ self.gin_channels = gin_channels
34
+
35
+ self.flows = nn.ModuleList()
36
+ for i in range(n_flows):
37
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
38
+ self.flows.append(modules.Flip())
39
+
40
+ def forward(self, x, x_mask, g=None, reverse=False):
41
+ if not reverse:
42
+ for flow in self.flows:
43
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
44
+ else:
45
+ for flow in reversed(self.flows):
46
+ x = flow(x, x_mask, g=g, reverse=reverse)
47
+ return x
48
+
49
+
50
+ class Encoder(nn.Module):
51
+ def __init__(self,
52
+ in_channels,
53
+ out_channels,
54
+ hidden_channels,
55
+ kernel_size,
56
+ dilation_rate,
57
+ n_layers,
58
+ gin_channels=0):
59
+ super().__init__()
60
+ self.in_channels = in_channels
61
+ self.out_channels = out_channels
62
+ self.hidden_channels = hidden_channels
63
+ self.kernel_size = kernel_size
64
+ self.dilation_rate = dilation_rate
65
+ self.n_layers = n_layers
66
+ self.gin_channels = gin_channels
67
+
68
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
69
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
71
+
72
+ def forward(self, x, x_lengths, g=None):
73
+ # print(x.shape,x_lengths.shape)
74
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
75
+ x = self.pre(x) * x_mask
76
+ x = self.enc(x, x_mask, g=g)
77
+ stats = self.proj(x) * x_mask
78
+ m, logs = torch.split(stats, self.out_channels, dim=1)
79
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
80
+ return z, m, logs, x_mask
81
+
82
+
83
+ class TextEncoder(nn.Module):
84
+ def __init__(self,
85
+ in_channels,
86
+ out_channels,
87
+ hidden_channels,
88
+ kernel_size,
89
+ dilation_rate,
90
+ n_layers,
91
+ gin_channels=0,
92
+ filter_channels=None,
93
+ n_heads=None,
94
+ p_dropout=None):
95
+ super().__init__()
96
+ self.in_channels = in_channels
97
+ self.out_channels = out_channels
98
+ self.hidden_channels = hidden_channels
99
+ self.kernel_size = kernel_size
100
+ self.dilation_rate = dilation_rate
101
+ self.n_layers = n_layers
102
+ self.gin_channels = gin_channels
103
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
104
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
105
+ self.f0_emb = nn.Embedding(256, hidden_channels)
106
+
107
+ self.enc_ = attentions.Encoder(
108
+ hidden_channels,
109
+ filter_channels,
110
+ n_heads,
111
+ n_layers,
112
+ kernel_size,
113
+ p_dropout)
114
+
115
+ def forward(self, x, x_lengths, f0=None):
116
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
117
+ x = self.pre(x) * x_mask
118
+ x = x + self.f0_emb(f0).transpose(1,2)
119
+ x = self.enc_(x * x_mask, x_mask)
120
+ stats = self.proj(x) * x_mask
121
+ m, logs = torch.split(stats, self.out_channels, dim=1)
122
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
123
+
124
+ return z, m, logs, x_mask
125
+
126
+
127
+
128
+ class DiscriminatorP(torch.nn.Module):
129
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
130
+ super(DiscriminatorP, self).__init__()
131
+ self.period = period
132
+ self.use_spectral_norm = use_spectral_norm
133
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
134
+ self.convs = nn.ModuleList([
135
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
136
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
137
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
138
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
139
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
140
+ ])
141
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
142
+
143
+ def forward(self, x):
144
+ fmap = []
145
+
146
+ # 1d to 2d
147
+ b, c, t = x.shape
148
+ if t % self.period != 0: # pad first
149
+ n_pad = self.period - (t % self.period)
150
+ x = F.pad(x, (0, n_pad), "reflect")
151
+ t = t + n_pad
152
+ x = x.view(b, c, t // self.period, self.period)
153
+
154
+ for l in self.convs:
155
+ x = l(x)
156
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
157
+ fmap.append(x)
158
+ x = self.conv_post(x)
159
+ fmap.append(x)
160
+ x = torch.flatten(x, 1, -1)
161
+
162
+ return x, fmap
163
+
164
+
165
+ class DiscriminatorS(torch.nn.Module):
166
+ def __init__(self, use_spectral_norm=False):
167
+ super(DiscriminatorS, self).__init__()
168
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
169
+ self.convs = nn.ModuleList([
170
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
171
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
172
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
173
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
174
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
175
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
176
+ ])
177
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
178
+
179
+ def forward(self, x):
180
+ fmap = []
181
+
182
+ for l in self.convs:
183
+ x = l(x)
184
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
185
+ fmap.append(x)
186
+ x = self.conv_post(x)
187
+ fmap.append(x)
188
+ x = torch.flatten(x, 1, -1)
189
+
190
+ return x, fmap
191
+
192
+
193
+ class MultiPeriodDiscriminator(torch.nn.Module):
194
+ def __init__(self, use_spectral_norm=False):
195
+ super(MultiPeriodDiscriminator, self).__init__()
196
+ periods = [2,3,5,7,11]
197
+
198
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
199
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
200
+ self.discriminators = nn.ModuleList(discs)
201
+
202
+ def forward(self, y, y_hat):
203
+ y_d_rs = []
204
+ y_d_gs = []
205
+ fmap_rs = []
206
+ fmap_gs = []
207
+ for i, d in enumerate(self.discriminators):
208
+ y_d_r, fmap_r = d(y)
209
+ y_d_g, fmap_g = d(y_hat)
210
+ y_d_rs.append(y_d_r)
211
+ y_d_gs.append(y_d_g)
212
+ fmap_rs.append(fmap_r)
213
+ fmap_gs.append(fmap_g)
214
+
215
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
216
+
217
+
218
+ class SpeakerEncoder(torch.nn.Module):
219
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
220
+ super(SpeakerEncoder, self).__init__()
221
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
222
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
223
+ self.relu = nn.ReLU()
224
+
225
+ def forward(self, mels):
226
+ self.lstm.flatten_parameters()
227
+ _, (hidden, _) = self.lstm(mels)
228
+ embeds_raw = self.relu(self.linear(hidden[-1]))
229
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
230
+
231
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
232
+ mel_slices = []
233
+ for i in range(0, total_frames-partial_frames, partial_hop):
234
+ mel_range = torch.arange(i, i+partial_frames)
235
+ mel_slices.append(mel_range)
236
+
237
+ return mel_slices
238
+
239
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
240
+ mel_len = mel.size(1)
241
+ last_mel = mel[:,-partial_frames:]
242
+
243
+ if mel_len > partial_frames:
244
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
245
+ mels = list(mel[:,s] for s in mel_slices)
246
+ mels.append(last_mel)
247
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
248
+
249
+ with torch.no_grad():
250
+ partial_embeds = self(mels)
251
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
252
+ #embed = embed / torch.linalg.norm(embed, 2)
253
+ else:
254
+ with torch.no_grad():
255
+ embed = self(last_mel)
256
+
257
+ return embed
258
+
259
+
260
+ class SynthesizerTrn(nn.Module):
261
+ """
262
+ Synthesizer for Training
263
+ """
264
+
265
+ def __init__(self,
266
+ spec_channels,
267
+ segment_size,
268
+ inter_channels,
269
+ hidden_channels,
270
+ filter_channels,
271
+ n_heads,
272
+ n_layers,
273
+ kernel_size,
274
+ p_dropout,
275
+ resblock,
276
+ resblock_kernel_sizes,
277
+ resblock_dilation_sizes,
278
+ upsample_rates,
279
+ upsample_initial_channel,
280
+ upsample_kernel_sizes,
281
+ gin_channels,
282
+ ssl_dim,
283
+ n_speakers,
284
+ **kwargs):
285
+
286
+ super().__init__()
287
+ self.spec_channels = spec_channels
288
+ self.inter_channels = inter_channels
289
+ self.hidden_channels = hidden_channels
290
+ self.filter_channels = filter_channels
291
+ self.n_heads = n_heads
292
+ self.n_layers = n_layers
293
+ self.kernel_size = kernel_size
294
+ self.p_dropout = p_dropout
295
+ self.resblock = resblock
296
+ self.resblock_kernel_sizes = resblock_kernel_sizes
297
+ self.resblock_dilation_sizes = resblock_dilation_sizes
298
+ self.upsample_rates = upsample_rates
299
+ self.upsample_initial_channel = upsample_initial_channel
300
+ self.upsample_kernel_sizes = upsample_kernel_sizes
301
+ self.segment_size = segment_size
302
+ self.gin_channels = gin_channels
303
+ self.ssl_dim = ssl_dim
304
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
305
+
306
+ self.enc_p_ = TextEncoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16,0, filter_channels, n_heads, p_dropout)
307
+ hps = {
308
+ "sampling_rate": 32000,
309
+ "inter_channels": 192,
310
+ "resblock": "1",
311
+ "resblock_kernel_sizes": [3, 7, 11],
312
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
313
+ "upsample_rates": [10, 8, 2, 2],
314
+ "upsample_initial_channel": 512,
315
+ "upsample_kernel_sizes": [16, 16, 4, 4],
316
+ "gin_channels": 256,
317
+ }
318
+ self.dec = Generator(h=hps)
319
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
320
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
321
+
322
+ def forward(self, c, f0, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
323
+ if c_lengths == None:
324
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
325
+ if spec_lengths == None:
326
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
327
+
328
+ g = self.emb_g(g).transpose(1,2)
329
+
330
+ z_ptemp, m_p, logs_p, _ = self.enc_p_(c, c_lengths, f0=f0_to_coarse(f0))
331
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
332
+
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+ z_slice, pitch_slice, ids_slice = commons.rand_slice_segments_with_pitch(z, f0, spec_lengths, self.segment_size)
335
+
336
+ # o = self.dec(z_slice, g=g)
337
+ o = self.dec(z_slice, g=g, f0=pitch_slice)
338
+
339
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
340
+
341
+ def infer(self, c, f0, g=None, mel=None, c_lengths=None):
342
+ if c_lengths == None:
343
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
344
+ g = self.emb_g(g).transpose(1,2)
345
+
346
+ z_p, m_p, logs_p, c_mask = self.enc_p_(c, c_lengths, f0=f0_to_coarse(f0))
347
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
348
+
349
+ o = self.dec(z * c_mask, g=g, f0=f0)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x