SazerLife commited on
Commit
36a67ca
1 Parent(s): fe6ff1b

feat: added model

Browse files
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DinoHuVits"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModel": "config.DinoHuVitsConfig"
7
+ },
8
+ "gin_channels": 256,
9
+ "hidden_channels": 192,
10
+ "hubert_downsample_channels": 192,
11
+ "hubert_feature_channels": 768,
12
+ "hubert_output_layer": 11,
13
+ "inter_channels": 192,
14
+ "model_type": "DINO-HuVITS",
15
+ "resblock": "1",
16
+ "resblock_dilation_sizes": [
17
+ [
18
+ 1,
19
+ 3,
20
+ 5
21
+ ],
22
+ [
23
+ 1,
24
+ 3,
25
+ 5
26
+ ],
27
+ [
28
+ 1,
29
+ 3,
30
+ 5
31
+ ]
32
+ ],
33
+ "resblock_kernel_sizes": [
34
+ 3,
35
+ 7,
36
+ 11
37
+ ],
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.38.2",
40
+ "upsample_initial_channel": 512,
41
+ "upsample_kernel_sizes": [
42
+ 20,
43
+ 16,
44
+ 4,
45
+ 4
46
+ ],
47
+ "upsample_rates": [
48
+ 10,
49
+ 8,
50
+ 2,
51
+ 2
52
+ ]
53
+ }
config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ class DinoHuVitsConfig(PretrainedConfig):
6
+ model_type = "DinoHuVits"
7
+
8
+ def __init__(
9
+ self,
10
+ inter_channels=192,
11
+ hidden_channels=192,
12
+ resblock="1",
13
+ resblock_kernel_sizes=[3, 7, 11],
14
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15
+ upsample_rates=[10, 8, 2, 2],
16
+ upsample_initial_channel=512,
17
+ upsample_kernel_sizes=[20, 16, 4, 4],
18
+ gin_channels=256,
19
+ hubert_feature_channels=768,
20
+ hubert_downsample_channels=192,
21
+ hubert_output_layer=11,
22
+ **kwargs
23
+ ):
24
+ self.inter_channels = inter_channels
25
+ self.hidden_channels = hidden_channels
26
+ self.resblock = resblock
27
+ self.resblock_kernel_sizes = resblock_kernel_sizes
28
+ self.resblock_dilation_sizes = resblock_dilation_sizes
29
+ self.upsample_rates = upsample_rates
30
+ self.upsample_initial_channel = upsample_initial_channel
31
+ self.upsample_kernel_sizes = upsample_kernel_sizes
32
+ self.gin_channels = gin_channels
33
+
34
+ self.hubert_feature_channels = hubert_feature_channels
35
+ self.hubert_downsample_channels = hubert_downsample_channels
36
+ self.hubert_output_layer = hubert_output_layer
37
+ super().__init__(**kwargs)
dino_huvits.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+
4
+ from config import DinoHuVitsConfig
5
+ from src import CAMPPlus, Flow, HiFiGAN, PosteriorHubert
6
+
7
+
8
+ class DinoHuVits(PreTrainedModel):
9
+ config_class = DinoHuVitsConfig
10
+
11
+ def __init__(self, config: DinoHuVitsConfig):
12
+ super().__init__(config)
13
+
14
+ self.enc_r = CAMPPlus(embed_dim=config.gin_channels, pooling_func="TSTP")
15
+ self.enc_q = PosteriorHubert(
16
+ out_channels=config.inter_channels,
17
+ feature_channels=config.hubert_feature_channels,
18
+ downsample_channels=config.hubert_downsample_channels,
19
+ output_layer=config.hubert_output_layer,
20
+ )
21
+ self.flow = Flow(
22
+ channels=config.inter_channels,
23
+ hidden_channels=config.hidden_channels,
24
+ kernel_size=5,
25
+ dilation_rate=1,
26
+ n_layers=4,
27
+ gin_channels=config.gin_channels,
28
+ )
29
+ self.dec = HiFiGAN(
30
+ initial_channel=config.inter_channels,
31
+ resblock=config.resblock,
32
+ resblock_kernel_sizes=config.resblock_kernel_sizes,
33
+ resblock_dilation_sizes=config.resblock_dilation_sizes,
34
+ upsample_rates=config.upsample_rates,
35
+ upsample_initial_channel=config.upsample_initial_channel,
36
+ upsample_kernel_sizes=config.upsample_kernel_sizes,
37
+ gin_channels=config.gin_channels,
38
+ )
39
+
40
+ def forward(
41
+ self, content: torch.Tensor, lengths: torch.Tensor, reference: torch.Tensor
42
+ ):
43
+ g_src = self.__get_style_embedding(content)
44
+ g_tgt = self.__get_style_embedding(reference)
45
+ z, _, _, y_mask = self.enc_q(content, lengths, g=g_src)
46
+ z_p = self.flow(z, y_mask, g=g_src)
47
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
48
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
49
+ return o_hat, y_mask
50
+
51
+ def __get_style_embedding(self, wavefrom: torch.Tensor):
52
+ g = self.enc_r(wavefrom) # [b, h, 1]
53
+ g = torch.nn.functional.normalize(g, dim=1)
54
+ return g.unsqueeze(-1)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecb11979e09bbd727f5fac56234beb2f25ee5d7ab572eebecd5cf061f538eef7
3
+ size 513863296
module/__init__.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+
4
+ import numpy as np
5
+ import scipy
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
9
+ from torch.nn import functional as F
10
+ from torch.nn.utils import remove_weight_norm, weight_norm
11
+
12
+ from tools import commons
13
+ from tools.commons import get_padding, init_weights
14
+ from tools.transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1).contiguous()
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1).contiguous()
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_2 = nn.ModuleList()
103
+ for i in range(n_layers):
104
+ dilation = kernel_size**i
105
+
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ conv = nn.Conv1d(
108
+ channels,
109
+ channels,
110
+ kernel_size,
111
+ groups=channels,
112
+ dilation=dilation,
113
+ padding=padding,
114
+ )
115
+ self.convs_sep.append(conv)
116
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
117
+ self.norms_2.append(LayerNorm(channels))
118
+
119
+ def forward(self, x, x_mask, g=None):
120
+ if g is not None:
121
+ x = x + g
122
+ for i in range(self.n_layers):
123
+ y = self.convs_sep[i](x * x_mask)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+
166
+ in_layer = Conv1d(
167
+ hidden_channels,
168
+ 2 * hidden_channels,
169
+ kernel_size,
170
+ padding=padding,
171
+ dilation=dilation,
172
+ )
173
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
174
+ self.in_layers.append(in_layer)
175
+
176
+ # last one is not necessary
177
+ if i < n_layers - 1:
178
+ res_skip_channels = 2 * hidden_channels
179
+ else:
180
+ res_skip_channels = hidden_channels
181
+
182
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
183
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
184
+ self.res_skip_layers.append(res_skip_layer)
185
+
186
+ def forward(self, x, x_mask, g=None, **kwargs):
187
+ output = torch.zeros_like(x)
188
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
189
+
190
+ if g is not None:
191
+ g = self.cond_layer(g)
192
+
193
+ for i in range(self.n_layers):
194
+ x_in = self.in_layers[i](x)
195
+ if g is not None:
196
+ cond_offset = i * 2 * self.hidden_channels
197
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
198
+ else:
199
+ g_l = torch.zeros_like(x_in)
200
+
201
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
202
+ acts = self.drop(acts)
203
+
204
+ res_skip_acts = self.res_skip_layers[i](acts)
205
+ if i < self.n_layers - 1:
206
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
207
+ x = (x + res_acts) * x_mask
208
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
209
+ else:
210
+ output = output + res_skip_acts
211
+ return output * x_mask
212
+
213
+ def remove_weight_norm(self):
214
+ if self.gin_channels != 0:
215
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
216
+ for l in self.in_layers:
217
+ torch.nn.utils.remove_weight_norm(l)
218
+ for l in self.res_skip_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+
221
+
222
+ class Log(nn.Module):
223
+ def forward(self, x, x_mask, reverse=False, **kwargs):
224
+ if not reverse:
225
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
226
+ logdet = torch.sum(-y, [1, 2])
227
+ return y, logdet
228
+ else:
229
+ x = torch.exp(x) * x_mask
230
+ return x
231
+
232
+
233
+ class Flip(nn.Module):
234
+ def forward(self, x, *args, reverse=False, **kwargs):
235
+ x = torch.flip(x, [1])
236
+ if not reverse:
237
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
238
+ return x, logdet
239
+ else:
240
+ return x
241
+
242
+
243
+ class ElementwiseAffine(nn.Module):
244
+ def __init__(self, channels):
245
+ super().__init__()
246
+ self.channels = channels
247
+ self.m = nn.Parameter(torch.zeros(channels, 1))
248
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
249
+
250
+ def forward(self, x, x_mask, reverse=False, **kwargs):
251
+ if not reverse:
252
+ y = self.m + torch.exp(self.logs) * x
253
+ y = y * x_mask
254
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
255
+ return y, logdet
256
+ else:
257
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
258
+ return x
259
+
260
+
261
+ class ResidualCouplingLayer(nn.Module):
262
+ def __init__(
263
+ self,
264
+ channels,
265
+ hidden_channels,
266
+ kernel_size,
267
+ dilation_rate,
268
+ n_layers,
269
+ p_dropout=0,
270
+ gin_channels=0,
271
+ mean_only=False,
272
+ ):
273
+ assert channels % 2 == 0, "channels should be divisible by 2"
274
+ super().__init__()
275
+ self.channels = channels
276
+ self.hidden_channels = hidden_channels
277
+ self.kernel_size = kernel_size
278
+ self.dilation_rate = dilation_rate
279
+ self.n_layers = n_layers
280
+ self.half_channels = channels // 2
281
+ self.mean_only = mean_only
282
+
283
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
284
+ self.enc = WN(
285
+ hidden_channels,
286
+ kernel_size,
287
+ dilation_rate,
288
+ n_layers,
289
+ p_dropout=p_dropout,
290
+ gin_channels=gin_channels,
291
+ )
292
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
293
+ self.post.weight.data.zero_()
294
+ self.post.bias.data.zero_()
295
+
296
+ def forward(self, x, x_mask, g=None, reverse=False):
297
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
298
+ h = self.pre(x0) * x_mask
299
+ h = self.enc(h, x_mask, g=g)
300
+ stats = self.post(h) * x_mask
301
+ if not self.mean_only:
302
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
303
+ else:
304
+ m = stats
305
+ logs = torch.zeros_like(m)
306
+
307
+ if not reverse:
308
+ x1 = m + x1 * torch.exp(logs) * x_mask
309
+ x = torch.cat([x0, x1], 1)
310
+ logdet = torch.sum(logs, [1, 2])
311
+ return x, logdet
312
+ else:
313
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
314
+ x = torch.cat([x0, x1], 1)
315
+ return x
316
+
317
+
318
+ class ConvFlow(nn.Module):
319
+ def __init__(
320
+ self,
321
+ in_channels,
322
+ filter_channels,
323
+ kernel_size,
324
+ n_layers,
325
+ num_bins=10,
326
+ tail_bound=5.0,
327
+ ):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ self.filter_channels = filter_channels
331
+ self.kernel_size = kernel_size
332
+ self.n_layers = n_layers
333
+ self.num_bins = num_bins
334
+ self.tail_bound = tail_bound
335
+ self.half_channels = in_channels // 2
336
+
337
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
338
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
339
+ self.proj = nn.Conv1d(
340
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
341
+ )
342
+ self.proj.weight.data.zero_()
343
+ self.proj.bias.data.zero_()
344
+
345
+ def forward(self, x, x_mask, g=None, reverse=False):
346
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
347
+
348
+ h = self.pre(x0)
349
+ h = self.convs(h, x_mask, g=g)
350
+ h = self.proj(h) * x_mask
351
+
352
+ b, c, t = x0.shape
353
+ h = (
354
+ h.reshape(b, c, -1, t).permute(0, 1, 3, 2).contiguous()
355
+ ) # [b, cx?, t] -> [b, c, t, ?]
356
+
357
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
358
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
359
+ self.filter_channels
360
+ )
361
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
362
+
363
+ x1, logabsdet = piecewise_rational_quadratic_transform(
364
+ x1,
365
+ unnormalized_widths,
366
+ unnormalized_heights,
367
+ unnormalized_derivatives,
368
+ inverse=reverse,
369
+ tails="linear",
370
+ tail_bound=self.tail_bound,
371
+ )
372
+
373
+ x = torch.cat([x0, x1], 1) * x_mask
374
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
375
+ if not reverse:
376
+ return x, logdet
377
+ else:
378
+ return x
379
+
380
+
381
+ class LinearNorm(nn.Module):
382
+ def __init__(
383
+ self,
384
+ in_channels,
385
+ out_channels,
386
+ bias=True,
387
+ spectral_norm=False,
388
+ ):
389
+ super(LinearNorm, self).__init__()
390
+ self.fc = nn.Linear(in_channels, out_channels, bias)
391
+
392
+ if spectral_norm:
393
+ self.fc = nn.utils.spectral_norm(self.fc)
394
+
395
+ def forward(self, input):
396
+ out = self.fc(input)
397
+ return out
398
+
399
+
400
+ class Mish(nn.Module):
401
+ def __init__(self):
402
+ super(Mish, self).__init__()
403
+
404
+ def forward(self, x):
405
+ return x * torch.tanh(F.softplus(x))
406
+
407
+
408
+ class LinearNorm(nn.Module):
409
+ def __init__(
410
+ self,
411
+ in_channels,
412
+ out_channels,
413
+ bias=True,
414
+ spectral_norm=False,
415
+ ):
416
+ super(LinearNorm, self).__init__()
417
+ self.fc = nn.Linear(in_channels, out_channels, bias)
418
+
419
+ if spectral_norm:
420
+ self.fc = nn.utils.spectral_norm(self.fc)
421
+
422
+ def forward(self, input):
423
+ out = self.fc(input)
424
+ return out
425
+
426
+
427
+ class ConvNorm(nn.Module):
428
+ def __init__(
429
+ self,
430
+ in_channels,
431
+ out_channels,
432
+ kernel_size=1,
433
+ stride=1,
434
+ padding=None,
435
+ dilation=1,
436
+ bias=True,
437
+ spectral_norm=False,
438
+ ):
439
+ super(ConvNorm, self).__init__()
440
+
441
+ if padding is None:
442
+ assert kernel_size % 2 == 1
443
+ padding = int(dilation * (kernel_size - 1) / 2)
444
+
445
+ self.conv = torch.nn.Conv1d(
446
+ in_channels,
447
+ out_channels,
448
+ kernel_size=kernel_size,
449
+ stride=stride,
450
+ padding=padding,
451
+ dilation=dilation,
452
+ bias=bias,
453
+ )
454
+
455
+ if spectral_norm:
456
+ self.conv = nn.utils.spectral_norm(self.conv)
457
+
458
+ def forward(self, input):
459
+ out = self.conv(input)
460
+ return out
461
+
462
+
463
+ class MultiHeadAttention(nn.Module):
464
+ """Multi-Head Attention module"""
465
+
466
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.0, spectral_norm=False):
467
+ super().__init__()
468
+
469
+ self.n_head = n_head
470
+ self.d_k = d_k
471
+ self.d_v = d_v
472
+
473
+ self.w_qs = nn.Linear(d_model, n_head * d_k)
474
+ self.w_ks = nn.Linear(d_model, n_head * d_k)
475
+ self.w_vs = nn.Linear(d_model, n_head * d_v)
476
+
477
+ self.attention = ScaledDotProductAttention(
478
+ temperature=np.power(d_model, 0.5), dropout=dropout
479
+ )
480
+
481
+ self.fc = nn.Linear(n_head * d_v, d_model)
482
+ self.dropout = nn.Dropout(dropout)
483
+
484
+ if spectral_norm:
485
+ self.w_qs = nn.utils.spectral_norm(self.w_qs)
486
+ self.w_ks = nn.utils.spectral_norm(self.w_ks)
487
+ self.w_vs = nn.utils.spectral_norm(self.w_vs)
488
+ self.fc = nn.utils.spectral_norm(self.fc)
489
+
490
+ def forward(self, x, mask=None):
491
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
492
+ sz_b, len_x, _ = x.size()
493
+
494
+ residual = x
495
+
496
+ q = self.w_qs(x).view(sz_b, len_x, n_head, d_k)
497
+ k = self.w_ks(x).view(sz_b, len_x, n_head, d_k)
498
+ v = self.w_vs(x).view(sz_b, len_x, n_head, d_v)
499
+ q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lq x dk
500
+ k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_k) # (n*b) x lk x dk
501
+ v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_v) # (n*b) x lv x dv
502
+
503
+ if mask is not None:
504
+ slf_mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
505
+ else:
506
+ slf_mask = None
507
+ output, attn = self.attention(q, k, v, mask=slf_mask)
508
+
509
+ output = output.view(n_head, sz_b, len_x, d_v)
510
+ output = (
511
+ output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
512
+ ) # b x lq x (n*dv)
513
+
514
+ output = self.fc(output)
515
+
516
+ output = self.dropout(output) + residual
517
+ return output, attn
518
+
519
+
520
+ class ScaledDotProductAttention(nn.Module):
521
+ """Scaled Dot-Product Attention"""
522
+
523
+ def __init__(self, temperature, dropout):
524
+ super().__init__()
525
+ self.temperature = temperature
526
+ self.softmax = nn.Softmax(dim=2)
527
+ self.dropout = nn.Dropout(dropout)
528
+
529
+ def forward(self, q, k, v, mask=None):
530
+ attn = torch.bmm(q, k.transpose(1, 2).contiguous())
531
+ attn = attn / self.temperature
532
+
533
+ if mask is not None:
534
+ attn = attn.masked_fill(mask, -np.inf)
535
+
536
+ attn = self.softmax(attn)
537
+ p_attn = self.dropout(attn)
538
+
539
+ output = torch.bmm(p_attn, v)
540
+ return output, attn
541
+
542
+
543
+ class Conv1dGLU(nn.Module):
544
+ """
545
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
546
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
547
+ """
548
+
549
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
550
+ super(Conv1dGLU, self).__init__()
551
+ self.out_channels = out_channels
552
+ self.conv1 = ConvNorm(in_channels, 2 * out_channels, kernel_size=kernel_size)
553
+ self.dropout = nn.Dropout(dropout)
554
+
555
+ def forward(self, x):
556
+ residual = x
557
+ x = self.conv1(x)
558
+ x1, x2 = torch.split(x, split_size_or_sections=self.out_channels, dim=1)
559
+ x = x1 * torch.sigmoid(x2)
560
+ x = residual + self.dropout(x)
561
+ return x
module/resblocks.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+
7
+ from . import LRELU_SLOPE
8
+ from tools.commons import get_padding, init_weights
9
+
10
+
11
+ class ResBlock1(torch.nn.Module):
12
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
13
+ super(ResBlock1, self).__init__()
14
+ self.convs1 = nn.ModuleList(
15
+ [
16
+ weight_norm(
17
+ Conv1d(
18
+ channels,
19
+ channels,
20
+ kernel_size,
21
+ 1,
22
+ dilation=dilation[0],
23
+ padding=get_padding(kernel_size, dilation[0]),
24
+ )
25
+ ),
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[1],
33
+ padding=get_padding(kernel_size, dilation[1]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[2],
43
+ padding=get_padding(kernel_size, dilation[2]),
44
+ )
45
+ ),
46
+ ]
47
+ )
48
+ self.convs1.apply(init_weights)
49
+
50
+ self.convs2 = nn.ModuleList(
51
+ [
52
+ weight_norm(
53
+ Conv1d(
54
+ channels,
55
+ channels,
56
+ kernel_size,
57
+ 1,
58
+ dilation=1,
59
+ padding=get_padding(kernel_size, 1),
60
+ )
61
+ ),
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ ]
83
+ )
84
+ self.convs2.apply(init_weights)
85
+
86
+ def forward(self, x, x_mask=None):
87
+ for c1, c2 in zip(self.convs1, self.convs2):
88
+ xt = F.leaky_relu(x, LRELU_SLOPE)
89
+ if x_mask is not None:
90
+ xt = xt * x_mask
91
+ xt = c1(xt)
92
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
93
+ if x_mask is not None:
94
+ xt = xt * x_mask
95
+ xt = c2(xt)
96
+ x = xt + x
97
+ if x_mask is not None:
98
+ x = x * x_mask
99
+ return x
100
+
101
+ def remove_weight_norm(self):
102
+ for l in self.convs1:
103
+ remove_weight_norm(l)
104
+ for l in self.convs2:
105
+ remove_weight_norm(l)
106
+
107
+
108
+ class ResBlock2(torch.nn.Module):
109
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
110
+ super(ResBlock2, self).__init__()
111
+ self.convs = nn.ModuleList(
112
+ [
113
+ weight_norm(
114
+ Conv1d(
115
+ channels,
116
+ channels,
117
+ kernel_size,
118
+ 1,
119
+ dilation=dilation[0],
120
+ padding=get_padding(kernel_size, dilation[0]),
121
+ )
122
+ ),
123
+ weight_norm(
124
+ Conv1d(
125
+ channels,
126
+ channels,
127
+ kernel_size,
128
+ 1,
129
+ dilation=dilation[1],
130
+ padding=get_padding(kernel_size, dilation[1]),
131
+ )
132
+ ),
133
+ ]
134
+ )
135
+ self.convs.apply(init_weights)
136
+
137
+ def forward(self, x, x_mask=None):
138
+ for c in self.convs:
139
+ xt = F.leaky_relu(x, LRELU_SLOPE)
140
+ if x_mask is not None:
141
+ xt = xt * x_mask
142
+ xt = c(xt)
143
+ x = xt + x
144
+ if x_mask is not None:
145
+ x = x * x_mask
146
+ return x
147
+
148
+ def remove_weight_norm(self):
149
+ for l in self.convs:
150
+ remove_weight_norm(l)
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .campplus import CAMPPlus
2
+ from .flow import ResidualCouplingBlock as Flow
3
+ from .hifi_gan import Generator as HiFiGAN
4
+ from .hubert_posterior import PosteriorHubert
src/attentions.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import tools.commons as commons
8
+
9
+
10
+ class MultiHeadAttention(nn.Module):
11
+ def __init__(
12
+ self,
13
+ channels,
14
+ out_channels,
15
+ n_heads,
16
+ p_dropout=0.0,
17
+ window_size=None,
18
+ heads_share=True,
19
+ block_length=None,
20
+ proximal_bias=False,
21
+ proximal_init=False,
22
+ ):
23
+ super().__init__()
24
+ assert channels % n_heads == 0
25
+
26
+ self.channels = channels
27
+ self.out_channels = out_channels
28
+ self.n_heads = n_heads
29
+ self.p_dropout = p_dropout
30
+ self.window_size = window_size
31
+ self.heads_share = heads_share
32
+ self.block_length = block_length
33
+ self.proximal_bias = proximal_bias
34
+ self.proximal_init = proximal_init
35
+ self.attn = None
36
+
37
+ self.k_channels = channels // n_heads
38
+ self.conv_q = nn.Conv1d(channels, channels, 1)
39
+ self.conv_k = nn.Conv1d(channels, channels, 1)
40
+ self.conv_v = nn.Conv1d(channels, channels, 1)
41
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
42
+ self.drop = nn.Dropout(p_dropout)
43
+
44
+ if window_size is not None:
45
+ n_heads_rel = 1 if heads_share else n_heads
46
+ rel_stddev = self.k_channels**-0.5
47
+ self.emb_rel_k = nn.Parameter(
48
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
49
+ * rel_stddev
50
+ )
51
+ self.emb_rel_v = nn.Parameter(
52
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
53
+ * rel_stddev
54
+ )
55
+
56
+ nn.init.xavier_uniform_(self.conv_q.weight)
57
+ nn.init.xavier_uniform_(self.conv_k.weight)
58
+ nn.init.xavier_uniform_(self.conv_v.weight)
59
+ if proximal_init:
60
+ with torch.no_grad():
61
+ self.conv_k.weight.copy_(self.conv_q.weight)
62
+ self.conv_k.bias.copy_(self.conv_q.bias)
63
+
64
+ def forward(self, x, c, attn_mask=None):
65
+ q = self.conv_q(x)
66
+ k = self.conv_k(c)
67
+ v = self.conv_v(c)
68
+
69
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
70
+
71
+ x = self.conv_o(x)
72
+ return x
73
+
74
+ def attention(self, query, key, value, mask=None):
75
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
76
+ b, d, t_s, t_t = (*key.size(), query.size(2))
77
+ query = (
78
+ query.view(b, self.n_heads, self.k_channels, t_t)
79
+ .transpose(2, 3)
80
+ .contiguous()
81
+ )
82
+ key = (
83
+ key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3).contiguous()
84
+ )
85
+ value = (
86
+ value.view(b, self.n_heads, self.k_channels, t_s)
87
+ .transpose(2, 3)
88
+ .contiguous()
89
+ )
90
+
91
+ scores = torch.matmul(
92
+ query / math.sqrt(self.k_channels), key.transpose(-2, -1).contiguous()
93
+ )
94
+ if self.window_size is not None:
95
+ assert (
96
+ t_s == t_t
97
+ ), "Relative attention is only available for self-attention."
98
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
99
+ rel_logits = self._matmul_with_relative_keys(
100
+ query / math.sqrt(self.k_channels), key_relative_embeddings
101
+ )
102
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
103
+ scores = scores + scores_local
104
+ if self.proximal_bias:
105
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
106
+ scores = scores + self._attention_bias_proximal(t_s).to(
107
+ device=scores.device, dtype=scores.dtype
108
+ )
109
+ if mask is not None:
110
+ scores = scores.masked_fill(mask == 0, -1e4)
111
+ if self.block_length is not None:
112
+ assert (
113
+ t_s == t_t
114
+ ), "Local attention is only available for self-attention."
115
+ block_mask = (
116
+ torch.ones_like(scores)
117
+ .triu(-self.block_length)
118
+ .tril(self.block_length)
119
+ )
120
+ scores = scores.masked_fill(block_mask == 0, -1e4)
121
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
122
+ p_attn = self.drop(p_attn)
123
+ output = torch.matmul(p_attn, value)
124
+ if self.window_size is not None:
125
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
126
+ value_relative_embeddings = self._get_relative_embeddings(
127
+ self.emb_rel_v, t_s
128
+ )
129
+ output = output + self._matmul_with_relative_values(
130
+ relative_weights, value_relative_embeddings
131
+ )
132
+ output = (
133
+ output.transpose(2, 3).contiguous().view(b, d, t_t).contiguous()
134
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
135
+ return output, p_attn
136
+
137
+ def _matmul_with_relative_values(self, x, y):
138
+ """
139
+ x: [b, h, l, m]
140
+ y: [h or 1, m, d]
141
+ ret: [b, h, l, d]
142
+ """
143
+ ret = torch.matmul(x, y.unsqueeze(0))
144
+ return ret
145
+
146
+ def _matmul_with_relative_keys(self, x, y):
147
+ """
148
+ x: [b, h, l, d]
149
+ y: [h or 1, m, d]
150
+ ret: [b, h, l, m]
151
+ """
152
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1).contiguous())
153
+ return ret
154
+
155
+ def _get_relative_embeddings(self, relative_embeddings, length):
156
+ max_relative_position = 2 * self.window_size + 1
157
+ # Pad first before slice to avoid using cond ops.
158
+ pad_length = max(length - (self.window_size + 1), 0)
159
+ slice_start_position = max((self.window_size + 1) - length, 0)
160
+ slice_end_position = slice_start_position + 2 * length - 1
161
+ if pad_length > 0:
162
+ padded_relative_embeddings = F.pad(
163
+ relative_embeddings,
164
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
165
+ )
166
+ else:
167
+ padded_relative_embeddings = relative_embeddings
168
+ used_relative_embeddings = padded_relative_embeddings[
169
+ :, slice_start_position:slice_end_position
170
+ ]
171
+ return used_relative_embeddings
172
+
173
+ def _relative_position_to_absolute_position(self, x):
174
+ """
175
+ x: [b, h, l, 2*l-1]
176
+ ret: [b, h, l, l]
177
+ """
178
+ batch, heads, length, _ = x.size()
179
+ # Concat columns of pad to shift from relative to absolute indexing.
180
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
181
+
182
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
183
+ x_flat = x.view([batch, heads, length * 2 * length])
184
+ x_flat = F.pad(
185
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
186
+ )
187
+
188
+ # Reshape and slice out the padded elements.
189
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
190
+ :, :, :length, length - 1 :
191
+ ]
192
+ return x_final
193
+
194
+ def _absolute_position_to_relative_position(self, x):
195
+ """
196
+ x: [b, h, l, l]
197
+ ret: [b, h, l, 2*l-1]
198
+ """
199
+ batch, heads, length, _ = x.size()
200
+ # padd along column
201
+ x = F.pad(
202
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
203
+ )
204
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
205
+ # add 0's in the beginning that will skew the elements after reshape
206
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
207
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
208
+ return x_final
209
+
210
+ def _attention_bias_proximal(self, length):
211
+ """Bias for self-attention to encourage attention to close positions.
212
+ Args:
213
+ length: an integer scalar.
214
+ Returns:
215
+ a Tensor with shape [1, 1, length, length]
216
+ """
217
+ r = torch.arange(length, dtype=torch.float32)
218
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
219
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
src/campplus.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Hongji Wang (jijijiang77@gmail.com)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ This implementation is adapted from github repo:
17
+ https://github.com/alibaba-damo-academy/3D-Speaker
18
+
19
+ Some modifications:
20
+ 1. Reuse the pooling layers in wespeaker
21
+ 2. Remove the memory_efficient mechanism to meet the torch.jit.script
22
+ export requirements
23
+
24
+ Reference:
25
+ [1] Hui Wang, Siqi Zheng, Yafeng Chen, Luyao Cheng and Qian Chen.
26
+ "CAM++: A Fast and Efficient Network for Speaker Verification
27
+ Using Context-Aware Masking". arXiv preprint arXiv:2303.00332
28
+ """
29
+
30
+ from collections import OrderedDict
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+
36
+ from .wespeaker_campplus import pooling_layers
37
+ from .wespeaker_campplus.fbank_feature_extractor import FbankFeatureExtractor
38
+
39
+
40
+ def get_nonlinear(config_str, channels):
41
+ nonlinear = nn.Sequential()
42
+ for name in config_str.split("-"):
43
+ if name == "relu":
44
+ nonlinear.add_module("relu", nn.ReLU(inplace=True))
45
+ elif name == "prelu":
46
+ nonlinear.add_module("prelu", nn.PReLU(channels))
47
+ elif name == "batchnorm":
48
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels))
49
+ elif name == "batchnorm_":
50
+ nonlinear.add_module("batchnorm", nn.BatchNorm1d(channels, affine=False))
51
+ else:
52
+ raise ValueError("Unexpected module ({}).".format(name))
53
+ return nonlinear
54
+
55
+
56
+ class TDNNLayer(nn.Module):
57
+ def __init__(
58
+ self,
59
+ in_channels,
60
+ out_channels,
61
+ kernel_size,
62
+ stride=1,
63
+ padding=0,
64
+ dilation=1,
65
+ bias=False,
66
+ config_str="batchnorm-relu",
67
+ ):
68
+ super(TDNNLayer, self).__init__()
69
+ if padding < 0:
70
+ assert (
71
+ kernel_size % 2 == 1
72
+ ), "Expect equal paddings, \
73
+ but got even kernel size ({})".format(
74
+ kernel_size
75
+ )
76
+ padding = (kernel_size - 1) // 2 * dilation
77
+ self.linear = nn.Conv1d(
78
+ in_channels,
79
+ out_channels,
80
+ kernel_size,
81
+ stride=stride,
82
+ padding=padding,
83
+ dilation=dilation,
84
+ bias=bias,
85
+ )
86
+ self.nonlinear = get_nonlinear(config_str, out_channels)
87
+
88
+ def forward(self, x):
89
+ x = self.linear(x)
90
+ x = self.nonlinear(x)
91
+ return x
92
+
93
+
94
+ class CAMLayer(nn.Module):
95
+ def __init__(
96
+ self,
97
+ bn_channels,
98
+ out_channels,
99
+ kernel_size,
100
+ stride,
101
+ padding,
102
+ dilation,
103
+ bias,
104
+ reduction=2,
105
+ ):
106
+ super(CAMLayer, self).__init__()
107
+ self.linear_local = nn.Conv1d(
108
+ bn_channels,
109
+ out_channels,
110
+ kernel_size,
111
+ stride=stride,
112
+ padding=padding,
113
+ dilation=dilation,
114
+ bias=bias,
115
+ )
116
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
117
+ self.relu = nn.ReLU(inplace=True)
118
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
119
+ self.sigmoid = nn.Sigmoid()
120
+
121
+ def forward(self, x):
122
+ y = self.linear_local(x)
123
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
124
+ context = self.relu(self.linear1(context))
125
+ m = self.sigmoid(self.linear2(context))
126
+ return y * m
127
+
128
+ def seg_pooling(self, x, seg_len: int = 100, stype: str = "avg"):
129
+ if stype == "avg":
130
+ seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
131
+ elif stype == "max":
132
+ seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
133
+ else:
134
+ raise ValueError("Wrong segment pooling type.")
135
+ shape = seg.shape
136
+ seg = (
137
+ seg.unsqueeze(-1)
138
+ .expand(shape[0], shape[1], shape[2], seg_len)
139
+ .reshape(shape[0], shape[1], -1)
140
+ )
141
+ seg = seg[..., : x.shape[-1]]
142
+ return seg
143
+
144
+
145
+ class CAMDenseTDNNLayer(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels,
149
+ out_channels,
150
+ bn_channels,
151
+ kernel_size,
152
+ stride=1,
153
+ dilation=1,
154
+ bias=False,
155
+ config_str="batchnorm-relu",
156
+ ):
157
+ super(CAMDenseTDNNLayer, self).__init__()
158
+ assert (
159
+ kernel_size % 2 == 1
160
+ ), "Expect equal paddings, \
161
+ but got even kernel size ({})".format(
162
+ kernel_size
163
+ )
164
+ padding = (kernel_size - 1) // 2 * dilation
165
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
166
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
167
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
168
+ self.cam_layer = CAMLayer(
169
+ bn_channels,
170
+ out_channels,
171
+ kernel_size,
172
+ stride=stride,
173
+ padding=padding,
174
+ dilation=dilation,
175
+ bias=bias,
176
+ )
177
+
178
+ def bn_function(self, x):
179
+ return self.linear1(self.nonlinear1(x))
180
+
181
+ def forward(self, x):
182
+ x = self.bn_function(x)
183
+ x = self.cam_layer(self.nonlinear2(x))
184
+ return x
185
+
186
+
187
+ class CAMDenseTDNNBlock(nn.ModuleList):
188
+ def __init__(
189
+ self,
190
+ num_layers,
191
+ in_channels,
192
+ out_channels,
193
+ bn_channels,
194
+ kernel_size,
195
+ stride=1,
196
+ dilation=1,
197
+ bias=False,
198
+ config_str="batchnorm-relu",
199
+ ):
200
+ super(CAMDenseTDNNBlock, self).__init__()
201
+ for i in range(num_layers):
202
+ layer = CAMDenseTDNNLayer(
203
+ in_channels=in_channels + i * out_channels,
204
+ out_channels=out_channels,
205
+ bn_channels=bn_channels,
206
+ kernel_size=kernel_size,
207
+ stride=stride,
208
+ dilation=dilation,
209
+ bias=bias,
210
+ config_str=config_str,
211
+ )
212
+ self.add_module("tdnnd%d" % (i + 1), layer)
213
+
214
+ def forward(self, x):
215
+ for layer in self:
216
+ x = torch.cat([x, layer(x)], dim=1)
217
+ return x
218
+
219
+
220
+ class TransitLayer(nn.Module):
221
+ def __init__(
222
+ self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"
223
+ ):
224
+ super(TransitLayer, self).__init__()
225
+ self.nonlinear = get_nonlinear(config_str, in_channels)
226
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
227
+
228
+ def forward(self, x):
229
+ x = self.nonlinear(x)
230
+ x = self.linear(x)
231
+ return x
232
+
233
+
234
+ class DenseLayer(nn.Module):
235
+ def __init__(
236
+ self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"
237
+ ):
238
+ super(DenseLayer, self).__init__()
239
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
240
+ self.nonlinear = get_nonlinear(config_str, out_channels)
241
+
242
+ def forward(self, x):
243
+ if len(x.shape) == 2:
244
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
245
+ else:
246
+ x = self.linear(x)
247
+ x = self.nonlinear(x)
248
+ return x
249
+
250
+
251
+ """Note: The stride used here is different from that in Resnet
252
+ """
253
+
254
+
255
+ class BasicResBlock(nn.Module):
256
+ expansion = 1
257
+
258
+ def __init__(self, in_planes, planes, stride=1):
259
+ super(BasicResBlock, self).__init__()
260
+ self.conv1 = nn.Conv2d(
261
+ in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False
262
+ )
263
+ self.bn1 = nn.BatchNorm2d(planes)
264
+ self.conv2 = nn.Conv2d(
265
+ planes, planes, kernel_size=3, stride=1, padding=1, bias=False
266
+ )
267
+ self.bn2 = nn.BatchNorm2d(planes)
268
+
269
+ self.shortcut = nn.Sequential()
270
+ if stride != 1 or in_planes != self.expansion * planes:
271
+ self.shortcut = nn.Sequential(
272
+ nn.Conv2d(
273
+ in_planes,
274
+ self.expansion * planes,
275
+ kernel_size=1,
276
+ stride=(stride, 1),
277
+ bias=False,
278
+ ),
279
+ nn.BatchNorm2d(self.expansion * planes),
280
+ )
281
+
282
+ def forward(self, x):
283
+ out = F.relu(self.bn1(self.conv1(x)))
284
+ out = self.bn2(self.conv2(out))
285
+ out += self.shortcut(x)
286
+ out = F.relu(out)
287
+ return out
288
+
289
+
290
+ class FCM(nn.Module):
291
+ def __init__(self, block, num_blocks, m_channels=32, feat_dim=80):
292
+ super(FCM, self).__init__()
293
+ self.in_planes = m_channels
294
+ self.conv1 = nn.Conv2d(
295
+ 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
296
+ )
297
+ self.bn1 = nn.BatchNorm2d(m_channels)
298
+
299
+ self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
300
+ self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2)
301
+
302
+ self.conv2 = nn.Conv2d(
303
+ m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False
304
+ )
305
+ self.bn2 = nn.BatchNorm2d(m_channels)
306
+ self.out_channels = m_channels * (feat_dim // 8)
307
+
308
+ def _make_layer(self, block, planes, num_blocks, stride):
309
+ strides = [stride] + [1] * (num_blocks - 1)
310
+ layers = []
311
+ for stride in strides:
312
+ layers.append(block(self.in_planes, planes, stride))
313
+ self.in_planes = planes * block.expansion
314
+ return nn.Sequential(*layers)
315
+
316
+ def forward(self, x):
317
+ x = x.unsqueeze(1)
318
+ out = F.relu(self.bn1(self.conv1(x)))
319
+ out = self.layer1(out)
320
+ out = self.layer2(out)
321
+ out = F.relu(self.bn2(self.conv2(out)))
322
+
323
+ shape = out.shape
324
+ out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
325
+ return out
326
+
327
+
328
+ class CAMPPlus(nn.Module):
329
+ def __init__(
330
+ self,
331
+ feat_dim=80,
332
+ embed_dim=512,
333
+ pooling_func="TSTP",
334
+ growth_rate=32,
335
+ bn_size=4,
336
+ init_channels=128,
337
+ config_str="batchnorm-relu",
338
+ ):
339
+ super(CAMPPlus, self).__init__()
340
+
341
+ self.feature_extractor = FbankFeatureExtractor(feat_dim=80)
342
+ self.head = FCM(block=BasicResBlock, num_blocks=[2, 2], feat_dim=feat_dim)
343
+ channels = self.head.out_channels
344
+
345
+ self.xvector = nn.Sequential(
346
+ OrderedDict(
347
+ [
348
+ (
349
+ "tdnn",
350
+ TDNNLayer(
351
+ channels,
352
+ init_channels,
353
+ 5,
354
+ stride=2,
355
+ dilation=1,
356
+ padding=-1,
357
+ config_str=config_str,
358
+ ),
359
+ ),
360
+ ]
361
+ )
362
+ )
363
+ channels = init_channels
364
+ for i, (num_layers, kernel_size, dilation) in enumerate(
365
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2))
366
+ ):
367
+ block = CAMDenseTDNNBlock(
368
+ num_layers=num_layers,
369
+ in_channels=channels,
370
+ out_channels=growth_rate,
371
+ bn_channels=bn_size * growth_rate,
372
+ kernel_size=kernel_size,
373
+ dilation=dilation,
374
+ config_str=config_str,
375
+ )
376
+ self.xvector.add_module("block%d" % (i + 1), block)
377
+ channels = channels + num_layers * growth_rate
378
+ self.xvector.add_module(
379
+ "transit%d" % (i + 1),
380
+ TransitLayer(
381
+ channels, channels // 2, bias=False, config_str=config_str
382
+ ),
383
+ )
384
+ channels //= 2
385
+
386
+ self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels))
387
+
388
+ self.pool = getattr(pooling_layers, pooling_func)(in_dim=channels)
389
+ self.pool_out_dim = self.pool.get_out_dim()
390
+ self.xvector.add_module("stats", self.pool)
391
+ self.xvector.add_module(
392
+ "dense", DenseLayer(self.pool_out_dim, embed_dim, config_str="batchnorm_")
393
+ )
394
+
395
+ for m in self.modules():
396
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
397
+ nn.init.kaiming_normal_(m.weight.data)
398
+ if m.bias is not None:
399
+ nn.init.zeros_(m.bias)
400
+
401
+ def forward(self, x):
402
+ x = self.feature_extractor(x)
403
+ # x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
404
+ x = self.head(x)
405
+ x = self.xvector(x)
406
+
407
+ return x
src/flow.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import module as modules
3
+
4
+
5
+ class ResidualCouplingBlock(nn.Module):
6
+ def __init__(
7
+ self,
8
+ channels,
9
+ hidden_channels,
10
+ kernel_size,
11
+ dilation_rate,
12
+ n_layers,
13
+ n_flows=4,
14
+ gin_channels=0,
15
+ ):
16
+ super().__init__()
17
+ self.channels = channels
18
+ self.hidden_channels = hidden_channels
19
+ self.kernel_size = kernel_size
20
+ self.dilation_rate = dilation_rate
21
+ self.n_layers = n_layers
22
+ self.n_flows = n_flows
23
+ self.gin_channels = gin_channels
24
+
25
+ self.flows = nn.ModuleList()
26
+ for i in range(n_flows):
27
+ self.flows.append(
28
+ modules.ResidualCouplingLayer(
29
+ channels,
30
+ hidden_channels,
31
+ kernel_size,
32
+ dilation_rate,
33
+ n_layers,
34
+ gin_channels=gin_channels,
35
+ mean_only=True,
36
+ )
37
+ )
38
+ self.flows.append(modules.Flip())
39
+
40
+ def forward(self, x, x_mask, g=None, reverse=False):
41
+ if not reverse:
42
+ for flow in self.flows:
43
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
44
+ else:
45
+ for flow in reversed(self.flows):
46
+ x = flow(x, x_mask, g=g, reverse=reverse)
47
+ return x
src/hifi_gan.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+
7
+ import module as modules
8
+ from module.resblocks import ResBlock1, ResBlock2
9
+ from tools.commons import init_weights
10
+
11
+
12
+ class Generator(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ initial_channel,
16
+ resblock,
17
+ resblock_kernel_sizes,
18
+ resblock_dilation_sizes,
19
+ upsample_rates,
20
+ upsample_initial_channel,
21
+ upsample_kernel_sizes,
22
+ gin_channels,
23
+ activation="snakebeta",
24
+ snake_logscale=True,
25
+ ):
26
+ super(Generator, self).__init__()
27
+ self.num_kernels = len(resblock_kernel_sizes)
28
+ self.num_upsamples = len(upsample_rates)
29
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, 3)
30
+ resblock = ResBlock1 if resblock == "1" else ResBlock2
31
+
32
+ self.ups = nn.ModuleList()
33
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
34
+ self.ups.append(
35
+ weight_norm(
36
+ ConvTranspose1d(
37
+ in_channels=upsample_initial_channel // (2**i),
38
+ out_channels=upsample_initial_channel // (2 ** (i + 1)),
39
+ kernel_size=k,
40
+ stride=u,
41
+ padding=(k - u) // 2,
42
+ )
43
+ )
44
+ )
45
+
46
+ self.resblocks = nn.ModuleList()
47
+ for i in range(len(self.ups)):
48
+ ch = upsample_initial_channel // (2 ** (i + 1))
49
+ for j, (k, d) in enumerate(
50
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
51
+ ):
52
+ self.resblocks.append(resblock(ch, k, d))
53
+
54
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
55
+ self.ups.apply(init_weights)
56
+
57
+ if gin_channels != 0:
58
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
59
+
60
+ def forward(self, x, g=None):
61
+ x = self.conv_pre(x)
62
+ if g is not None:
63
+ x = x + self.cond(g)
64
+
65
+ for i in range(self.num_upsamples):
66
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
67
+ x = self.ups[i](x)
68
+ xs = None
69
+ for j in range(self.num_kernels):
70
+ if xs is None:
71
+ xs = self.resblocks[i * self.num_kernels + j](x)
72
+ else:
73
+ xs += self.resblocks[i * self.num_kernels + j](x)
74
+ x = xs / self.num_kernels
75
+ x = F.leaky_relu(x)
76
+ x = self.conv_post(x)
77
+ x = torch.tanh(x)
78
+
79
+ return x
80
+
81
+ def remove_weight_norm(self):
82
+ print("Removing weight norm...")
83
+ for l in self.ups:
84
+ remove_weight_norm(l)
85
+ for l in self.resblocks:
86
+ l.remove_weight_norm()
src/hubert_posterior.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .inference_hubert import InferenceHubertBase
5
+ from .vae_memory_bank import VAEMemoryBank
6
+
7
+
8
+ def create_padding_mask(waveforms_lengths: torch.Tensor = None):
9
+ if waveforms_lengths is None:
10
+ return None
11
+ batch = waveforms_lengths.shape[0]
12
+ max_len = waveforms_lengths.max()
13
+ device = waveforms_lengths.device
14
+ padding_mask = torch.ones(batch, max_len, dtype=torch.bool, device=device)
15
+ for idx, length in enumerate(waveforms_lengths):
16
+ padding_mask[idx, :length] = 0
17
+ return padding_mask
18
+
19
+
20
+ def unfreeze_layers(model: nn.Module, root_name: str):
21
+ for name, param in model.named_parameters():
22
+ if root_name in name[: len(root_name)]:
23
+ param.requires_grad = True
24
+
25
+
26
+ class PosteriorHubert(nn.Module):
27
+ def __init__(
28
+ self, out_channels, feature_channels, downsample_channels, output_layer=11
29
+ ) -> None:
30
+ super().__init__()
31
+ self.out_channels = out_channels
32
+ self.feature_channels = feature_channels
33
+ self.downsample_channels = downsample_channels
34
+ self.output_layer = output_layer
35
+
36
+ self.hubert = InferenceHubertBase()
37
+ self.memory_bank = VAEMemoryBank(
38
+ n_hidden_dims=feature_channels,
39
+ bank_size=1000,
40
+ output_channels=downsample_channels,
41
+ )
42
+
43
+ self.proj = nn.Conv1d(downsample_channels, out_channels * 2, 1)
44
+
45
+ def forward(self, waveforms: torch.Tensor, waveforms_lengths: torch.Tensor, g=None):
46
+ features, features_mask = self.hubert.extract_features(
47
+ source=waveforms,
48
+ padding_mask=create_padding_mask(waveforms_lengths),
49
+ output_layer=self.output_layer,
50
+ )
51
+ x = self.memory_bank(features.transpose(1, 2))
52
+ x_mask = (~features_mask).unsqueeze(1).to(torch.float32)
53
+ x = x[:, :, : x_mask.shape[-1]]
54
+
55
+ stats = self.proj(x) * x_mask
56
+ m, logs = torch.split(stats, self.out_channels, dim=1)
57
+
58
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
59
+ return z, m, logs, x_mask
src/inference_hubert/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hubert import InferenceHubertBase
src/inference_hubert/fairseq_modules.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Classes reused from:
3
+ 1. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fp32_group_norm.py
4
+ 2. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/same_pad.py
5
+ 3. https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/fairseq_dropout.py
6
+ """
7
+
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class Fp32GroupNorm(nn.GroupNorm):
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__(*args, **kwargs)
15
+
16
+ def forward(self, input):
17
+ output = F.group_norm(
18
+ input.float(),
19
+ self.num_groups,
20
+ self.weight.float() if self.weight is not None else None,
21
+ self.bias.float() if self.bias is not None else None,
22
+ self.eps,
23
+ )
24
+ return output.type_as(input)
25
+
26
+
27
+ class SamePad(nn.Module):
28
+ def __init__(self, kernel_size, causal=False):
29
+ super().__init__()
30
+ if causal:
31
+ self.remove = kernel_size - 1
32
+ else:
33
+ self.remove = 1 if kernel_size % 2 == 0 else 0
34
+
35
+ def forward(self, x):
36
+ if self.remove > 0:
37
+ x = x[:, :, : -self.remove]
38
+ return x
39
+
40
+
41
+ class FairseqDropout(nn.Module):
42
+ def __init__(self, p, module_name=None):
43
+ super().__init__()
44
+ self.p = p
45
+ self.module_name = module_name
46
+ self.apply_during_inference = False
47
+
48
+ def forward(self, x, inplace: bool = False):
49
+ if self.p > 0 and (self.training or self.apply_during_inference):
50
+ return F.dropout(x, p=self.p, training=True, inplace=inplace)
51
+ else:
52
+ return x
src/inference_hubert/hubert.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The `InferenceHubertBase` class is a lightweight version of the model from this repository:
3
+ https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/hubert/hubert.py#L248C5-L248C6
4
+ """
5
+
6
+ import math
7
+ from typing import Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+
14
+ from .fairseq_modules import Fp32GroupNorm, SamePad, FairseqDropout
15
+
16
+
17
+ class InferenceHubertBase(nn.Module):
18
+ def __init__(self, *args, **kwargs) -> None:
19
+ super().__init__(*args, **kwargs)
20
+ self.feature_extractor = ConvFeatureExtractor()
21
+ self.layer_norm = nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True)
22
+ self.post_extract_proj = nn.Linear(in_features=512, out_features=768, bias=True)
23
+ self.dropout_input = nn.Dropout(p=0.1, inplace=False)
24
+ self.dropout_features = nn.Dropout(p=0.1, inplace=False)
25
+ self.encoder = TransformerEncoder()
26
+
27
+ def extract_features(
28
+ self,
29
+ source: Tensor,
30
+ padding_mask: Optional[Tensor] = None,
31
+ output_layer: int = 12,
32
+ ) -> Tuple[Tensor, Tensor]:
33
+ features = self.feature_extractor(source).transpose(1, 2)
34
+ features = self.layer_norm(features)
35
+ if padding_mask is not None:
36
+ padding_mask = self.__apply_padding_mask(features, padding_mask)
37
+ features = self.post_extract_proj(features)
38
+ features = self.dropout_input(features)
39
+ features = self.encoder(
40
+ features, padding_mask=padding_mask, tgt_layer=output_layer - 1
41
+ )
42
+ return features, padding_mask
43
+
44
+ def __apply_padding_mask(self, features: Tensor, padding_mask: Tensor) -> Tensor:
45
+ extra = padding_mask.size(1) % features.size(1)
46
+ if extra > 0:
47
+ padding_mask = padding_mask[:, :-extra]
48
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
49
+ padding_mask = padding_mask.all(-1)
50
+ return padding_mask
51
+
52
+
53
+ class ConvFeatureExtractor(nn.Module):
54
+ def __init__(self, *args, **kwargs) -> None:
55
+ super().__init__(*args, **kwargs)
56
+ conv_layers = [
57
+ nn.Sequential(
58
+ nn.Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False),
59
+ nn.Dropout(p=0.0, inplace=False),
60
+ Fp32GroupNorm(512, 512, eps=1e-05, affine=True),
61
+ nn.GELU(approximate="none"),
62
+ ),
63
+ *[
64
+ nn.Sequential(
65
+ nn.Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False),
66
+ nn.Dropout(p=0.0, inplace=False),
67
+ nn.GELU(approximate="none"),
68
+ )
69
+ for _ in range(4)
70
+ ],
71
+ *[
72
+ nn.Sequential(
73
+ nn.Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False),
74
+ nn.Dropout(p=0.0, inplace=False),
75
+ nn.GELU(approximate="none"),
76
+ )
77
+ for _ in range(2)
78
+ ],
79
+ ]
80
+ self.conv_layers = nn.ModuleList(conv_layers)
81
+
82
+ def forward(self, x: Tensor):
83
+ x = x.unsqueeze(1)
84
+ for conv in self.conv_layers:
85
+ x = conv(x)
86
+ return x
87
+
88
+
89
+ class TransformerEncoder(nn.Module):
90
+ def __init__(
91
+ self, dropout=0.1, required_seq_len_multiple=2, *args, **kwargs
92
+ ) -> None:
93
+ super().__init__(*args, **kwargs)
94
+ self.dropout = dropout # 0.1
95
+ self.required_seq_len_multiple = required_seq_len_multiple # 2
96
+
97
+ pos_conv = nn.Conv1d(
98
+ 768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16
99
+ )
100
+ self.pos_conv = nn.Sequential(
101
+ nn.utils.weight_norm(pos_conv, name="weight", dim=2),
102
+ SamePad(128),
103
+ nn.GELU(approximate="none"),
104
+ )
105
+ self.layers = nn.ModuleList(
106
+ [TransformerSentenceEncoderLayer() for _ in range(12)]
107
+ )
108
+ self.layer_norm = nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True)
109
+
110
+ @torch.no_grad()
111
+ def forward(self, x: Tensor, padding_mask=None, tgt_layer=None):
112
+ if padding_mask is not None:
113
+ # x = index_put(x, padding_mask, 0)
114
+ x[padding_mask] = 0
115
+
116
+ x_conv = self.pos_conv(x.transpose(1, 2))
117
+ x_conv = x_conv.transpose(1, 2)
118
+ x = x + x_conv
119
+
120
+ x = self.layer_norm(x)
121
+
122
+ # pad to the sequence length dimension
123
+ x, pad_length = pad_to_multiple(
124
+ x, self.required_seq_len_multiple, dim=-2, value=0
125
+ )
126
+ if pad_length > 0 and padding_mask is None:
127
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
128
+ padding_mask[:, -pad_length:] = True
129
+ else:
130
+ padding_mask, _ = pad_to_multiple(
131
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
132
+ )
133
+ x = F.dropout(x, p=self.dropout, training=self.training)
134
+
135
+ # B x T x C -> T x B x C
136
+ x = x.transpose(0, 1)
137
+
138
+ for i, layer in enumerate(self.layers):
139
+ x, _ = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
140
+ if i == tgt_layer:
141
+ break
142
+
143
+ # T x B x C -> B x T x C
144
+ x = x.transpose(0, 1)
145
+ return x
146
+
147
+
148
+ class TransformerSentenceEncoderLayer(nn.Module):
149
+ def __init__(
150
+ self,
151
+ embedding_dim: float = 768,
152
+ ffn_embedding_dim: float = 3072,
153
+ num_attention_heads: int = 12,
154
+ dropout: float = 0.1,
155
+ attention_dropout: float = 0.1,
156
+ activation_dropout: float = 0.1,
157
+ layer_norm_first: bool = False,
158
+ *args,
159
+ **kwargs,
160
+ ) -> None:
161
+ super().__init__(*args, **kwargs)
162
+ self.embedding_dim = embedding_dim
163
+ self.ffn_embedding_dim = ffn_embedding_dim
164
+ self.num_attention_heads = num_attention_heads
165
+
166
+ self.self_attn = MultiheadAttention(
167
+ self.embedding_dim, # 768
168
+ num_attention_heads, # 12
169
+ dropout=attention_dropout, # 0.1
170
+ )
171
+ self.dropout1 = nn.Dropout(dropout)
172
+ self.dropout2 = nn.Dropout(activation_dropout)
173
+ self.dropout3 = nn.Dropout(dropout)
174
+ self.layer_norm_first = layer_norm_first
175
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
176
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
177
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
178
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
179
+
180
+ def forward(
181
+ self,
182
+ x: torch.Tensor,
183
+ self_attn_mask: torch.Tensor = None,
184
+ self_attn_padding_mask: torch.Tensor = None,
185
+ need_weights: bool = False,
186
+ att_args=None,
187
+ ):
188
+ residual = x
189
+ x, attn = self.self_attn(
190
+ query=x,
191
+ key=x,
192
+ value=x,
193
+ key_padding_mask=self_attn_padding_mask,
194
+ need_weights=False,
195
+ )
196
+
197
+ x = self.dropout1(x)
198
+ x = residual + x
199
+
200
+ x = self.self_attn_layer_norm(x)
201
+
202
+ residual = x
203
+ x = F.gelu(self.fc1(x).float()).type_as(x)
204
+ x = self.dropout2(x)
205
+ x = self.fc2(x)
206
+
207
+ layer_result = x
208
+
209
+ x = self.dropout3(x)
210
+ x = residual + x
211
+ x = self.final_layer_norm(x)
212
+
213
+ return x, (attn, layer_result)
214
+
215
+
216
+ class MultiheadAttention(nn.Module):
217
+ def __init__(
218
+ self, embed_dim: int, num_heads: int, dropout=0.1, bias=True, *args, **kwargs
219
+ ) -> None:
220
+ super().__init__(*args, **kwargs)
221
+ self.embed_dim = embed_dim
222
+ self.num_heads = num_heads
223
+
224
+ self.dropout_module = FairseqDropout(p=dropout)
225
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
226
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
227
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
228
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
229
+
230
+ def forward(
231
+ self,
232
+ query: Tensor,
233
+ key: Tensor,
234
+ value: Tensor,
235
+ key_padding_mask: Optional[Tensor] = None,
236
+ need_weights: bool = False,
237
+ attn_mask: Optional[Tensor] = None,
238
+ ) -> Tuple[Tensor, Optional[Tensor]]:
239
+
240
+ tgt_len, bsz, embed_dim = query.size()
241
+ src_len = tgt_len
242
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
243
+ src_len, key_bsz, _ = key.size()
244
+ assert src_len, key_bsz == value.shape[:2]
245
+ return F.multi_head_attention_forward(
246
+ query,
247
+ key,
248
+ value,
249
+ self.embed_dim,
250
+ self.num_heads,
251
+ torch.empty([0]),
252
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
253
+ None,
254
+ None,
255
+ False,
256
+ self.dropout_module.p,
257
+ self.out_proj.weight,
258
+ self.out_proj.bias,
259
+ self.training or self.dropout_module.apply_during_inference,
260
+ key_padding_mask.bool() if key_padding_mask is not None else None,
261
+ need_weights,
262
+ attn_mask,
263
+ use_separate_proj_weight=True,
264
+ q_proj_weight=self.q_proj.weight,
265
+ k_proj_weight=self.k_proj.weight,
266
+ v_proj_weight=self.v_proj.weight,
267
+ )
268
+
269
+
270
+ def pad_to_multiple(x, multiple, dim=-1, value=0):
271
+ # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
272
+ if x is None:
273
+ return None, 0
274
+ tsz = x.size(dim)
275
+ m = tsz / multiple
276
+ remainder = math.ceil(m) * multiple - tsz
277
+ if m.is_integer():
278
+ return x, 0
279
+ pad_offset = (0,) * (-1 - dim) * 2
280
+
281
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
src/vae_memory_bank.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .attentions import MultiHeadAttention
5
+
6
+
7
+ class VAEMemoryBank(nn.Module):
8
+ def __init__(
9
+ self,
10
+ bank_size=1000,
11
+ n_hidden_dims=512,
12
+ n_attn_heads=2,
13
+ init_values=None,
14
+ output_channels=192,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.bank_size = bank_size
19
+ self.n_hidden_dims = n_hidden_dims
20
+ self.n_attn_heads = n_attn_heads
21
+
22
+ self.encoder = MultiHeadAttention(
23
+ channels=n_hidden_dims,
24
+ out_channels=n_hidden_dims,
25
+ n_heads=n_attn_heads,
26
+ )
27
+
28
+ self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size))
29
+ self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1)
30
+ if init_values is not None:
31
+ with torch.no_grad():
32
+ self.memory_bank.copy_(init_values)
33
+
34
+ def forward(self, z: torch.Tensor):
35
+ b, _, _ = z.shape
36
+ ret = self.encoder(
37
+ z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None
38
+ )
39
+ ret = self.proj(ret)
40
+ return ret
src/wespeaker_campplus/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ The methods in this module are reused from this repository
3
+ https://github.com/wenet-e2e/wespeaker
4
+ """
src/wespeaker_campplus/fbank_feature_extractor.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+
6
+ class PreEmphasis(torch.nn.Module):
7
+
8
+ def __init__(self, coef: float = 0.97):
9
+ super().__init__()
10
+ self.coef = coef
11
+ self.register_buffer(
12
+ 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
13
+ )
14
+
15
+ def forward(self, input: torch.tensor) -> torch.tensor:
16
+ input = input.unsqueeze(1)
17
+ input = F.pad(input, (1, 0), 'reflect')
18
+ return F.conv1d(input, self.flipped_filter).squeeze(1)
19
+
20
+ class FbankFeatureExtractor(nn.Module):
21
+ """Some Information about MyModule"""
22
+ def __init__(self, feat_dim = 80, f_max = 7600, **kwargs):
23
+ super(FbankFeatureExtractor, self, ).__init__()
24
+
25
+ self.torchfbank = torch.nn.Sequential(
26
+ PreEmphasis(),
27
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, \
28
+ f_min = 20, f_max = f_max, window_fn=torch.hamming_window, n_mels=feat_dim),
29
+ )
30
+
31
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
32
+
33
+ def forward(self, x):
34
+ with torch.no_grad():
35
+ x = self.torchfbank(x)+1e-6
36
+ x = x.log()
37
+ x = x - torch.mean(x, dim=-1, keepdim=True)
38
+ return x
src/wespeaker_campplus/pooling_layers.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Pooling functions to aggregate frame-level deep features
16
+ into segment-level speaker embeddings
17
+
18
+ High-order statistics are surprisingly effective, TSDP acts similarly as TSTP,
19
+ even though we remove the mean statistic, on Voxceleb.
20
+ """
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+
26
+
27
+ class TAP(nn.Module):
28
+ """
29
+ Temporal average pooling, only first-order mean is considered
30
+ """
31
+
32
+ def __init__(self, in_dim=0, **kwargs):
33
+ super(TAP, self).__init__()
34
+ self.in_dim = in_dim
35
+
36
+ def forward(self, x):
37
+ pooling_mean = x.mean(dim=-1)
38
+ # To be compatable with 2D input
39
+ pooling_mean = pooling_mean.flatten(start_dim=1)
40
+ return pooling_mean
41
+
42
+ def get_out_dim(self):
43
+ self.out_dim = self.in_dim
44
+ return self.out_dim
45
+
46
+
47
+ class TSDP(nn.Module):
48
+ """
49
+ Temporal standard deviation pooling, only second-order std is considered
50
+ """
51
+
52
+ def __init__(self, in_dim=0, **kwargs):
53
+ super(TSDP, self).__init__()
54
+ self.in_dim = in_dim
55
+
56
+ def forward(self, x):
57
+ # The last dimension is the temporal axis
58
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
59
+ pooling_std = pooling_std.flatten(start_dim=1)
60
+ return pooling_std
61
+
62
+ def get_out_dim(self):
63
+ self.out_dim = self.in_dim
64
+ return self.out_dim
65
+
66
+
67
+ class TSTP(nn.Module):
68
+ """
69
+ Temporal statistics pooling, concatenate mean and std, which is used in
70
+ x-vector
71
+ Comment: simple concatenation can not make full use of both statistics
72
+ """
73
+
74
+ def __init__(self, in_dim=0, **kwargs):
75
+ super(TSTP, self).__init__()
76
+ self.in_dim = in_dim
77
+
78
+ def forward(self, x):
79
+ # The last dimension is the temporal axis
80
+ pooling_mean = x.mean(dim=-1)
81
+ pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
82
+ pooling_mean = pooling_mean.flatten(start_dim=1)
83
+ pooling_std = pooling_std.flatten(start_dim=1)
84
+ stats = torch.cat((pooling_mean, pooling_std), 1)
85
+ return stats
86
+
87
+ def get_out_dim(self):
88
+ self.out_dim = self.in_dim * 2
89
+ return self.out_dim
90
+
91
+
92
+ class ASTP(nn.Module):
93
+ """Attentive statistics pooling: Channel- and context-dependent
94
+ statistics pooling, first used in ECAPA_TDNN.
95
+ """
96
+
97
+ def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False, **kwargs):
98
+ super(ASTP, self).__init__()
99
+ self.in_dim = in_dim
100
+ self.global_context_att = global_context_att
101
+
102
+ # Use Conv1d with stride == 1 rather than Linear, then we don't
103
+ # need to transpose inputs.
104
+ if global_context_att:
105
+ self.linear1 = nn.Conv1d(
106
+ in_dim * 3, bottleneck_dim, kernel_size=1
107
+ ) # equals W and b in the paper
108
+ else:
109
+ self.linear1 = nn.Conv1d(
110
+ in_dim, bottleneck_dim, kernel_size=1
111
+ ) # equals W and b in the paper
112
+ self.linear2 = nn.Conv1d(
113
+ bottleneck_dim, in_dim, kernel_size=1
114
+ ) # equals V and k in the paper
115
+
116
+ def forward(self, x):
117
+ """
118
+ x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
119
+ or a 4-dimensional tensor in resnet architecture (B,C,F,T)
120
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
121
+ """
122
+ if len(x.shape) == 4:
123
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
124
+ assert len(x.shape) == 3
125
+
126
+ if self.global_context_att:
127
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
128
+ context_std = torch.sqrt(
129
+ torch.var(x, dim=-1, keepdim=True) + 1e-7
130
+ ).expand_as(x)
131
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
132
+ else:
133
+ x_in = x
134
+
135
+ # DON'T use ReLU here! ReLU may be hard to converge.
136
+ alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
137
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
138
+ mean = torch.sum(alpha * x, dim=2)
139
+ var = torch.sum(alpha * (x**2), dim=2) - mean**2
140
+ std = torch.sqrt(var.clamp(min=1e-7))
141
+ return torch.cat([mean, std], dim=1)
142
+
143
+ def get_out_dim(self):
144
+ self.out_dim = 2 * self.in_dim
145
+ return self.out_dim
146
+
147
+
148
+ class MHASTP(torch.nn.Module):
149
+ """Multi head attentive statistics pooling
150
+ Reference:
151
+ Self Multi-Head Attention for Speaker Recognition
152
+ https://arxiv.org/pdf/1906.09890.pdf
153
+ """
154
+
155
+ def __init__(
156
+ self, in_dim, layer_num=2, head_num=2, d_s=1, bottleneck_dim=64, **kwargs
157
+ ):
158
+ super(MHASTP, self).__init__()
159
+ assert (
160
+ in_dim % head_num
161
+ ) == 0 # make sure that head num can be divided by input_dim
162
+ self.in_dim = in_dim
163
+ self.head_num = head_num
164
+ d_model = int(in_dim / head_num)
165
+ channel_dims = [bottleneck_dim for i in range(layer_num + 1)]
166
+ if d_s > 1:
167
+ d_s = d_model
168
+ else:
169
+ d_s = 1
170
+ self.d_s = d_s
171
+ channel_dims[0], channel_dims[-1] = d_model, d_s
172
+ heads_att_trans = []
173
+ for i in range(self.head_num):
174
+ att_trans = nn.Sequential()
175
+ for i in range(layer_num - 1):
176
+ att_trans.add_module(
177
+ "att_" + str(i),
178
+ nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1),
179
+ )
180
+ att_trans.add_module("tanh" + str(i), nn.Tanh())
181
+ att_trans.add_module(
182
+ "att_" + str(layer_num - 1),
183
+ nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], 1, 1),
184
+ )
185
+ heads_att_trans.append(att_trans)
186
+ self.heads_att_trans = nn.ModuleList(heads_att_trans)
187
+
188
+ def forward(self, input):
189
+ """
190
+ input: a 3-dimensional tensor in xvector architecture
191
+ or a 4-dimensional tensor in resnet architecture
192
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
193
+ """
194
+ if len(input.shape) == 4: # B x F x T
195
+ input = input.reshape(
196
+ input.shape[0], input.shape[1] * input.shape[2], input.shape[3]
197
+ )
198
+ assert len(input.shape) == 3
199
+ bs, f_dim, t_dim = input.shape
200
+ chunks = torch.chunk(input, self.head_num, 1)
201
+ # split
202
+ chunks_out = []
203
+ # for i in range(self.head_num):
204
+ # att_score = self.heads_att_trans[i](chunks[i])
205
+ for i, layer in enumerate(self.heads_att_trans):
206
+ att_score = layer(chunks[i])
207
+ alpha = F.softmax(att_score, dim=-1)
208
+ mean = torch.sum(alpha * chunks[i], dim=2)
209
+ var = torch.sum(alpha * chunks[i] ** 2, dim=2) - mean**2
210
+ std = torch.sqrt(var.clamp(min=1e-7))
211
+ chunks_out.append(torch.cat((mean, std), dim=1))
212
+ out = torch.cat(chunks_out, dim=1)
213
+ return out
214
+
215
+ def get_out_dim(self):
216
+ self.out_dim = 2 * self.in_dim
217
+ return self.out_dim
218
+
219
+
220
+ class MQMHASTP(torch.nn.Module):
221
+ """An attentive pooling
222
+ Reference:
223
+ multi query multi head attentive statistics pooling
224
+ https://arxiv.org/pdf/2110.05042.pdf
225
+ Args:
226
+ in_dim: the feature dimension of input
227
+ layer_num: the number of layer in the pooling layer
228
+ query_num: the number of querys
229
+ head_num: the number of heads
230
+ bottleneck_dim: the bottleneck dimension
231
+
232
+ SA (H = 1, Q = 1, n = 2, d_s = 1) ref:
233
+ https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf
234
+ MHA (H > 1, Q = 1, n = 1, d_s = 1) ref:
235
+ https://arxiv.org/pdf/1906.09890.pdf
236
+ AS (H = 1, Q > 1, n = 2, d_s = 1) ref:
237
+ https://arxiv.org/pdf/1803.10963.pdf
238
+ VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref:
239
+ http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ in_dim,
245
+ layer_num=2,
246
+ query_num=2,
247
+ head_num=8,
248
+ d_s=2,
249
+ bottleneck_dim=64,
250
+ **kwargs
251
+ ):
252
+ super(MQMHASTP, self).__init__()
253
+ self.n_query = nn.ModuleList(
254
+ [
255
+ MHASTP(
256
+ in_dim,
257
+ layer_num=layer_num,
258
+ head_num=head_num,
259
+ d_s=d_s,
260
+ bottleneck_dim=bottleneck_dim,
261
+ )
262
+ for i in range(query_num)
263
+ ]
264
+ )
265
+ self.query_num = query_num
266
+ self.in_dim = in_dim
267
+
268
+ def forward(self, input):
269
+ """
270
+ input: a 3-dimensional tensor in xvector architecture
271
+ or a 4-dimensional tensor in resnet architecture
272
+ 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
273
+ """
274
+ if len(input.shape) == 4: # B x F x T
275
+ input = input.reshape(
276
+ input.shape[0], input.shape[1] * input.shape[2], input.shape[3]
277
+ )
278
+ assert len(input.shape) == 3
279
+ res = []
280
+ for i, layer in enumerate(self.n_query):
281
+ res.append(layer(input))
282
+ out = torch.cat(res, dim=-1)
283
+ return out
284
+
285
+ def get_out_dim(self):
286
+ self.out_dim = self.in_dim * 2 * self.query_num
287
+ return self.out_dim
tools/commons.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+
9
+ def init_weights(m, mean=0.0, std=0.01):
10
+ classname = m.__class__.__name__
11
+ if classname.find("Conv") != -1:
12
+ try:
13
+ m.weight.data.normal_(mean, std)
14
+ except:
15
+ try:
16
+ m.conv.weight.data.normal_(mean, std)
17
+ except:
18
+ m.conv.conv.weight.data.normal_(mean, std)
19
+
20
+
21
+ def get_padding(kernel_size, dilation=1):
22
+ return int((kernel_size * dilation - dilation) / 2)
23
+
24
+
25
+ def convert_pad_shape(pad_shape):
26
+ l = pad_shape[::-1]
27
+ pad_shape = [item for sublist in l for item in sublist]
28
+ return pad_shape
29
+
30
+
31
+ def intersperse(lst, item):
32
+ result = [item] * (len(lst) * 2 + 1)
33
+ result[1::2] = lst
34
+ return result
35
+
36
+
37
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
38
+ """KL(P||Q)"""
39
+ kl = (logs_q - logs_p) - 0.5
40
+ kl += (
41
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
42
+ )
43
+ return kl
44
+
45
+
46
+ def rand_gumbel(shape):
47
+ """Sample from the Gumbel distribution, protect from overflows."""
48
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
49
+ return -torch.log(-torch.log(uniform_samples))
50
+
51
+
52
+ def rand_gumbel_like(x):
53
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
54
+ return g
55
+
56
+
57
+ def slice_segments(x, ids_str, segment_size=4):
58
+ ret = torch.zeros_like(x[:, :, :segment_size])
59
+ for i in range(x.size(0)):
60
+ idx_str = ids_str[i]
61
+ idx_end = idx_str + segment_size
62
+ ret[i] = x[i, :, idx_str:idx_end]
63
+ return ret
64
+
65
+
66
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
67
+ b, d, t = x.size()
68
+ if x_lengths is None:
69
+ x_lengths = t
70
+ ids_str_max = x_lengths - segment_size + 1
71
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
72
+ ret = slice_segments(x, ids_str, segment_size)
73
+ return ret, ids_str
74
+
75
+
76
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
77
+ position = torch.arange(length, dtype=torch.float)
78
+ num_timescales = channels // 2
79
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
80
+ num_timescales - 1
81
+ )
82
+ inv_timescales = min_timescale * torch.exp(
83
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
84
+ )
85
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
86
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
87
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
88
+ signal = signal.view(1, channels, length)
89
+ return signal
90
+
91
+
92
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
93
+ b, channels, length = x.size()
94
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
95
+ return x + signal.to(dtype=x.dtype, device=x.device)
96
+
97
+
98
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
99
+ b, channels, length = x.size()
100
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
101
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
102
+
103
+
104
+ def subsequent_mask(length):
105
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
106
+ return mask
107
+
108
+
109
+ @torch.jit.script
110
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
111
+ n_channels_int = n_channels[0]
112
+ in_act = input_a + input_b
113
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
114
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
115
+ acts = t_act * s_act
116
+ return acts
117
+
118
+
119
+ def convert_pad_shape(pad_shape):
120
+ l = pad_shape[::-1]
121
+ pad_shape = [item for sublist in l for item in sublist]
122
+ return pad_shape
123
+
124
+
125
+ def shift_1d(x):
126
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
127
+ return x
128
+
129
+
130
+ def sequence_mask(length: torch.LongTensor, max_length: int = None):
131
+ """_summary_
132
+
133
+ Args:
134
+ length (torch.LongTensor): 1d sequence of lengths of some sequence [BATCH_SIZE]
135
+ max_length (int, optional) : max length of sequence. Defaults to None.
136
+
137
+ Returns:
138
+ _type_: _description_
139
+ """
140
+ if max_length is None:
141
+ max_length = length.max()
142
+
143
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
144
+
145
+ return x.unsqueeze(0) < length.unsqueeze(1)
146
+
147
+
148
+ def generate_path(duration, mask):
149
+ """
150
+ duration: [b, 1, t_x]
151
+ mask: [b, 1, t_y, t_x]
152
+ """
153
+ device = duration.device
154
+
155
+ b, _, t_y, t_x = mask.shape
156
+ cum_duration = torch.cumsum(duration, -1)
157
+
158
+ cum_duration_flat = cum_duration.view(b * t_x)
159
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
160
+ path = path.view(b, t_x, t_y)
161
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
162
+ path = path.unsqueeze(1).transpose(2, 3).contiguous() * mask
163
+ return path
164
+
165
+
166
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
167
+ if isinstance(parameters, torch.Tensor):
168
+ parameters = [parameters]
169
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
170
+ norm_type = float(norm_type)
171
+ if clip_value is not None:
172
+ clip_value = float(clip_value)
173
+
174
+ total_norm = 0
175
+ for p in parameters:
176
+ param_norm = p.grad.data.norm(norm_type)
177
+ total_norm += param_norm.item() ** norm_type
178
+ if clip_value is not None:
179
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
180
+ total_norm = total_norm ** (1.0 / norm_type)
181
+ return total_norm
tools/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet