tracert commited on
Commit
c94aca8
·
verified ·
1 Parent(s): fc1d5e3

Upload 2 files

Browse files
Files changed (2) hide show
  1. models.py +711 -0
  2. utils.py +83 -0
models.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #coding:utf-8
2
+
3
+ import os
4
+ import os.path as osp
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+ from munch import Munch
19
+ import yaml
20
+
21
+ class LearnedDownSample(nn.Module):
22
+ def __init__(self, layer_type, dim_in):
23
+ super().__init__()
24
+ self.layer_type = layer_type
25
+
26
+ if self.layer_type == 'none':
27
+ self.conv = nn.Identity()
28
+ elif self.layer_type == 'timepreserve':
29
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
30
+ elif self.layer_type == 'half':
31
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
32
+ else:
33
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
34
+
35
+ def forward(self, x):
36
+ return self.conv(x)
37
+
38
+ class LearnedUpSample(nn.Module):
39
+ def __init__(self, layer_type, dim_in):
40
+ super().__init__()
41
+ self.layer_type = layer_type
42
+
43
+ if self.layer_type == 'none':
44
+ self.conv = nn.Identity()
45
+ elif self.layer_type == 'timepreserve':
46
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
47
+ elif self.layer_type == 'half':
48
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
49
+ else:
50
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
51
+
52
+
53
+ def forward(self, x):
54
+ return self.conv(x)
55
+
56
+ class DownSample(nn.Module):
57
+ def __init__(self, layer_type):
58
+ super().__init__()
59
+ self.layer_type = layer_type
60
+
61
+ def forward(self, x):
62
+ if self.layer_type == 'none':
63
+ return x
64
+ elif self.layer_type == 'timepreserve':
65
+ return F.avg_pool2d(x, (2, 1))
66
+ elif self.layer_type == 'half':
67
+ if x.shape[-1] % 2 != 0:
68
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
69
+ return F.avg_pool2d(x, 2)
70
+ else:
71
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
72
+
73
+
74
+ class UpSample(nn.Module):
75
+ def __init__(self, layer_type):
76
+ super().__init__()
77
+ self.layer_type = layer_type
78
+
79
+ def forward(self, x):
80
+ if self.layer_type == 'none':
81
+ return x
82
+ elif self.layer_type == 'timepreserve':
83
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
84
+ elif self.layer_type == 'half':
85
+ return F.interpolate(x, scale_factor=2, mode='nearest')
86
+ else:
87
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
88
+
89
+
90
+ class ResBlk(nn.Module):
91
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
92
+ normalize=False, downsample='none'):
93
+ super().__init__()
94
+ self.actv = actv
95
+ self.normalize = normalize
96
+ self.downsample = DownSample(downsample)
97
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
98
+ self.learned_sc = dim_in != dim_out
99
+ self._build_weights(dim_in, dim_out)
100
+
101
+ def _build_weights(self, dim_in, dim_out):
102
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
103
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
104
+ if self.normalize:
105
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
106
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
107
+ if self.learned_sc:
108
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
109
+
110
+ def _shortcut(self, x):
111
+ if self.learned_sc:
112
+ x = self.conv1x1(x)
113
+ if self.downsample:
114
+ x = self.downsample(x)
115
+ return x
116
+
117
+ def _residual(self, x):
118
+ if self.normalize:
119
+ x = self.norm1(x)
120
+ x = self.actv(x)
121
+ x = self.conv1(x)
122
+ x = self.downsample_res(x)
123
+ if self.normalize:
124
+ x = self.norm2(x)
125
+ x = self.actv(x)
126
+ x = self.conv2(x)
127
+ return x
128
+
129
+ def forward(self, x):
130
+ x = self._shortcut(x) + self._residual(x)
131
+ return x / math.sqrt(2) # unit variance
132
+
133
+ class StyleEncoder(nn.Module):
134
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
135
+ super().__init__()
136
+ blocks = []
137
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
138
+
139
+ repeat_num = 4
140
+ for _ in range(repeat_num):
141
+ dim_out = min(dim_in*2, max_conv_dim)
142
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
143
+ dim_in = dim_out
144
+
145
+ blocks += [nn.LeakyReLU(0.2)]
146
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
147
+ blocks += [nn.AdaptiveAvgPool2d(1)]
148
+ blocks += [nn.LeakyReLU(0.2)]
149
+ self.shared = nn.Sequential(*blocks)
150
+
151
+ self.unshared = nn.Linear(dim_out, style_dim)
152
+
153
+ def forward(self, x):
154
+ h = self.shared(x)
155
+ h = h.view(h.size(0), -1)
156
+ s = self.unshared(h)
157
+
158
+ return s
159
+
160
+ class LinearNorm(torch.nn.Module):
161
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
162
+ super(LinearNorm, self).__init__()
163
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
164
+
165
+ torch.nn.init.xavier_uniform_(
166
+ self.linear_layer.weight,
167
+ gain=torch.nn.init.calculate_gain(w_init_gain))
168
+
169
+ def forward(self, x):
170
+ return self.linear_layer(x)
171
+
172
+ class Discriminator2d(nn.Module):
173
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
174
+ super().__init__()
175
+ blocks = []
176
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
177
+
178
+ for lid in range(repeat_num):
179
+ dim_out = min(dim_in*2, max_conv_dim)
180
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
181
+ dim_in = dim_out
182
+
183
+ blocks += [nn.LeakyReLU(0.2)]
184
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
185
+ blocks += [nn.LeakyReLU(0.2)]
186
+ blocks += [nn.AdaptiveAvgPool2d(1)]
187
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
188
+ self.main = nn.Sequential(*blocks)
189
+
190
+ def get_feature(self, x):
191
+ features = []
192
+ for l in self.main:
193
+ x = l(x)
194
+ features.append(x)
195
+ out = features[-1]
196
+ out = out.view(out.size(0), -1) # (batch, num_domains)
197
+ return out, features
198
+
199
+ def forward(self, x):
200
+ out, features = self.get_feature(x)
201
+ out = out.squeeze() # (batch)
202
+ return out, features
203
+
204
+ class ResBlk1d(nn.Module):
205
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
206
+ normalize=False, downsample='none', dropout_p=0.2):
207
+ super().__init__()
208
+ self.actv = actv
209
+ self.normalize = normalize
210
+ self.downsample_type = downsample
211
+ self.learned_sc = dim_in != dim_out
212
+ self._build_weights(dim_in, dim_out)
213
+ self.dropout_p = dropout_p
214
+
215
+ if self.downsample_type == 'none':
216
+ self.pool = nn.Identity()
217
+ else:
218
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
219
+
220
+ def _build_weights(self, dim_in, dim_out):
221
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
222
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
223
+ if self.normalize:
224
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
225
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
226
+ if self.learned_sc:
227
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
228
+
229
+ def downsample(self, x):
230
+ if self.downsample_type == 'none':
231
+ return x
232
+ else:
233
+ if x.shape[-1] % 2 != 0:
234
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
235
+ return F.avg_pool1d(x, 2)
236
+
237
+ def _shortcut(self, x):
238
+ if self.learned_sc:
239
+ x = self.conv1x1(x)
240
+ x = self.downsample(x)
241
+ return x
242
+
243
+ def _residual(self, x):
244
+ if self.normalize:
245
+ x = self.norm1(x)
246
+ x = self.actv(x)
247
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
248
+
249
+ x = self.conv1(x)
250
+ x = self.pool(x)
251
+ if self.normalize:
252
+ x = self.norm2(x)
253
+
254
+ x = self.actv(x)
255
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
256
+
257
+ x = self.conv2(x)
258
+ return x
259
+
260
+ def forward(self, x):
261
+ x = self._shortcut(x) + self._residual(x)
262
+ return x / math.sqrt(2) # unit variance
263
+
264
+ class LayerNorm(nn.Module):
265
+ def __init__(self, channels, eps=1e-5):
266
+ super().__init__()
267
+ self.channels = channels
268
+ self.eps = eps
269
+
270
+ self.gamma = nn.Parameter(torch.ones(channels))
271
+ self.beta = nn.Parameter(torch.zeros(channels))
272
+
273
+ def forward(self, x):
274
+ x = x.transpose(1, -1)
275
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
276
+ return x.transpose(1, -1)
277
+
278
+ class TextEncoder(nn.Module):
279
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
280
+ super().__init__()
281
+ self.embedding = nn.Embedding(n_symbols, channels)
282
+
283
+ padding = (kernel_size - 1) // 2
284
+ self.cnn = nn.ModuleList()
285
+ for _ in range(depth):
286
+ self.cnn.append(nn.Sequential(
287
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
288
+ LayerNorm(channels),
289
+ actv,
290
+ nn.Dropout(0.2),
291
+ ))
292
+ # self.cnn = nn.Sequential(*self.cnn)
293
+
294
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
295
+
296
+ def forward(self, x, input_lengths, m):
297
+ x = self.embedding(x) # [B, T, emb]
298
+ x = x.transpose(1, 2) # [B, emb, T]
299
+ m = m.to(input_lengths.device).unsqueeze(1)
300
+ x.masked_fill_(m, 0.0)
301
+
302
+ for c in self.cnn:
303
+ x = c(x)
304
+ x.masked_fill_(m, 0.0)
305
+
306
+ x = x.transpose(1, 2) # [B, T, chn]
307
+
308
+ input_lengths = input_lengths.cpu().numpy()
309
+ x = nn.utils.rnn.pack_padded_sequence(
310
+ x, input_lengths, batch_first=True, enforce_sorted=False)
311
+
312
+ self.lstm.flatten_parameters()
313
+ x, _ = self.lstm(x)
314
+ x, _ = nn.utils.rnn.pad_packed_sequence(
315
+ x, batch_first=True)
316
+
317
+ x = x.transpose(-1, -2)
318
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
319
+
320
+ x_pad[:, :, :x.shape[-1]] = x
321
+ x = x_pad.to(x.device)
322
+
323
+ x.masked_fill_(m, 0.0)
324
+
325
+ return x
326
+
327
+ def inference(self, x):
328
+ x = self.embedding(x)
329
+ x = x.transpose(1, 2)
330
+ x = self.cnn(x)
331
+ x = x.transpose(1, 2)
332
+ self.lstm.flatten_parameters()
333
+ x, _ = self.lstm(x)
334
+ return x
335
+
336
+ def length_to_mask(self, lengths):
337
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
338
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
339
+ return mask
340
+
341
+
342
+ class AdaIN1d(nn.Module):
343
+ def __init__(self, style_dim, num_features):
344
+ super().__init__()
345
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
346
+ self.fc = nn.Linear(style_dim, num_features*2)
347
+
348
+ def forward(self, x, s):
349
+ h = self.fc(s)
350
+ h = h.view(h.size(0), h.size(1), 1)
351
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
352
+ return (1 + gamma) * self.norm(x) + beta
353
+
354
+ class UpSample1d(nn.Module):
355
+ def __init__(self, layer_type):
356
+ super().__init__()
357
+ self.layer_type = layer_type
358
+
359
+ def forward(self, x):
360
+ if self.layer_type == 'none':
361
+ return x
362
+ else:
363
+ return F.interpolate(x, scale_factor=2, mode='nearest')
364
+
365
+ class AdainResBlk1d(nn.Module):
366
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
367
+ upsample='none', dropout_p=0.0):
368
+ super().__init__()
369
+ self.actv = actv
370
+ self.upsample_type = upsample
371
+ self.upsample = UpSample1d(upsample)
372
+ self.learned_sc = dim_in != dim_out
373
+ self._build_weights(dim_in, dim_out, style_dim)
374
+ self.dropout = nn.Dropout(dropout_p)
375
+
376
+ if upsample == 'none':
377
+ self.pool = nn.Identity()
378
+ else:
379
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
380
+
381
+
382
+ def _build_weights(self, dim_in, dim_out, style_dim):
383
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
384
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
385
+ self.norm1 = AdaIN1d(style_dim, dim_in)
386
+ self.norm2 = AdaIN1d(style_dim, dim_out)
387
+ if self.learned_sc:
388
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
389
+
390
+ def _shortcut(self, x):
391
+ x = self.upsample(x)
392
+ if self.learned_sc:
393
+ x = self.conv1x1(x)
394
+ return x
395
+
396
+ def _residual(self, x, s):
397
+ x = self.norm1(x, s)
398
+ x = self.actv(x)
399
+ x = self.pool(x)
400
+ x = self.conv1(self.dropout(x))
401
+ x = self.norm2(x, s)
402
+ x = self.actv(x)
403
+ x = self.conv2(self.dropout(x))
404
+ return x
405
+
406
+ def forward(self, x, s):
407
+ out = self._residual(x, s)
408
+ out = (out + self._shortcut(x)) / math.sqrt(2)
409
+ return out
410
+
411
+
412
+ class Decoder(nn.Module):
413
+ def __init__(self, dim_in=512, style_dim=64, residual_dim=64, dim_out=80):
414
+ super().__init__()
415
+
416
+ self.decode = nn.ModuleList()
417
+
418
+ self.bottleneck_dim = dim_in * 2
419
+
420
+ self.encode = nn.Sequential(ResBlk1d(dim_in + 2, self.bottleneck_dim, normalize=True),
421
+ ResBlk1d(self.bottleneck_dim, self.bottleneck_dim, normalize=True))
422
+
423
+ self.decode.append(AdainResBlk1d(self.bottleneck_dim + residual_dim + 2, self.bottleneck_dim, style_dim))
424
+ self.decode.append(AdainResBlk1d(self.bottleneck_dim + residual_dim + 2, self.bottleneck_dim, style_dim))
425
+ self.decode.append(AdainResBlk1d(self.bottleneck_dim + residual_dim + 2, dim_in, style_dim, upsample=True))
426
+ self.decode.append(AdainResBlk1d(dim_in, dim_in, style_dim))
427
+ self.decode.append(AdainResBlk1d(dim_in, dim_in, style_dim))
428
+
429
+ self.F0_conv = nn.Sequential(
430
+ ResBlk1d(1, residual_dim, normalize=True, downsample=True),
431
+ weight_norm(nn.Conv1d(residual_dim, 1, kernel_size=1)),
432
+ nn.InstanceNorm1d(1, affine=True)
433
+ )
434
+
435
+ self.N_conv = nn.Sequential(
436
+ ResBlk1d(1, residual_dim, normalize=True, downsample=True),
437
+ weight_norm(nn.Conv1d(residual_dim, 1, kernel_size=1)),
438
+ nn.InstanceNorm1d(1, affine=True)
439
+ )
440
+
441
+ self.asr_res = nn.Sequential(
442
+ weight_norm(nn.Conv1d(dim_in, residual_dim, kernel_size=1)),
443
+ nn.InstanceNorm1d(residual_dim, affine=True)
444
+ )
445
+
446
+ self.to_out = nn.Sequential(weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0)))
447
+
448
+ def forward(self, asr, F0, N, s):
449
+ F0 = self.F0_conv(F0.unsqueeze(1))
450
+ N = self.N_conv(N.unsqueeze(1))
451
+
452
+ x = torch.cat([asr, F0, N], axis=1)
453
+ x = self.encode(x)
454
+
455
+ asr_res = self.asr_res(asr)
456
+
457
+ res = True
458
+ for block in self.decode:
459
+ if res:
460
+ x = torch.cat([x, asr_res, F0, N], axis=1)
461
+ x = block(x, s)
462
+ if block.upsample_type != "none":
463
+ res = False
464
+
465
+ x = self.to_out(x)
466
+ return x
467
+
468
+
469
+ class AdaLayerNorm(nn.Module):
470
+ def __init__(self, style_dim, channels, eps=1e-5):
471
+ super().__init__()
472
+ self.channels = channels
473
+ self.eps = eps
474
+
475
+ self.fc = nn.Linear(style_dim, channels*2)
476
+
477
+ def forward(self, x, s):
478
+ x = x.transpose(-1, -2)
479
+ x = x.transpose(1, -1)
480
+
481
+ h = self.fc(s)
482
+ h = h.view(h.size(0), h.size(1), 1)
483
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
484
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
485
+
486
+
487
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
488
+ x = (1 + gamma) * x + beta
489
+ return x.transpose(1, -1).transpose(-1, -2)
490
+
491
+ class LinearNorm(torch.nn.Module):
492
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
493
+ super(LinearNorm, self).__init__()
494
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
495
+
496
+ torch.nn.init.xavier_uniform_(
497
+ self.linear_layer.weight,
498
+ gain=torch.nn.init.calculate_gain(w_init_gain))
499
+
500
+ def forward(self, x):
501
+ return self.linear_layer(x)
502
+
503
+ class ProsodyPredictor(nn.Module):
504
+
505
+ def __init__(self, style_dim, d_hid, nlayers, dropout=0.1):
506
+ super().__init__()
507
+
508
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
509
+ d_model=d_hid,
510
+ nlayers=nlayers,
511
+ dropout=dropout)
512
+
513
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
514
+ self.duration_proj = LinearNorm(d_hid, 1)
515
+
516
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
517
+ self.F0 = nn.ModuleList()
518
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
519
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
520
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
521
+
522
+ self.N = nn.ModuleList()
523
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
524
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
525
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
526
+
527
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
528
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
529
+
530
+
531
+ def forward(self, texts, style, text_lengths, alignment, m):
532
+ d = self.text_encoder(texts, style, text_lengths, m)
533
+
534
+ batch_size = d.shape[0]
535
+ text_size = d.shape[1]
536
+
537
+ # predict duration
538
+ input_lengths = text_lengths.cpu().numpy()
539
+ x = nn.utils.rnn.pack_padded_sequence(
540
+ d, input_lengths, batch_first=True, enforce_sorted=False)
541
+
542
+ m = m.to(text_lengths.device).unsqueeze(1)
543
+
544
+ self.lstm.flatten_parameters()
545
+ x, _ = self.lstm(x)
546
+ x, _ = nn.utils.rnn.pad_packed_sequence(
547
+ x, batch_first=True)
548
+
549
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
550
+
551
+ x_pad[:, :x.shape[1], :] = x
552
+ x = x_pad.to(x.device)
553
+
554
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
555
+
556
+ en = (d.transpose(-1, -2) @ alignment)
557
+
558
+ return duration.squeeze(-1), en
559
+
560
+ def F0Ntrain(self, x, s):
561
+ x, _ = self.shared(x.transpose(-1, -2))
562
+
563
+ F0 = x.transpose(-1, -2)
564
+ for block in self.F0:
565
+ F0 = block(F0, s)
566
+ F0 = self.F0_proj(F0)
567
+
568
+ N = x.transpose(-1, -2)
569
+ for block in self.N:
570
+ N = block(N, s)
571
+ N = self.N_proj(N)
572
+
573
+ return F0.squeeze(1), N.squeeze(1)
574
+
575
+ def length_to_mask(self, lengths):
576
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
577
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
578
+ return mask
579
+
580
+ class DurationEncoder(nn.Module):
581
+
582
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
583
+ super().__init__()
584
+ self.lstms = nn.ModuleList()
585
+ for _ in range(nlayers):
586
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
587
+ d_model // 2,
588
+ num_layers=1,
589
+ batch_first=True,
590
+ bidirectional=True,
591
+ dropout=dropout))
592
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
593
+
594
+
595
+ self.dropout = dropout
596
+ self.d_model = d_model
597
+ self.sty_dim = sty_dim
598
+
599
+ def forward(self, x, style, text_lengths, m):
600
+ masks = m.to(text_lengths.device)
601
+
602
+ x = x.permute(2, 0, 1)
603
+ s = style.expand(x.shape[0], x.shape[1], -1)
604
+ x = torch.cat([x, s], axis=-1)
605
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
606
+
607
+ x = x.transpose(0, 1)
608
+ input_lengths = text_lengths.cpu().numpy()
609
+ x = x.transpose(-1, -2)
610
+
611
+ for block in self.lstms:
612
+ if isinstance(block, AdaLayerNorm):
613
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
614
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
615
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
616
+ else:
617
+ x = x.transpose(-1, -2)
618
+ x = nn.utils.rnn.pack_padded_sequence(
619
+ x, input_lengths, batch_first=True, enforce_sorted=False)
620
+ block.flatten_parameters()
621
+ x, _ = block(x)
622
+ x, _ = nn.utils.rnn.pad_packed_sequence(
623
+ x, batch_first=True)
624
+ x = F.dropout(x, p=self.dropout, training=self.training)
625
+ x = x.transpose(-1, -2)
626
+
627
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
628
+
629
+ x_pad[:, :, :x.shape[-1]] = x
630
+ x = x_pad.to(x.device)
631
+
632
+ return x.transpose(-1, -2)
633
+
634
+ def inference(self, x, style):
635
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
636
+ style = style.expand(x.shape[0], x.shape[1], -1)
637
+ x = torch.cat([x, style], axis=-1)
638
+ src = self.pos_encoder(x)
639
+ output = self.transformer_encoder(src).transpose(0, 1)
640
+ return output
641
+
642
+ def length_to_mask(self, lengths):
643
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
644
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
645
+ return mask
646
+
647
+ def load_F0_models(path):
648
+ # load F0 model
649
+
650
+ F0_model = JDCNet(num_class=1, seq_len=192)
651
+ params = torch.load(path, map_location='cpu')['net']
652
+ F0_model.load_state_dict(params)
653
+ _ = F0_model.train()
654
+
655
+ return F0_model
656
+
657
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
658
+ # load ASR model
659
+ def _load_config(path):
660
+ with open(path) as f:
661
+ config = yaml.safe_load(f)
662
+ model_config = config['model_params']
663
+ return model_config
664
+
665
+ def _load_model(model_config, model_path):
666
+ model = ASRCNN(**model_config)
667
+ params = torch.load(model_path, map_location='cpu')['model']
668
+ model.load_state_dict(params)
669
+ return model
670
+
671
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
672
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
673
+ _ = asr_model.train()
674
+
675
+ return asr_model
676
+
677
+ def build_model(args, text_aligner, pitch_extractor):
678
+
679
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels)
680
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
681
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, dropout=args.dropout)
682
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)
683
+ discriminator = Discriminator2d(dim_in=args.dim_in, num_domains=1, max_conv_dim=args.hidden_dim)
684
+
685
+ nets = Munch(predictor=predictor,
686
+ decoder=decoder,
687
+ pitch_extractor=pitch_extractor,
688
+ text_encoder=text_encoder,
689
+ style_encoder=style_encoder,
690
+ text_aligner = text_aligner,
691
+ discriminator=discriminator)
692
+ return nets
693
+
694
+ def load_checkpoint(model, optimizer, path, load_only_params=True):
695
+ state = torch.load(path, map_location='cpu')
696
+ params = state['net']
697
+ for key in model:
698
+ if key in params:
699
+ print('%s loaded' % key)
700
+ model[key].load_state_dict(params[key])
701
+ _ = [model[key].eval() for key in model]
702
+
703
+ if not load_only_params:
704
+ epoch = state["epoch"]
705
+ iters = state["iters"]
706
+ optimizer.load_state_dict(state["optimizer"])
707
+ else:
708
+ epoch = 0
709
+ iters = 0
710
+
711
+ return model, optimizer, epoch, iters
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from monotonic_align import maximum_path
2
+ from monotonic_align import mask_from_lens
3
+ from monotonic_align.core import maximum_path_c
4
+ import numpy as np
5
+ import torch
6
+ import copy
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ import librosa
11
+ import matplotlib.pyplot as plt
12
+
13
+ def maximum_path(neg_cent, mask):
14
+ """ Cython optimized version.
15
+ neg_cent: [b, t_t, t_s]
16
+ mask: [b, t_t, t_s]
17
+ """
18
+ device = neg_cent.device
19
+ dtype = neg_cent.dtype
20
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
21
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
22
+
23
+ t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32))
24
+ t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32))
25
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
26
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
27
+
28
+ def get_data_path_list(train_path=None, val_path=None):
29
+ if train_path is None:
30
+ train_path = "Data/train_list.txt"
31
+ if val_path is None:
32
+ val_path = "Data/val_list.txt"
33
+
34
+ with open(train_path, 'r', encoding='utf-8', errors='ignore') as f:
35
+ train_list = f.readlines()
36
+ with open(val_path, 'r', encoding='utf-8', errors='ignore') as f:
37
+ val_list = f.readlines()
38
+
39
+ return train_list, val_list
40
+
41
+ def length_to_mask(lengths):
42
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
43
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
44
+ return mask
45
+
46
+ # for adversarial loss
47
+ def adv_loss(logits, target):
48
+ assert target in [1, 0]
49
+ if len(logits.shape) > 1:
50
+ logits = logits.reshape(-1)
51
+ targets = torch.full_like(logits, fill_value=target)
52
+ logits = logits.clamp(min=-10, max=10) # prevent nan
53
+ loss = F.binary_cross_entropy_with_logits(logits, targets)
54
+ return loss
55
+
56
+ # for R1 regularization loss
57
+ def r1_reg(d_out, x_in):
58
+ # zero-centered gradient penalty for real images
59
+ batch_size = x_in.size(0)
60
+ grad_dout = torch.autograd.grad(
61
+ outputs=d_out.sum(), inputs=x_in,
62
+ create_graph=True, retain_graph=True, only_inputs=True
63
+ )[0]
64
+ grad_dout2 = grad_dout.pow(2)
65
+ assert(grad_dout2.size() == x_in.size())
66
+ reg = 0.5 * grad_dout2.view(batch_size, -1).sum(1).mean(0)
67
+ return reg
68
+
69
+ # for norm consistency loss
70
+ def log_norm(x, mean=-4, std=4, dim=2):
71
+ """
72
+ normalized log mel -> mel -> norm -> log(norm)
73
+ """
74
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
75
+ return x
76
+
77
+ def get_image(arrs):
78
+ plt.switch_backend('agg')
79
+ fig = plt.figure()
80
+ ax = plt.gca()
81
+ ax.imshow(arrs)
82
+
83
+ return fig