HaiLua commited on
Commit
96801a1
1 Parent(s): 364160a

Create infer_pack/models.py

Browse files
Files changed (1) hide show
  1. infer_pack/models.py +982 -0
infer_pack/models.py ADDED
@@ -0,0 +1,982 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math, pdb, os
2
+ from time import time as ttime
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from infer_pack import modules
7
+ from infer_pack import attentions
8
+ from infer_pack import commons
9
+ from infer_pack.commons import init_weights, get_padding
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from infer_pack.commons import init_weights
13
+ import numpy as np
14
+ from infer_pack import commons
15
+
16
+
17
+ class TextEncoder256(nn.Module):
18
+ def __init__(
19
+ self,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ f0=True,
28
+ ):
29
+ super().__init__()
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emb_phone = nn.Linear(256, hidden_channels)
38
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
+ if f0 == True:
40
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
+ self.encoder = attentions.Encoder(
42
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
+ )
44
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
+
46
+ def forward(self, phone, pitch, lengths):
47
+ if pitch == None:
48
+ x = self.emb_phone(phone)
49
+ else:
50
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
+ x = self.lrelu(x)
53
+ x = torch.transpose(x, 1, -1) # [b, h, t]
54
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
+ x.dtype
56
+ )
57
+ x = self.encoder(x * x_mask, x_mask)
58
+ stats = self.proj(x) * x_mask
59
+
60
+ m, logs = torch.split(stats, self.out_channels, dim=1)
61
+ return m, logs, x_mask
62
+
63
+
64
+ class TextEncoder256Sim(nn.Module):
65
+ def __init__(
66
+ self,
67
+ out_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ n_heads,
71
+ n_layers,
72
+ kernel_size,
73
+ p_dropout,
74
+ f0=True,
75
+ ):
76
+ super().__init__()
77
+ self.out_channels = out_channels
78
+ self.hidden_channels = hidden_channels
79
+ self.filter_channels = filter_channels
80
+ self.n_heads = n_heads
81
+ self.n_layers = n_layers
82
+ self.kernel_size = kernel_size
83
+ self.p_dropout = p_dropout
84
+ self.emb_phone = nn.Linear(256, hidden_channels)
85
+ self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
+ if f0 == True:
87
+ self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
+ self.encoder = attentions.Encoder(
89
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
+ )
91
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
+
93
+ def forward(self, phone, pitch, lengths):
94
+ if pitch == None:
95
+ x = self.emb_phone(phone)
96
+ else:
97
+ x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
+ x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
+ x = self.lrelu(x)
100
+ x = torch.transpose(x, 1, -1) # [b, h, t]
101
+ x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
+ x.dtype
103
+ )
104
+ x = self.encoder(x * x_mask, x_mask)
105
+ x = self.proj(x) * x_mask
106
+ return x, x_mask
107
+
108
+
109
+ class ResidualCouplingBlock(nn.Module):
110
+ def __init__(
111
+ self,
112
+ channels,
113
+ hidden_channels,
114
+ kernel_size,
115
+ dilation_rate,
116
+ n_layers,
117
+ n_flows=4,
118
+ gin_channels=0,
119
+ ):
120
+ super().__init__()
121
+ self.channels = channels
122
+ self.hidden_channels = hidden_channels
123
+ self.kernel_size = kernel_size
124
+ self.dilation_rate = dilation_rate
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.ResidualCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ dilation_rate,
137
+ n_layers,
138
+ gin_channels=gin_channels,
139
+ mean_only=True,
140
+ )
141
+ )
142
+ self.flows.append(modules.Flip())
143
+
144
+ def forward(self, x, x_mask, g=None, reverse=False):
145
+ if not reverse:
146
+ for flow in self.flows:
147
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
+ else:
149
+ for flow in reversed(self.flows):
150
+ x = flow(x, x_mask, g=g, reverse=reverse)
151
+ return x
152
+
153
+ def remove_weight_norm(self):
154
+ for i in range(self.n_flows):
155
+ self.flows[i * 2].remove_weight_norm()
156
+
157
+
158
+ class PosteriorEncoder(nn.Module):
159
+ def __init__(
160
+ self,
161
+ in_channels,
162
+ out_channels,
163
+ hidden_channels,
164
+ kernel_size,
165
+ dilation_rate,
166
+ n_layers,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ self.in_channels = in_channels
171
+ self.out_channels = out_channels
172
+ self.hidden_channels = hidden_channels
173
+ self.kernel_size = kernel_size
174
+ self.dilation_rate = dilation_rate
175
+ self.n_layers = n_layers
176
+ self.gin_channels = gin_channels
177
+
178
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
+ self.enc = modules.WN(
180
+ hidden_channels,
181
+ kernel_size,
182
+ dilation_rate,
183
+ n_layers,
184
+ gin_channels=gin_channels,
185
+ )
186
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
+
188
+ def forward(self, x, x_lengths, g=None):
189
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
+ x.dtype
191
+ )
192
+ x = self.pre(x) * x_mask
193
+ x = self.enc(x, x_mask, g=g)
194
+ stats = self.proj(x) * x_mask
195
+ m, logs = torch.split(stats, self.out_channels, dim=1)
196
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
+ return z, m, logs, x_mask
198
+
199
+ def remove_weight_norm(self):
200
+ self.enc.remove_weight_norm()
201
+
202
+
203
+ class Generator(torch.nn.Module):
204
+ def __init__(
205
+ self,
206
+ initial_channel,
207
+ resblock,
208
+ resblock_kernel_sizes,
209
+ resblock_dilation_sizes,
210
+ upsample_rates,
211
+ upsample_initial_channel,
212
+ upsample_kernel_sizes,
213
+ gin_channels=0,
214
+ ):
215
+ super(Generator, self).__init__()
216
+ self.num_kernels = len(resblock_kernel_sizes)
217
+ self.num_upsamples = len(upsample_rates)
218
+ self.conv_pre = Conv1d(
219
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
220
+ )
221
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
+
223
+ self.ups = nn.ModuleList()
224
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
+ self.ups.append(
226
+ weight_norm(
227
+ ConvTranspose1d(
228
+ upsample_initial_channel // (2**i),
229
+ upsample_initial_channel // (2 ** (i + 1)),
230
+ k,
231
+ u,
232
+ padding=(k - u) // 2,
233
+ )
234
+ )
235
+ )
236
+
237
+ self.resblocks = nn.ModuleList()
238
+ for i in range(len(self.ups)):
239
+ ch = upsample_initial_channel // (2 ** (i + 1))
240
+ for j, (k, d) in enumerate(
241
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
+ ):
243
+ self.resblocks.append(resblock(ch, k, d))
244
+
245
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
+ self.ups.apply(init_weights)
247
+
248
+ if gin_channels != 0:
249
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
+
251
+ def forward(self, x, g=None):
252
+ x = self.conv_pre(x)
253
+ if g is not None:
254
+ x = x + self.cond(g)
255
+
256
+ for i in range(self.num_upsamples):
257
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
+ x = self.ups[i](x)
259
+ xs = None
260
+ for j in range(self.num_kernels):
261
+ if xs is None:
262
+ xs = self.resblocks[i * self.num_kernels + j](x)
263
+ else:
264
+ xs += self.resblocks[i * self.num_kernels + j](x)
265
+ x = xs / self.num_kernels
266
+ x = F.leaky_relu(x)
267
+ x = self.conv_post(x)
268
+ x = torch.tanh(x)
269
+
270
+ return x
271
+
272
+ def remove_weight_norm(self):
273
+ for l in self.ups:
274
+ remove_weight_norm(l)
275
+ for l in self.resblocks:
276
+ l.remove_weight_norm()
277
+
278
+
279
+ class SineGen(torch.nn.Module):
280
+ """Definition of sine generator
281
+ SineGen(samp_rate, harmonic_num = 0,
282
+ sine_amp = 0.1, noise_std = 0.003,
283
+ voiced_threshold = 0,
284
+ flag_for_pulse=False)
285
+ samp_rate: sampling rate in Hz
286
+ harmonic_num: number of harmonic overtones (default 0)
287
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
288
+ noise_std: std of Gaussian noise (default 0.003)
289
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
+ Note: when flag_for_pulse is True, the first time step of a voiced
292
+ segment is always sin(np.pi) or cos(0)
293
+ """
294
+
295
+ def __init__(
296
+ self,
297
+ samp_rate,
298
+ harmonic_num=0,
299
+ sine_amp=0.1,
300
+ noise_std=0.003,
301
+ voiced_threshold=0,
302
+ flag_for_pulse=False,
303
+ ):
304
+ super(SineGen, self).__init__()
305
+ self.sine_amp = sine_amp
306
+ self.noise_std = noise_std
307
+ self.harmonic_num = harmonic_num
308
+ self.dim = self.harmonic_num + 1
309
+ self.sampling_rate = samp_rate
310
+ self.voiced_threshold = voiced_threshold
311
+
312
+ def _f02uv(self, f0):
313
+ # generate uv signal
314
+ uv = torch.ones_like(f0)
315
+ uv = uv * (f0 > self.voiced_threshold)
316
+ return uv
317
+
318
+ def forward(self, f0, upp):
319
+ """sine_tensor, uv = forward(f0)
320
+ input F0: tensor(batchsize=1, length, dim=1)
321
+ f0 for unvoiced steps should be 0
322
+ output sine_tensor: tensor(batchsize=1, length, dim)
323
+ output uv: tensor(batchsize=1, length, 1)
324
+ """
325
+ with torch.no_grad():
326
+ f0 = f0[:, None].transpose(1, 2)
327
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
+ # fundamental component
329
+ f0_buf[:, :, 0] = f0[:, :, 0]
330
+ for idx in np.arange(self.harmonic_num):
331
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
+ idx + 2
333
+ ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
+ rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
+ rand_ini = torch.rand(
336
+ f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
+ )
338
+ rand_ini[:, 0] = 0
339
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
+ tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
+ tmp_over_one *= upp
342
+ tmp_over_one = F.interpolate(
343
+ tmp_over_one.transpose(2, 1),
344
+ scale_factor=upp,
345
+ mode="linear",
346
+ align_corners=True,
347
+ ).transpose(2, 1)
348
+ rad_values = F.interpolate(
349
+ rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
+ ).transpose(
351
+ 2, 1
352
+ ) #######
353
+ tmp_over_one %= 1
354
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
+ cumsum_shift = torch.zeros_like(rad_values)
356
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
+ sine_waves = torch.sin(
358
+ torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
+ )
360
+ sine_waves = sine_waves * self.sine_amp
361
+ uv = self._f02uv(f0)
362
+ uv = F.interpolate(
363
+ uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
+ ).transpose(2, 1)
365
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
+ noise = noise_amp * torch.randn_like(sine_waves)
367
+ sine_waves = sine_waves * uv + noise
368
+ return sine_waves, uv, noise
369
+
370
+
371
+ class SourceModuleHnNSF(torch.nn.Module):
372
+ """SourceModule for hn-nsf
373
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
+ add_noise_std=0.003, voiced_threshod=0)
375
+ sampling_rate: sampling_rate in Hz
376
+ harmonic_num: number of harmonic above F0 (default: 0)
377
+ sine_amp: amplitude of sine source signal (default: 0.1)
378
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
379
+ note that amplitude of noise in unvoiced is decided
380
+ by sine_amp
381
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
382
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
+ F0_sampled (batchsize, length, 1)
384
+ Sine_source (batchsize, length, 1)
385
+ noise_source (batchsize, length 1)
386
+ uv (batchsize, length, 1)
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ sampling_rate,
392
+ harmonic_num=0,
393
+ sine_amp=0.1,
394
+ add_noise_std=0.003,
395
+ voiced_threshod=0,
396
+ is_half=True,
397
+ ):
398
+ super(SourceModuleHnNSF, self).__init__()
399
+
400
+ self.sine_amp = sine_amp
401
+ self.noise_std = add_noise_std
402
+ self.is_half = is_half
403
+ # to produce sine waveforms
404
+ self.l_sin_gen = SineGen(
405
+ sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
+ )
407
+
408
+ # to merge source harmonics into a single excitation
409
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
+ self.l_tanh = torch.nn.Tanh()
411
+
412
+ def forward(self, x, upp=None):
413
+ sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
+ if self.is_half:
415
+ sine_wavs = sine_wavs.half()
416
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
+ return sine_merge, None, None # noise, uv
418
+
419
+
420
+ class GeneratorNSF(torch.nn.Module):
421
+ def __init__(
422
+ self,
423
+ initial_channel,
424
+ resblock,
425
+ resblock_kernel_sizes,
426
+ resblock_dilation_sizes,
427
+ upsample_rates,
428
+ upsample_initial_channel,
429
+ upsample_kernel_sizes,
430
+ gin_channels,
431
+ sr,
432
+ is_half=False,
433
+ ):
434
+ super(GeneratorNSF, self).__init__()
435
+ self.num_kernels = len(resblock_kernel_sizes)
436
+ self.num_upsamples = len(upsample_rates)
437
+
438
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
+ self.m_source = SourceModuleHnNSF(
440
+ sampling_rate=sr, harmonic_num=0, is_half=is_half
441
+ )
442
+ self.noise_convs = nn.ModuleList()
443
+ self.conv_pre = Conv1d(
444
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
445
+ )
446
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
+
448
+ self.ups = nn.ModuleList()
449
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
451
+ self.ups.append(
452
+ weight_norm(
453
+ ConvTranspose1d(
454
+ upsample_initial_channel // (2**i),
455
+ upsample_initial_channel // (2 ** (i + 1)),
456
+ k,
457
+ u,
458
+ padding=(k - u) // 2,
459
+ )
460
+ )
461
+ )
462
+ if i + 1 < len(upsample_rates):
463
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
464
+ self.noise_convs.append(
465
+ Conv1d(
466
+ 1,
467
+ c_cur,
468
+ kernel_size=stride_f0 * 2,
469
+ stride=stride_f0,
470
+ padding=stride_f0 // 2,
471
+ )
472
+ )
473
+ else:
474
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
+
476
+ self.resblocks = nn.ModuleList()
477
+ for i in range(len(self.ups)):
478
+ ch = upsample_initial_channel // (2 ** (i + 1))
479
+ for j, (k, d) in enumerate(
480
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
+ ):
482
+ self.resblocks.append(resblock(ch, k, d))
483
+
484
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
+ self.ups.apply(init_weights)
486
+
487
+ if gin_channels != 0:
488
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
+
490
+ self.upp = np.prod(upsample_rates)
491
+
492
+ def forward(self, x, f0, g=None):
493
+ har_source, noi_source, uv = self.m_source(f0, self.upp)
494
+ har_source = har_source.transpose(1, 2)
495
+ x = self.conv_pre(x)
496
+ if g is not None:
497
+ x = x + self.cond(g)
498
+
499
+ for i in range(self.num_upsamples):
500
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
+ x = self.ups[i](x)
502
+ x_source = self.noise_convs[i](har_source)
503
+ x = x + x_source
504
+ xs = None
505
+ for j in range(self.num_kernels):
506
+ if xs is None:
507
+ xs = self.resblocks[i * self.num_kernels + j](x)
508
+ else:
509
+ xs += self.resblocks[i * self.num_kernels + j](x)
510
+ x = xs / self.num_kernels
511
+ x = F.leaky_relu(x)
512
+ x = self.conv_post(x)
513
+ x = torch.tanh(x)
514
+ return x
515
+
516
+ def remove_weight_norm(self):
517
+ for l in self.ups:
518
+ remove_weight_norm(l)
519
+ for l in self.resblocks:
520
+ l.remove_weight_norm()
521
+
522
+
523
+ sr2sr = {
524
+ "32k": 32000,
525
+ "40k": 40000,
526
+ "48k": 48000,
527
+ }
528
+
529
+
530
+ class SynthesizerTrnMs256NSFsid(nn.Module):
531
+ def __init__(
532
+ self,
533
+ spec_channels,
534
+ segment_size,
535
+ inter_channels,
536
+ hidden_channels,
537
+ filter_channels,
538
+ n_heads,
539
+ n_layers,
540
+ kernel_size,
541
+ p_dropout,
542
+ resblock,
543
+ resblock_kernel_sizes,
544
+ resblock_dilation_sizes,
545
+ upsample_rates,
546
+ upsample_initial_channel,
547
+ upsample_kernel_sizes,
548
+ spk_embed_dim,
549
+ gin_channels,
550
+ sr,
551
+ **kwargs
552
+ ):
553
+ super().__init__()
554
+ if type(sr) == type("strr"):
555
+ sr = sr2sr[sr]
556
+ self.spec_channels = spec_channels
557
+ self.inter_channels = inter_channels
558
+ self.hidden_channels = hidden_channels
559
+ self.filter_channels = filter_channels
560
+ self.n_heads = n_heads
561
+ self.n_layers = n_layers
562
+ self.kernel_size = kernel_size
563
+ self.p_dropout = p_dropout
564
+ self.resblock = resblock
565
+ self.resblock_kernel_sizes = resblock_kernel_sizes
566
+ self.resblock_dilation_sizes = resblock_dilation_sizes
567
+ self.upsample_rates = upsample_rates
568
+ self.upsample_initial_channel = upsample_initial_channel
569
+ self.upsample_kernel_sizes = upsample_kernel_sizes
570
+ self.segment_size = segment_size
571
+ self.gin_channels = gin_channels
572
+ # self.hop_length = hop_length#
573
+ self.spk_embed_dim = spk_embed_dim
574
+ self.enc_p = TextEncoder256(
575
+ inter_channels,
576
+ hidden_channels,
577
+ filter_channels,
578
+ n_heads,
579
+ n_layers,
580
+ kernel_size,
581
+ p_dropout,
582
+ )
583
+ self.dec = GeneratorNSF(
584
+ inter_channels,
585
+ resblock,
586
+ resblock_kernel_sizes,
587
+ resblock_dilation_sizes,
588
+ upsample_rates,
589
+ upsample_initial_channel,
590
+ upsample_kernel_sizes,
591
+ gin_channels=gin_channels,
592
+ sr=sr,
593
+ is_half=kwargs["is_half"],
594
+ )
595
+ self.enc_q = PosteriorEncoder(
596
+ spec_channels,
597
+ inter_channels,
598
+ hidden_channels,
599
+ 5,
600
+ 1,
601
+ 16,
602
+ gin_channels=gin_channels,
603
+ )
604
+ self.flow = ResidualCouplingBlock(
605
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
+ )
607
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
+
610
+ def remove_weight_norm(self):
611
+ self.dec.remove_weight_norm()
612
+ self.flow.remove_weight_norm()
613
+ self.enc_q.remove_weight_norm()
614
+
615
+ def forward(
616
+ self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
617
+ ): # 这里ds是id,[bs,1]
618
+ # print(1,pitch.shape)#[bs,t]
619
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
620
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
621
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
622
+ z_p = self.flow(z, y_mask, g=g)
623
+ z_slice, ids_slice = commons.rand_slice_segments(
624
+ z, y_lengths, self.segment_size
625
+ )
626
+ # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
627
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
628
+ # print(-2,pitchf.shape,z_slice.shape)
629
+ o = self.dec(z_slice, pitchf, g=g)
630
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
631
+
632
+ def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
633
+ g = self.emb_g(sid).unsqueeze(-1)
634
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
635
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
636
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
637
+ o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
638
+ return o, x_mask, (z, z_p, m_p, logs_p)
639
+
640
+
641
+ class SynthesizerTrnMs256NSFsid_nono(nn.Module):
642
+ def __init__(
643
+ self,
644
+ spec_channels,
645
+ segment_size,
646
+ inter_channels,
647
+ hidden_channels,
648
+ filter_channels,
649
+ n_heads,
650
+ n_layers,
651
+ kernel_size,
652
+ p_dropout,
653
+ resblock,
654
+ resblock_kernel_sizes,
655
+ resblock_dilation_sizes,
656
+ upsample_rates,
657
+ upsample_initial_channel,
658
+ upsample_kernel_sizes,
659
+ spk_embed_dim,
660
+ gin_channels,
661
+ sr=None,
662
+ **kwargs
663
+ ):
664
+ super().__init__()
665
+ self.spec_channels = spec_channels
666
+ self.inter_channels = inter_channels
667
+ self.hidden_channels = hidden_channels
668
+ self.filter_channels = filter_channels
669
+ self.n_heads = n_heads
670
+ self.n_layers = n_layers
671
+ self.kernel_size = kernel_size
672
+ self.p_dropout = p_dropout
673
+ self.resblock = resblock
674
+ self.resblock_kernel_sizes = resblock_kernel_sizes
675
+ self.resblock_dilation_sizes = resblock_dilation_sizes
676
+ self.upsample_rates = upsample_rates
677
+ self.upsample_initial_channel = upsample_initial_channel
678
+ self.upsample_kernel_sizes = upsample_kernel_sizes
679
+ self.segment_size = segment_size
680
+ self.gin_channels = gin_channels
681
+ # self.hop_length = hop_length#
682
+ self.spk_embed_dim = spk_embed_dim
683
+ self.enc_p = TextEncoder256(
684
+ inter_channels,
685
+ hidden_channels,
686
+ filter_channels,
687
+ n_heads,
688
+ n_layers,
689
+ kernel_size,
690
+ p_dropout,
691
+ f0=False,
692
+ )
693
+ self.dec = Generator(
694
+ inter_channels,
695
+ resblock,
696
+ resblock_kernel_sizes,
697
+ resblock_dilation_sizes,
698
+ upsample_rates,
699
+ upsample_initial_channel,
700
+ upsample_kernel_sizes,
701
+ gin_channels=gin_channels,
702
+ )
703
+ self.enc_q = PosteriorEncoder(
704
+ spec_channels,
705
+ inter_channels,
706
+ hidden_channels,
707
+ 5,
708
+ 1,
709
+ 16,
710
+ gin_channels=gin_channels,
711
+ )
712
+ self.flow = ResidualCouplingBlock(
713
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
714
+ )
715
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
716
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
717
+
718
+ def remove_weight_norm(self):
719
+ self.dec.remove_weight_norm()
720
+ self.flow.remove_weight_norm()
721
+ self.enc_q.remove_weight_norm()
722
+
723
+ def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
724
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
725
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
726
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
727
+ z_p = self.flow(z, y_mask, g=g)
728
+ z_slice, ids_slice = commons.rand_slice_segments(
729
+ z, y_lengths, self.segment_size
730
+ )
731
+ o = self.dec(z_slice, g=g)
732
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
733
+
734
+ def infer(self, phone, phone_lengths, sid, max_len=None):
735
+ g = self.emb_g(sid).unsqueeze(-1)
736
+ m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
737
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
738
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
739
+ o = self.dec((z * x_mask)[:, :, :max_len], g=g)
740
+ return o, x_mask, (z, z_p, m_p, logs_p)
741
+
742
+
743
+ class SynthesizerTrnMs256NSFsid_sim(nn.Module):
744
+ """
745
+ Synthesizer for Training
746
+ """
747
+
748
+ def __init__(
749
+ self,
750
+ spec_channels,
751
+ segment_size,
752
+ inter_channels,
753
+ hidden_channels,
754
+ filter_channels,
755
+ n_heads,
756
+ n_layers,
757
+ kernel_size,
758
+ p_dropout,
759
+ resblock,
760
+ resblock_kernel_sizes,
761
+ resblock_dilation_sizes,
762
+ upsample_rates,
763
+ upsample_initial_channel,
764
+ upsample_kernel_sizes,
765
+ spk_embed_dim,
766
+ # hop_length,
767
+ gin_channels=0,
768
+ use_sdp=True,
769
+ **kwargs
770
+ ):
771
+ super().__init__()
772
+ self.spec_channels = spec_channels
773
+ self.inter_channels = inter_channels
774
+ self.hidden_channels = hidden_channels
775
+ self.filter_channels = filter_channels
776
+ self.n_heads = n_heads
777
+ self.n_layers = n_layers
778
+ self.kernel_size = kernel_size
779
+ self.p_dropout = p_dropout
780
+ self.resblock = resblock
781
+ self.resblock_kernel_sizes = resblock_kernel_sizes
782
+ self.resblock_dilation_sizes = resblock_dilation_sizes
783
+ self.upsample_rates = upsample_rates
784
+ self.upsample_initial_channel = upsample_initial_channel
785
+ self.upsample_kernel_sizes = upsample_kernel_sizes
786
+ self.segment_size = segment_size
787
+ self.gin_channels = gin_channels
788
+ # self.hop_length = hop_length#
789
+ self.spk_embed_dim = spk_embed_dim
790
+ self.enc_p = TextEncoder256Sim(
791
+ inter_channels,
792
+ hidden_channels,
793
+ filter_channels,
794
+ n_heads,
795
+ n_layers,
796
+ kernel_size,
797
+ p_dropout,
798
+ )
799
+ self.dec = GeneratorNSF(
800
+ inter_channels,
801
+ resblock,
802
+ resblock_kernel_sizes,
803
+ resblock_dilation_sizes,
804
+ upsample_rates,
805
+ upsample_initial_channel,
806
+ upsample_kernel_sizes,
807
+ gin_channels=gin_channels,
808
+ is_half=kwargs["is_half"],
809
+ )
810
+
811
+ self.flow = ResidualCouplingBlock(
812
+ inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
813
+ )
814
+ self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
815
+ print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
816
+
817
+ def remove_weight_norm(self):
818
+ self.dec.remove_weight_norm()
819
+ self.flow.remove_weight_norm()
820
+ self.enc_q.remove_weight_norm()
821
+
822
+ def forward(
823
+ self, phone, phone_lengths, pitch, pitchf, y_lengths, ds
824
+ ): # y是spec不需要了现在
825
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
826
+ x, x_mask = self.enc_p(phone, pitch, phone_lengths)
827
+ x = self.flow(x, x_mask, g=g, reverse=True)
828
+ z_slice, ids_slice = commons.rand_slice_segments(
829
+ x, y_lengths, self.segment_size
830
+ )
831
+
832
+ pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
833
+ o = self.dec(z_slice, pitchf, g=g)
834
+ return o, ids_slice
835
+
836
+ def infer(
837
+ self, phone, phone_lengths, pitch, pitchf, ds, max_len=None
838
+ ): # y是spec不需要了现在
839
+ g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
840
+ x, x_mask = self.enc_p(phone, pitch, phone_lengths)
841
+ x = self.flow(x, x_mask, g=g, reverse=True)
842
+ o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g)
843
+ return o, o
844
+
845
+
846
+ class MultiPeriodDiscriminator(torch.nn.Module):
847
+ def __init__(self, use_spectral_norm=False):
848
+ super(MultiPeriodDiscriminator, self).__init__()
849
+ periods = [2, 3, 5, 7, 11, 17]
850
+ # periods = [3, 5, 7, 11, 17, 23, 37]
851
+
852
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
853
+ discs = discs + [
854
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
855
+ ]
856
+ self.discriminators = nn.ModuleList(discs)
857
+
858
+ def forward(self, y, y_hat):
859
+ y_d_rs = [] #
860
+ y_d_gs = []
861
+ fmap_rs = []
862
+ fmap_gs = []
863
+ for i, d in enumerate(self.discriminators):
864
+ y_d_r, fmap_r = d(y)
865
+ y_d_g, fmap_g = d(y_hat)
866
+ # for j in range(len(fmap_r)):
867
+ # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
868
+ y_d_rs.append(y_d_r)
869
+ y_d_gs.append(y_d_g)
870
+ fmap_rs.append(fmap_r)
871
+ fmap_gs.append(fmap_g)
872
+
873
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
874
+
875
+
876
+ class DiscriminatorS(torch.nn.Module):
877
+ def __init__(self, use_spectral_norm=False):
878
+ super(DiscriminatorS, self).__init__()
879
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
880
+ self.convs = nn.ModuleList(
881
+ [
882
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
883
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
884
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
885
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
886
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
887
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
888
+ ]
889
+ )
890
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
891
+
892
+ def forward(self, x):
893
+ fmap = []
894
+
895
+ for l in self.convs:
896
+ x = l(x)
897
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
898
+ fmap.append(x)
899
+ x = self.conv_post(x)
900
+ fmap.append(x)
901
+ x = torch.flatten(x, 1, -1)
902
+
903
+ return x, fmap
904
+
905
+
906
+ class DiscriminatorP(torch.nn.Module):
907
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
908
+ super(DiscriminatorP, self).__init__()
909
+ self.period = period
910
+ self.use_spectral_norm = use_spectral_norm
911
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
912
+ self.convs = nn.ModuleList(
913
+ [
914
+ norm_f(
915
+ Conv2d(
916
+ 1,
917
+ 32,
918
+ (kernel_size, 1),
919
+ (stride, 1),
920
+ padding=(get_padding(kernel_size, 1), 0),
921
+ )
922
+ ),
923
+ norm_f(
924
+ Conv2d(
925
+ 32,
926
+ 128,
927
+ (kernel_size, 1),
928
+ (stride, 1),
929
+ padding=(get_padding(kernel_size, 1), 0),
930
+ )
931
+ ),
932
+ norm_f(
933
+ Conv2d(
934
+ 128,
935
+ 512,
936
+ (kernel_size, 1),
937
+ (stride, 1),
938
+ padding=(get_padding(kernel_size, 1), 0),
939
+ )
940
+ ),
941
+ norm_f(
942
+ Conv2d(
943
+ 512,
944
+ 1024,
945
+ (kernel_size, 1),
946
+ (stride, 1),
947
+ padding=(get_padding(kernel_size, 1), 0),
948
+ )
949
+ ),
950
+ norm_f(
951
+ Conv2d(
952
+ 1024,
953
+ 1024,
954
+ (kernel_size, 1),
955
+ 1,
956
+ padding=(get_padding(kernel_size, 1), 0),
957
+ )
958
+ ),
959
+ ]
960
+ )
961
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
962
+
963
+ def forward(self, x):
964
+ fmap = []
965
+
966
+ # 1d to 2d
967
+ b, c, t = x.shape
968
+ if t % self.period != 0: # pad first
969
+ n_pad = self.period - (t % self.period)
970
+ x = F.pad(x, (0, n_pad), "reflect")
971
+ t = t + n_pad
972
+ x = x.view(b, c, t // self.period, self.period)
973
+
974
+ for l in self.convs:
975
+ x = l(x)
976
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
977
+ fmap.append(x)
978
+ x = self.conv_post(x)
979
+ fmap.append(x)
980
+ x = torch.flatten(x, 1, -1)
981
+
982
+ return x, fmap