kevinwang676 commited on
Commit
c2d34ab
·
verified ·
1 Parent(s): c120033

Add files using upload-large-folder tool

Browse files
third_party/Matcha-TTS/matcha/hifigan/xutils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+
6
+ import matplotlib
7
+ import torch
8
+ from torch.nn.utils import weight_norm
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pylab as plt
12
+
13
+
14
+ def plot_spectrogram(spectrogram):
15
+ fig, ax = plt.subplots(figsize=(10, 2))
16
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17
+ plt.colorbar(im, ax=ax)
18
+
19
+ fig.canvas.draw()
20
+ plt.close()
21
+
22
+ return fig
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def apply_weight_norm(m):
32
+ classname = m.__class__.__name__
33
+ if classname.find("Conv") != -1:
34
+ weight_norm(m)
35
+
36
+
37
+ def get_padding(kernel_size, dilation=1):
38
+ return int((kernel_size * dilation - dilation) / 2)
39
+
40
+
41
+ def load_checkpoint(filepath, device):
42
+ assert os.path.isfile(filepath)
43
+ print(f"Loading '{filepath}'")
44
+ checkpoint_dict = torch.load(filepath, map_location=device)
45
+ print("Complete.")
46
+ return checkpoint_dict
47
+
48
+
49
+ def save_checkpoint(filepath, obj):
50
+ print(f"Saving checkpoint to {filepath}")
51
+ torch.save(obj, filepath)
52
+ print("Complete.")
53
+
54
+
55
+ def scan_checkpoint(cp_dir, prefix):
56
+ pattern = os.path.join(cp_dir, prefix + "????????")
57
+ cp_list = glob.glob(pattern)
58
+ if len(cp_list) == 0:
59
+ return None
60
+ return sorted(cp_list)[-1]
third_party/Matcha-TTS/matcha/models/components/decoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from matcha.models.components.transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class Block1D(torch.nn.Module):
33
+ def __init__(self, dim, dim_out, groups=8):
34
+ super().__init__()
35
+ self.block = torch.nn.Sequential(
36
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
+ torch.nn.GroupNorm(groups, dim_out),
38
+ nn.Mish(),
39
+ )
40
+
41
+ def forward(self, x, mask):
42
+ output = self.block(x * mask)
43
+ return output * mask
44
+
45
+
46
+ class ResnetBlock1D(torch.nn.Module):
47
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
+ super().__init__()
49
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
+
51
+ self.block1 = Block1D(dim, dim_out, groups=groups)
52
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
+
54
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
+
56
+ def forward(self, x, mask, time_emb):
57
+ h = self.block1(x, mask)
58
+ h += self.mlp(time_emb).unsqueeze(-1)
59
+ h = self.block2(h, mask)
60
+ output = h + self.res_conv(x * mask)
61
+ return output
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
+
69
+ def forward(self, x):
70
+ return self.conv(x)
71
+
72
+
73
+ class TimestepEmbedding(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ time_embed_dim: int,
78
+ act_fn: str = "silu",
79
+ out_dim: int = None,
80
+ post_act_fn: Optional[str] = None,
81
+ cond_proj_dim=None,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
+ else:
90
+ self.cond_proj = None
91
+
92
+ self.act = get_activation(act_fn)
93
+
94
+ if out_dim is not None:
95
+ time_embed_dim_out = out_dim
96
+ else:
97
+ time_embed_dim_out = time_embed_dim
98
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
+
100
+ if post_act_fn is None:
101
+ self.post_act = None
102
+ else:
103
+ self.post_act = get_activation(post_act_fn)
104
+
105
+ def forward(self, sample, condition=None):
106
+ if condition is not None:
107
+ sample = sample + self.cond_proj(condition)
108
+ sample = self.linear_1(sample)
109
+
110
+ if self.act is not None:
111
+ sample = self.act(sample)
112
+
113
+ sample = self.linear_2(sample)
114
+
115
+ if self.post_act is not None:
116
+ sample = self.post_act(sample)
117
+ return sample
118
+
119
+
120
+ class Upsample1D(nn.Module):
121
+ """A 1D upsampling layer with an optional convolution.
122
+
123
+ Parameters:
124
+ channels (`int`):
125
+ number of channels in the inputs and outputs.
126
+ use_conv (`bool`, default `False`):
127
+ option to use a convolution.
128
+ use_conv_transpose (`bool`, default `False`):
129
+ option to use a convolution transpose.
130
+ out_channels (`int`, optional):
131
+ number of output channels. Defaults to `channels`.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.use_conv_transpose = use_conv_transpose
140
+ self.name = name
141
+
142
+ self.conv = None
143
+ if use_conv_transpose:
144
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
+ elif use_conv:
146
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, inputs):
149
+ assert inputs.shape[1] == self.channels
150
+ if self.use_conv_transpose:
151
+ return self.conv(inputs)
152
+
153
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
+
155
+ if self.use_conv:
156
+ outputs = self.conv(outputs)
157
+
158
+ return outputs
159
+
160
+
161
+ class ConformerWrapper(ConformerBlock):
162
+ def __init__( # pylint: disable=useless-super-delegation
163
+ self,
164
+ *,
165
+ dim,
166
+ dim_head=64,
167
+ heads=8,
168
+ ff_mult=4,
169
+ conv_expansion_factor=2,
170
+ conv_kernel_size=31,
171
+ attn_dropout=0,
172
+ ff_dropout=0,
173
+ conv_dropout=0,
174
+ conv_causal=False,
175
+ ):
176
+ super().__init__(
177
+ dim=dim,
178
+ dim_head=dim_head,
179
+ heads=heads,
180
+ ff_mult=ff_mult,
181
+ conv_expansion_factor=conv_expansion_factor,
182
+ conv_kernel_size=conv_kernel_size,
183
+ attn_dropout=attn_dropout,
184
+ ff_dropout=ff_dropout,
185
+ conv_dropout=conv_dropout,
186
+ conv_causal=conv_causal,
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ attention_mask,
193
+ encoder_hidden_states=None,
194
+ encoder_attention_mask=None,
195
+ timestep=None,
196
+ ):
197
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
198
+
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ channels=(256, 256),
206
+ dropout=0.05,
207
+ attention_head_dim=64,
208
+ n_blocks=1,
209
+ num_mid_blocks=2,
210
+ num_heads=4,
211
+ act_fn="snake",
212
+ down_block_type="transformer",
213
+ mid_block_type="transformer",
214
+ up_block_type="transformer",
215
+ ):
216
+ super().__init__()
217
+ channels = tuple(channels)
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+
221
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
222
+ time_embed_dim = channels[0] * 4
223
+ self.time_mlp = TimestepEmbedding(
224
+ in_channels=in_channels,
225
+ time_embed_dim=time_embed_dim,
226
+ act_fn="silu",
227
+ )
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.mid_blocks = nn.ModuleList([])
231
+ self.up_blocks = nn.ModuleList([])
232
+
233
+ output_channel = in_channels
234
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
+ input_channel = output_channel
236
+ output_channel = channels[i]
237
+ is_last = i == len(channels) - 1
238
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
+ transformer_blocks = nn.ModuleList(
240
+ [
241
+ self.get_block(
242
+ down_block_type,
243
+ output_channel,
244
+ attention_head_dim,
245
+ num_heads,
246
+ dropout,
247
+ act_fn,
248
+ )
249
+ for _ in range(n_blocks)
250
+ ]
251
+ )
252
+ downsample = (
253
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
+ )
255
+
256
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
+
258
+ for i in range(num_mid_blocks):
259
+ input_channel = channels[-1]
260
+ out_channels = channels[-1]
261
+
262
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
+
264
+ transformer_blocks = nn.ModuleList(
265
+ [
266
+ self.get_block(
267
+ mid_block_type,
268
+ output_channel,
269
+ attention_head_dim,
270
+ num_heads,
271
+ dropout,
272
+ act_fn,
273
+ )
274
+ for _ in range(n_blocks)
275
+ ]
276
+ )
277
+
278
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
+
280
+ channels = channels[::-1] + (channels[0],)
281
+ for i in range(len(channels) - 1):
282
+ input_channel = channels[i]
283
+ output_channel = channels[i + 1]
284
+ is_last = i == len(channels) - 2
285
+
286
+ resnet = ResnetBlock1D(
287
+ dim=2 * input_channel,
288
+ dim_out=output_channel,
289
+ time_emb_dim=time_embed_dim,
290
+ )
291
+ transformer_blocks = nn.ModuleList(
292
+ [
293
+ self.get_block(
294
+ up_block_type,
295
+ output_channel,
296
+ attention_head_dim,
297
+ num_heads,
298
+ dropout,
299
+ act_fn,
300
+ )
301
+ for _ in range(n_blocks)
302
+ ]
303
+ )
304
+ upsample = (
305
+ Upsample1D(output_channel, use_conv_transpose=True)
306
+ if not is_last
307
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
+ )
309
+
310
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
+
312
+ self.final_block = Block1D(channels[-1], channels[-1])
313
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
+
315
+ self.initialize_weights()
316
+ # nn.init.normal_(self.final_proj.weight)
317
+
318
+ @staticmethod
319
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
+ if block_type == "conformer":
321
+ block = ConformerWrapper(
322
+ dim=dim,
323
+ dim_head=attention_head_dim,
324
+ heads=num_heads,
325
+ ff_mult=1,
326
+ conv_expansion_factor=2,
327
+ ff_dropout=dropout,
328
+ attn_dropout=dropout,
329
+ conv_dropout=dropout,
330
+ conv_kernel_size=31,
331
+ )
332
+ elif block_type == "transformer":
333
+ block = BasicTransformerBlock(
334
+ dim=dim,
335
+ num_attention_heads=num_heads,
336
+ attention_head_dim=attention_head_dim,
337
+ dropout=dropout,
338
+ activation_fn=act_fn,
339
+ )
340
+ else:
341
+ raise ValueError(f"Unknown block type {block_type}")
342
+
343
+ return block
344
+
345
+ def initialize_weights(self):
346
+ for m in self.modules():
347
+ if isinstance(m, nn.Conv1d):
348
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
+
350
+ if m.bias is not None:
351
+ nn.init.constant_(m.bias, 0)
352
+
353
+ elif isinstance(m, nn.GroupNorm):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ elif isinstance(m, nn.Linear):
358
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
+
360
+ if m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+
363
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
364
+ """Forward pass of the UNet1DConditional model.
365
+
366
+ Args:
367
+ x (torch.Tensor): shape (batch_size, in_channels, time)
368
+ mask (_type_): shape (batch_size, 1, time)
369
+ t (_type_): shape (batch_size)
370
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
+ cond (_type_, optional): placeholder for future use. Defaults to None.
372
+
373
+ Raises:
374
+ ValueError: _description_
375
+ ValueError: _description_
376
+
377
+ Returns:
378
+ _type_: _description_
379
+ """
380
+
381
+ t = self.time_embeddings(t)
382
+ t = self.time_mlp(t)
383
+
384
+ x = pack([x, mu], "b * t")[0]
385
+
386
+ if spks is not None:
387
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
+ x = pack([x, spks], "b * t")[0]
389
+
390
+ hiddens = []
391
+ masks = [mask]
392
+ for resnet, transformer_blocks, downsample in self.down_blocks:
393
+ mask_down = masks[-1]
394
+ x = resnet(x, mask_down, t)
395
+ x = rearrange(x, "b c t -> b t c")
396
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
397
+ for transformer_block in transformer_blocks:
398
+ x = transformer_block(
399
+ hidden_states=x,
400
+ attention_mask=mask_down,
401
+ timestep=t,
402
+ )
403
+ x = rearrange(x, "b t c -> b c t")
404
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
405
+ hiddens.append(x) # Save hidden states for skip connections
406
+ x = downsample(x * mask_down)
407
+ masks.append(mask_down[:, :, ::2])
408
+
409
+ masks = masks[:-1]
410
+ mask_mid = masks[-1]
411
+
412
+ for resnet, transformer_blocks in self.mid_blocks:
413
+ x = resnet(x, mask_mid, t)
414
+ x = rearrange(x, "b c t -> b t c")
415
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
+ for transformer_block in transformer_blocks:
417
+ x = transformer_block(
418
+ hidden_states=x,
419
+ attention_mask=mask_mid,
420
+ timestep=t,
421
+ )
422
+ x = rearrange(x, "b t c -> b c t")
423
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
+
425
+ for resnet, transformer_blocks, upsample in self.up_blocks:
426
+ mask_up = masks.pop()
427
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
+ x = rearrange(x, "b c t -> b t c")
429
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
430
+ for transformer_block in transformer_blocks:
431
+ x = transformer_block(
432
+ hidden_states=x,
433
+ attention_mask=mask_up,
434
+ timestep=t,
435
+ )
436
+ x = rearrange(x, "b t c -> b c t")
437
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
438
+ x = upsample(x * mask_up)
439
+
440
+ x = self.final_block(x, mask_up)
441
+ output = self.final_proj(x * mask_up)
442
+
443
+ return output * mask
third_party/Matcha-TTS/matcha/models/components/text_encoder.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+ import matcha.utils as utils
10
+ from matcha.utils.model import sequence_mask
11
+
12
+ log = utils.get_pylogger(__name__)
13
+
14
+
15
+ class LayerNorm(nn.Module):
16
+ def __init__(self, channels, eps=1e-4):
17
+ super().__init__()
18
+ self.channels = channels
19
+ self.eps = eps
20
+
21
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
22
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
23
+
24
+ def forward(self, x):
25
+ n_dims = len(x.shape)
26
+ mean = torch.mean(x, 1, keepdim=True)
27
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
28
+
29
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
30
+
31
+ shape = [1, -1] + [1] * (n_dims - 2)
32
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
33
+ return x
34
+
35
+
36
+ class ConvReluNorm(nn.Module):
37
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
38
+ super().__init__()
39
+ self.in_channels = in_channels
40
+ self.hidden_channels = hidden_channels
41
+ self.out_channels = out_channels
42
+ self.kernel_size = kernel_size
43
+ self.n_layers = n_layers
44
+ self.p_dropout = p_dropout
45
+
46
+ self.conv_layers = torch.nn.ModuleList()
47
+ self.norm_layers = torch.nn.ModuleList()
48
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
49
+ self.norm_layers.append(LayerNorm(hidden_channels))
50
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
51
+ for _ in range(n_layers - 1):
52
+ self.conv_layers.append(
53
+ torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
54
+ )
55
+ self.norm_layers.append(LayerNorm(hidden_channels))
56
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
57
+ self.proj.weight.data.zero_()
58
+ self.proj.bias.data.zero_()
59
+
60
+ def forward(self, x, x_mask):
61
+ x_org = x
62
+ for i in range(self.n_layers):
63
+ x = self.conv_layers[i](x * x_mask)
64
+ x = self.norm_layers[i](x)
65
+ x = self.relu_drop(x)
66
+ x = x_org + self.proj(x)
67
+ return x * x_mask
68
+
69
+
70
+ class DurationPredictor(nn.Module):
71
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
72
+ super().__init__()
73
+ self.in_channels = in_channels
74
+ self.filter_channels = filter_channels
75
+ self.p_dropout = p_dropout
76
+
77
+ self.drop = torch.nn.Dropout(p_dropout)
78
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
79
+ self.norm_1 = LayerNorm(filter_channels)
80
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
81
+ self.norm_2 = LayerNorm(filter_channels)
82
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
83
+
84
+ def forward(self, x, x_mask):
85
+ x = self.conv_1(x * x_mask)
86
+ x = torch.relu(x)
87
+ x = self.norm_1(x)
88
+ x = self.drop(x)
89
+ x = self.conv_2(x * x_mask)
90
+ x = torch.relu(x)
91
+ x = self.norm_2(x)
92
+ x = self.drop(x)
93
+ x = self.proj(x * x_mask)
94
+ return x * x_mask
95
+
96
+
97
+ class RotaryPositionalEmbeddings(nn.Module):
98
+ """
99
+ ## RoPE module
100
+
101
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
102
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
103
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
104
+ by an angle depending on the position of the token.
105
+ """
106
+
107
+ def __init__(self, d: int, base: int = 10_000):
108
+ r"""
109
+ * `d` is the number of features $d$
110
+ * `base` is the constant used for calculating $\Theta$
111
+ """
112
+ super().__init__()
113
+
114
+ self.base = base
115
+ self.d = int(d)
116
+ self.cos_cached = None
117
+ self.sin_cached = None
118
+
119
+ def _build_cache(self, x: torch.Tensor):
120
+ r"""
121
+ Cache $\cos$ and $\sin$ values
122
+ """
123
+ # Return if cache is already built
124
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
125
+ return
126
+
127
+ # Get sequence length
128
+ seq_len = x.shape[0]
129
+
130
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
131
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
132
+
133
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
134
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
135
+
136
+ # Calculate the product of position index and $\theta_i$
137
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
138
+
139
+ # Concatenate so that for row $m$ we have
140
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
141
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
142
+
143
+ # Cache them
144
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
145
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
146
+
147
+ def _neg_half(self, x: torch.Tensor):
148
+ # $\frac{d}{2}$
149
+ d_2 = self.d // 2
150
+
151
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
152
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
153
+
154
+ def forward(self, x: torch.Tensor):
155
+ """
156
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
157
+ """
158
+ # Cache $\cos$ and $\sin$ values
159
+ x = rearrange(x, "b h t d -> t b h d")
160
+
161
+ self._build_cache(x)
162
+
163
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
164
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
165
+
166
+ # Calculate
167
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
168
+ neg_half_x = self._neg_half(x_rope)
169
+
170
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
171
+
172
+ return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
173
+
174
+
175
+ class MultiHeadAttention(nn.Module):
176
+ def __init__(
177
+ self,
178
+ channels,
179
+ out_channels,
180
+ n_heads,
181
+ heads_share=True,
182
+ p_dropout=0.0,
183
+ proximal_bias=False,
184
+ proximal_init=False,
185
+ ):
186
+ super().__init__()
187
+ assert channels % n_heads == 0
188
+
189
+ self.channels = channels
190
+ self.out_channels = out_channels
191
+ self.n_heads = n_heads
192
+ self.heads_share = heads_share
193
+ self.proximal_bias = proximal_bias
194
+ self.p_dropout = p_dropout
195
+ self.attn = None
196
+
197
+ self.k_channels = channels // n_heads
198
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
199
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
200
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
201
+
202
+ # from https://nn.labml.ai/transformers/rope/index.html
203
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
204
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
205
+
206
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
207
+ self.drop = torch.nn.Dropout(p_dropout)
208
+
209
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
210
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
211
+ if proximal_init:
212
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
213
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
214
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
215
+
216
+ def forward(self, x, c, attn_mask=None):
217
+ q = self.conv_q(x)
218
+ k = self.conv_k(c)
219
+ v = self.conv_v(c)
220
+
221
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
+
223
+ x = self.conv_o(x)
224
+ return x
225
+
226
+ def attention(self, query, key, value, mask=None):
227
+ b, d, t_s, t_t = (*key.size(), query.size(2))
228
+ query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
229
+ key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
230
+ value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
231
+
232
+ query = self.query_rotary_pe(query)
233
+ key = self.key_rotary_pe(key)
234
+
235
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
236
+
237
+ if self.proximal_bias:
238
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
239
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
240
+ if mask is not None:
241
+ scores = scores.masked_fill(mask == 0, -1e4)
242
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
243
+ p_attn = self.drop(p_attn)
244
+ output = torch.matmul(p_attn, value)
245
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
246
+ return output, p_attn
247
+
248
+ @staticmethod
249
+ def _attention_bias_proximal(length):
250
+ r = torch.arange(length, dtype=torch.float32)
251
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
252
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
253
+
254
+
255
+ class FFN(nn.Module):
256
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
257
+ super().__init__()
258
+ self.in_channels = in_channels
259
+ self.out_channels = out_channels
260
+ self.filter_channels = filter_channels
261
+ self.kernel_size = kernel_size
262
+ self.p_dropout = p_dropout
263
+
264
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
265
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
266
+ self.drop = torch.nn.Dropout(p_dropout)
267
+
268
+ def forward(self, x, x_mask):
269
+ x = self.conv_1(x * x_mask)
270
+ x = torch.relu(x)
271
+ x = self.drop(x)
272
+ x = self.conv_2(x * x_mask)
273
+ return x * x_mask
274
+
275
+
276
+ class Encoder(nn.Module):
277
+ def __init__(
278
+ self,
279
+ hidden_channels,
280
+ filter_channels,
281
+ n_heads,
282
+ n_layers,
283
+ kernel_size=1,
284
+ p_dropout=0.0,
285
+ **kwargs,
286
+ ):
287
+ super().__init__()
288
+ self.hidden_channels = hidden_channels
289
+ self.filter_channels = filter_channels
290
+ self.n_heads = n_heads
291
+ self.n_layers = n_layers
292
+ self.kernel_size = kernel_size
293
+ self.p_dropout = p_dropout
294
+
295
+ self.drop = torch.nn.Dropout(p_dropout)
296
+ self.attn_layers = torch.nn.ModuleList()
297
+ self.norm_layers_1 = torch.nn.ModuleList()
298
+ self.ffn_layers = torch.nn.ModuleList()
299
+ self.norm_layers_2 = torch.nn.ModuleList()
300
+ for _ in range(self.n_layers):
301
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
302
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
303
+ self.ffn_layers.append(
304
+ FFN(
305
+ hidden_channels,
306
+ hidden_channels,
307
+ filter_channels,
308
+ kernel_size,
309
+ p_dropout=p_dropout,
310
+ )
311
+ )
312
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
313
+
314
+ def forward(self, x, x_mask):
315
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
316
+ for i in range(self.n_layers):
317
+ x = x * x_mask
318
+ y = self.attn_layers[i](x, x, attn_mask)
319
+ y = self.drop(y)
320
+ x = self.norm_layers_1[i](x + y)
321
+ y = self.ffn_layers[i](x, x_mask)
322
+ y = self.drop(y)
323
+ x = self.norm_layers_2[i](x + y)
324
+ x = x * x_mask
325
+ return x
326
+
327
+
328
+ class TextEncoder(nn.Module):
329
+ def __init__(
330
+ self,
331
+ encoder_type,
332
+ encoder_params,
333
+ duration_predictor_params,
334
+ n_vocab,
335
+ n_spks=1,
336
+ spk_emb_dim=128,
337
+ ):
338
+ super().__init__()
339
+ self.encoder_type = encoder_type
340
+ self.n_vocab = n_vocab
341
+ self.n_feats = encoder_params.n_feats
342
+ self.n_channels = encoder_params.n_channels
343
+ self.spk_emb_dim = spk_emb_dim
344
+ self.n_spks = n_spks
345
+
346
+ self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
347
+ torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
348
+
349
+ if encoder_params.prenet:
350
+ self.prenet = ConvReluNorm(
351
+ self.n_channels,
352
+ self.n_channels,
353
+ self.n_channels,
354
+ kernel_size=5,
355
+ n_layers=3,
356
+ p_dropout=0.5,
357
+ )
358
+ else:
359
+ self.prenet = lambda x, x_mask: x
360
+
361
+ self.encoder = Encoder(
362
+ encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
363
+ encoder_params.filter_channels,
364
+ encoder_params.n_heads,
365
+ encoder_params.n_layers,
366
+ encoder_params.kernel_size,
367
+ encoder_params.p_dropout,
368
+ )
369
+
370
+ self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
371
+ self.proj_w = DurationPredictor(
372
+ self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
373
+ duration_predictor_params.filter_channels_dp,
374
+ duration_predictor_params.kernel_size,
375
+ duration_predictor_params.p_dropout,
376
+ )
377
+
378
+ def forward(self, x, x_lengths, spks=None):
379
+ """Run forward pass to the transformer based encoder and duration predictor
380
+
381
+ Args:
382
+ x (torch.Tensor): text input
383
+ shape: (batch_size, max_text_length)
384
+ x_lengths (torch.Tensor): text input lengths
385
+ shape: (batch_size,)
386
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
387
+ shape: (batch_size,)
388
+
389
+ Returns:
390
+ mu (torch.Tensor): average output of the encoder
391
+ shape: (batch_size, n_feats, max_text_length)
392
+ logw (torch.Tensor): log duration predicted by the duration predictor
393
+ shape: (batch_size, 1, max_text_length)
394
+ x_mask (torch.Tensor): mask for the text input
395
+ shape: (batch_size, 1, max_text_length)
396
+ """
397
+ x = self.emb(x) * math.sqrt(self.n_channels)
398
+ x = torch.transpose(x, 1, -1)
399
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
400
+
401
+ x = self.prenet(x, x_mask)
402
+ if self.n_spks > 1:
403
+ x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
404
+ x = self.encoder(x, x_mask)
405
+ mu = self.proj_m(x) * x_mask
406
+
407
+ x_dp = torch.detach(x)
408
+ logw = self.proj_w(x_dp, x_mask)
409
+
410
+ return mu, logw, x_mask
third_party/Matcha-TTS/matcha/models/components/transformer.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
+ """
37
+ Initialization.
38
+ INPUT:
39
+ - in_features: shape of the input
40
+ - alpha - trainable parameter that controls frequency
41
+ - beta - trainable parameter that controls magnitude
42
+ alpha is initialized to 1 by default, higher values = higher-frequency.
43
+ beta is initialized to 1 by default, higher values = higher-magnitude.
44
+ alpha will be trained along with the rest of your model.
45
+ """
46
+ super().__init__()
47
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
+ self.proj = LoRACompatibleLinear(in_features, out_features)
49
+
50
+ # initialize alpha
51
+ self.alpha_logscale = alpha_logscale
52
+ if self.alpha_logscale: # log scale alphas initialized to zeros
53
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
+ else: # linear scale alphas initialized to ones
56
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
+
59
+ self.alpha.requires_grad = alpha_trainable
60
+ self.beta.requires_grad = alpha_trainable
61
+
62
+ self.no_div_by_zero = 0.000000001
63
+
64
+ def forward(self, x):
65
+ """
66
+ Forward pass of the function.
67
+ Applies the function to the input elementwise.
68
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
+ """
70
+ x = self.proj(x)
71
+ if self.alpha_logscale:
72
+ alpha = torch.exp(self.alpha)
73
+ beta = torch.exp(self.beta)
74
+ else:
75
+ alpha = self.alpha
76
+ beta = self.beta
77
+
78
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
+
80
+ return x
81
+
82
+
83
+ class FeedForward(nn.Module):
84
+ r"""
85
+ A feed-forward layer.
86
+
87
+ Parameters:
88
+ dim (`int`): The number of channels in the input.
89
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ dim_out: Optional[int] = None,
100
+ mult: int = 4,
101
+ dropout: float = 0.0,
102
+ activation_fn: str = "geglu",
103
+ final_dropout: bool = False,
104
+ ):
105
+ super().__init__()
106
+ inner_dim = int(dim * mult)
107
+ dim_out = dim_out if dim_out is not None else dim
108
+
109
+ if activation_fn == "gelu":
110
+ act_fn = GELU(dim, inner_dim)
111
+ if activation_fn == "gelu-approximate":
112
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
113
+ elif activation_fn == "geglu":
114
+ act_fn = GEGLU(dim, inner_dim)
115
+ elif activation_fn == "geglu-approximate":
116
+ act_fn = ApproximateGELU(dim, inner_dim)
117
+ elif activation_fn == "snakebeta":
118
+ act_fn = SnakeBeta(dim, inner_dim)
119
+
120
+ self.net = nn.ModuleList([])
121
+ # project in
122
+ self.net.append(act_fn)
123
+ # project dropout
124
+ self.net.append(nn.Dropout(dropout))
125
+ # project out
126
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
+ if final_dropout:
129
+ self.net.append(nn.Dropout(dropout))
130
+
131
+ def forward(self, hidden_states):
132
+ for module in self.net:
133
+ hidden_states = module(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class BasicTransformerBlock(nn.Module):
139
+ r"""
140
+ A basic Transformer block.
141
+
142
+ Parameters:
143
+ dim (`int`): The number of channels in the input and output.
144
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
+ attention_head_dim (`int`): The number of channels in each head.
146
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
+ only_cross_attention (`bool`, *optional*):
149
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
+ double_self_attention (`bool`, *optional*):
151
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
152
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
+ num_embeds_ada_norm (:
154
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
+ attention_bias (:
156
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_attention_heads: int,
163
+ attention_head_dim: int,
164
+ dropout=0.0,
165
+ cross_attention_dim: Optional[int] = None,
166
+ activation_fn: str = "geglu",
167
+ num_embeds_ada_norm: Optional[int] = None,
168
+ attention_bias: bool = False,
169
+ only_cross_attention: bool = False,
170
+ double_self_attention: bool = False,
171
+ upcast_attention: bool = False,
172
+ norm_elementwise_affine: bool = True,
173
+ norm_type: str = "layer_norm",
174
+ final_dropout: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.only_cross_attention = only_cross_attention
178
+
179
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
+
182
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
+ raise ValueError(
184
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
+ )
187
+
188
+ # Define 3 blocks. Each block has its own normalization layer.
189
+ # 1. Self-Attn
190
+ if self.use_ada_layer_norm:
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
+ elif self.use_ada_layer_norm_zero:
193
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
+ else:
195
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
+ self.attn1 = Attention(
197
+ query_dim=dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+
206
+ # 2. Cross-Attn
207
+ if cross_attention_dim is not None or double_self_attention:
208
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
+ # the second cross attention block.
211
+ self.norm2 = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn2 = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ # scale_qk=False, # uncomment this to not to use flash attention
225
+ ) # is self-attn if encoder_hidden_states is none
226
+ else:
227
+ self.norm2 = None
228
+ self.attn2 = None
229
+
230
+ # 3. Feed-forward
231
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
+
234
+ # let chunk size default to None
235
+ self._chunk_size = None
236
+ self._chunk_dim = 0
237
+
238
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
+ # Sets chunk feed-forward
240
+ self._chunk_size = chunk_size
241
+ self._chunk_dim = dim
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.FloatTensor,
246
+ attention_mask: Optional[torch.FloatTensor] = None,
247
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ cross_attention_kwargs: Dict[str, Any] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ ):
253
+ # Notice that normalization is always applied before the real computation in the following blocks.
254
+ # 1. Self-Attention
255
+ if self.use_ada_layer_norm:
256
+ norm_hidden_states = self.norm1(hidden_states, timestep)
257
+ elif self.use_ada_layer_norm_zero:
258
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
+ )
261
+ else:
262
+ norm_hidden_states = self.norm1(hidden_states)
263
+
264
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ hidden_states = attn_output + hidden_states
275
+
276
+ # 2. Cross-Attention
277
+ if self.attn2 is not None:
278
+ norm_hidden_states = (
279
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
+ )
281
+
282
+ attn_output = self.attn2(
283
+ norm_hidden_states,
284
+ encoder_hidden_states=encoder_hidden_states,
285
+ attention_mask=encoder_attention_mask,
286
+ **cross_attention_kwargs,
287
+ )
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ # 3. Feed-forward
291
+ norm_hidden_states = self.norm3(hidden_states)
292
+
293
+ if self.use_ada_layer_norm_zero:
294
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
+
296
+ if self._chunk_size is not None:
297
+ # "feed_forward_chunk_size" can be used to save memory
298
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
+ raise ValueError(
300
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
+ )
302
+
303
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
+ ff_output = torch.cat(
305
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
+ dim=self._chunk_dim,
307
+ )
308
+ else:
309
+ ff_output = self.ff(norm_hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
313
+
314
+ hidden_states = ff_output + hidden_states
315
+
316
+ return hidden_states
third_party/Matcha-TTS/matcha/text/numbers.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+
5
+ import inflect
6
+
7
+ _inflect = inflect.engine()
8
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def _expand_dollars(m):
25
+ match = m.group(1)
26
+ parts = match.split(".")
27
+ if len(parts) > 2:
28
+ return match + " dollars"
29
+ dollars = int(parts[0]) if parts[0] else 0
30
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31
+ if dollars and cents:
32
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
33
+ cent_unit = "cent" if cents == 1 else "cents"
34
+ return f"{dollars} {dollar_unit}, {cents} {cent_unit}"
35
+ elif dollars:
36
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
37
+ return f"{dollars} {dollar_unit}"
38
+ elif cents:
39
+ cent_unit = "cent" if cents == 1 else "cents"
40
+ return f"{cents} {cent_unit}"
41
+ else:
42
+ return "zero dollars"
43
+
44
+
45
+ def _expand_ordinal(m):
46
+ return _inflect.number_to_words(m.group(0))
47
+
48
+
49
+ def _expand_number(m):
50
+ num = int(m.group(0))
51
+ if num > 1000 and num < 3000:
52
+ if num == 2000:
53
+ return "two thousand"
54
+ elif num > 2000 and num < 2010:
55
+ return "two thousand " + _inflect.number_to_words(num % 100)
56
+ elif num % 100 == 0:
57
+ return _inflect.number_to_words(num // 100) + " hundred"
58
+ else:
59
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
60
+ else:
61
+ return _inflect.number_to_words(num, andword="")
62
+
63
+
64
+ def normalize_numbers(text):
65
+ text = re.sub(_comma_number_re, _remove_commas, text)
66
+ text = re.sub(_pounds_re, r"\1 pounds", text)
67
+ text = re.sub(_dollars_re, _expand_dollars, text)
68
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
70
+ text = re.sub(_number_re, _expand_number, text)
71
+ return text
third_party/Matcha-TTS/matcha/text/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron
2
+
3
+ Defines the set of symbols used in text input to the model.
4
+ """
5
+ _pad = "_"
6
+ _punctuation = ';:,.!?¡¿—…"«»“” '
7
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
8
+ _letters_ipa = (
9
+ "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10
+ )
11
+
12
+
13
+ # Export all symbols:
14
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
15
+
16
+ # Special symbol ids
17
+ SPACE_ID = symbols.index(" ")
third_party/Matcha-TTS/matcha/utils/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
third_party/Matcha-TTS/matcha/utils/logging_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+ from omegaconf import OmegaConf
5
+
6
+ from matcha.utils import pylogger
7
+
8
+ log = pylogger.get_pylogger(__name__)
9
+
10
+
11
+ @rank_zero_only
12
+ def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13
+ """Controls which config parts are saved by Lightning loggers.
14
+
15
+ Additionally saves:
16
+ - Number of model parameters
17
+
18
+ :param object_dict: A dictionary containing the following objects:
19
+ - `"cfg"`: A DictConfig object containing the main config.
20
+ - `"model"`: The Lightning model.
21
+ - `"trainer"`: The Lightning trainer.
22
+ """
23
+ hparams = {}
24
+
25
+ cfg = OmegaConf.to_container(object_dict["cfg"])
26
+ model = object_dict["model"]
27
+ trainer = object_dict["trainer"]
28
+
29
+ if not trainer.logger:
30
+ log.warning("Logger not found! Skipping hyperparameter logging...")
31
+ return
32
+
33
+ hparams["model"] = cfg["model"]
34
+
35
+ # save number of model parameters
36
+ hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37
+ hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
38
+ hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
39
+
40
+ hparams["data"] = cfg["data"]
41
+ hparams["trainer"] = cfg["trainer"]
42
+
43
+ hparams["callbacks"] = cfg.get("callbacks")
44
+ hparams["extras"] = cfg.get("extras")
45
+
46
+ hparams["task_name"] = cfg.get("task_name")
47
+ hparams["tags"] = cfg.get("tags")
48
+ hparams["ckpt_path"] = cfg.get("ckpt_path")
49
+ hparams["seed"] = cfg.get("seed")
50
+
51
+ # send hparams to all loggers
52
+ for logger in trainer.loggers:
53
+ logger.log_hyperparams(hparams)
third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from matcha.utils.monotonic_align.core import maximum_path_c
5
+
6
+
7
+ def maximum_path(value, mask):
8
+ """Cython optimised version.
9
+ value: [b, t_x, t_y]
10
+ mask: [b, t_x, t_y]
11
+ """
12
+ value = value * mask
13
+ device = value.device
14
+ dtype = value.dtype
15
+ value = value.data.cpu().numpy().astype(np.float32)
16
+ path = np.zeros_like(value).astype(np.int32)
17
+ mask = mask.data.cpu().numpy()
18
+
19
+ t_x_max = mask.sum(1)[:, 0].astype(np.int32)
20
+ t_y_max = mask.sum(2)[:, 0].astype(np.int32)
21
+ maximum_path_c(path, value, t_x_max, t_y_max)
22
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ cimport cython
4
+ cimport numpy as np
5
+
6
+ from cython.parallel import prange
7
+
8
+
9
+ @cython.boundscheck(False)
10
+ @cython.wraparound(False)
11
+ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
12
+ cdef int x
13
+ cdef int y
14
+ cdef float v_prev
15
+ cdef float v_cur
16
+ cdef float tmp
17
+ cdef int index = t_x - 1
18
+
19
+ for y in range(t_y):
20
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
21
+ if x == y:
22
+ v_cur = max_neg_val
23
+ else:
24
+ v_cur = value[x, y-1]
25
+ if x == 0:
26
+ if y == 0:
27
+ v_prev = 0.
28
+ else:
29
+ v_prev = max_neg_val
30
+ else:
31
+ v_prev = value[x-1, y-1]
32
+ value[x, y] = max(v_cur, v_prev) + value[x, y]
33
+
34
+ for y in range(t_y - 1, -1, -1):
35
+ path[index, y] = 1
36
+ if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
37
+ index = index - 1
38
+
39
+
40
+ @cython.boundscheck(False)
41
+ @cython.wraparound(False)
42
+ cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
43
+ cdef int b = values.shape[0]
44
+
45
+ cdef int i
46
+ for i in prange(b, nogil=True):
47
+ maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
third_party/Matcha-TTS/matcha/utils/pylogger.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from lightning.pytorch.utilities import rank_zero_only
4
+
5
+
6
+ def get_pylogger(name: str = __name__) -> logging.Logger:
7
+ """Initializes a multi-GPU-friendly python command line logger.
8
+
9
+ :param name: The name of the logger, defaults to ``__name__``.
10
+
11
+ :return: A logger object.
12
+ """
13
+ logger = logging.getLogger(name)
14
+
15
+ # this ensures all logging levels get marked with the rank zero decorator
16
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
17
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
18
+ for level in logging_levels:
19
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
20
+
21
+ return logger
third_party/Matcha-TTS/notebooks/.gitkeep ADDED
File without changes
third_party/Matcha-TTS/scripts/schedule.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Schedule execution of many runs
3
+ # Run from root folder with: bash scripts/schedule.sh
4
+
5
+ python src/train.py trainer.max_epochs=5 logger=csv
6
+
7
+ python src/train.py trainer.max_epochs=10 logger=csv
tools/extract_embedding.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ import onnxruntime
18
+ import torch
19
+ import torchaudio
20
+ import torchaudio.compliance.kaldi as kaldi
21
+ from tqdm import tqdm
22
+
23
+
24
+ def single_job(utt):
25
+ audio, sample_rate = torchaudio.load(utt2wav[utt])
26
+ if sample_rate != 16000:
27
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
28
+ feat = kaldi.fbank(audio,
29
+ num_mel_bins=80,
30
+ dither=0,
31
+ sample_frequency=16000)
32
+ feat = feat - feat.mean(dim=0, keepdim=True)
33
+ embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
34
+ return utt, embedding
35
+
36
+
37
+ def main(args):
38
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
39
+ utt2embedding, spk2embedding = {}, {}
40
+ for future in tqdm(as_completed(all_task)):
41
+ utt, embedding = future.result()
42
+ utt2embedding[utt] = embedding
43
+ spk = utt2spk[utt]
44
+ if spk not in spk2embedding:
45
+ spk2embedding[spk] = []
46
+ spk2embedding[spk].append(embedding)
47
+ for k, v in spk2embedding.items():
48
+ spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
49
+ torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
50
+ torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
51
+
52
+
53
+ if __name__ == "__main__":
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--dir", type=str)
56
+ parser.add_argument("--onnx_path", type=str)
57
+ parser.add_argument("--num_thread", type=int, default=8)
58
+ args = parser.parse_args()
59
+
60
+ utt2wav, utt2spk = {}, {}
61
+ with open('{}/wav.scp'.format(args.dir)) as f:
62
+ for l in f:
63
+ l = l.replace('\n', '').split()
64
+ utt2wav[l[0]] = l[1]
65
+ with open('{}/utt2spk'.format(args.dir)) as f:
66
+ for l in f:
67
+ l = l.replace('\n', '').split()
68
+ utt2spk[l[0]] = l[1]
69
+
70
+ option = onnxruntime.SessionOptions()
71
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72
+ option.intra_op_num_threads = 1
73
+ providers = ["CPUExecutionProvider"]
74
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
75
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
76
+
77
+ main(args)
tools/extract_speech_token.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ import logging
18
+ import torch
19
+ from tqdm import tqdm
20
+ import onnxruntime
21
+ import numpy as np
22
+ import torchaudio
23
+ import whisper
24
+
25
+
26
+ def single_job(utt):
27
+ audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
28
+ if sample_rate != 16000:
29
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
30
+ # Convert audio to mono
31
+ if audio.shape[0] > 1:
32
+ audio = audio.mean(dim=0, keepdim=True)
33
+ if audio.shape[1] / 16000 > 30:
34
+ logging.warning('do not support extract speech token for audio longer than 30s')
35
+ speech_token = []
36
+ else:
37
+ feat = whisper.log_mel_spectrogram(audio, n_mels=128)
38
+ speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
39
+ ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
40
+ return utt, speech_token
41
+
42
+
43
+ def main(args):
44
+ all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
45
+ utt2speech_token = {}
46
+ for future in tqdm(as_completed(all_task)):
47
+ utt, speech_token = future.result()
48
+ utt2speech_token[utt] = speech_token
49
+ torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--dir", type=str)
55
+ parser.add_argument("--onnx_path", type=str)
56
+ parser.add_argument("--num_thread", type=int, default=8)
57
+ args = parser.parse_args()
58
+
59
+ utt2wav = {}
60
+ with open('{}/wav.scp'.format(args.dir)) as f:
61
+ for l in f:
62
+ l = l.replace('\n', '').split()
63
+ utt2wav[l[0]] = l[1]
64
+
65
+ option = onnxruntime.SessionOptions()
66
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
67
+ option.intra_op_num_threads = 1
68
+ providers = ["CUDAExecutionProvider"]
69
+ ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
70
+ executor = ThreadPoolExecutor(max_workers=args.num_thread)
71
+
72
+ main(args)
tools/make_parquet_list.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import argparse
16
+ import logging
17
+ import os
18
+ import json
19
+ from tqdm import tqdm
20
+ import pandas as pd
21
+ import multiprocessing
22
+ import time
23
+ import torch
24
+
25
+
26
+ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
27
+ start_time = time.time()
28
+ data_list = []
29
+ for utt in tqdm(utt_list):
30
+ data = open(utt2wav[utt], 'rb').read()
31
+ data_list.append(data)
32
+ wav_list = [utt2wav[utt] for utt in utt_list]
33
+ text_list = [utt2text[utt] for utt in utt_list]
34
+ spk_list = [utt2spk[utt] for utt in utt_list]
35
+ uttembedding_list = [utt2embedding[utt] for utt in utt_list]
36
+ spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
37
+ speech_token_list = [utt2speech_token[utt] for utt in utt_list]
38
+
39
+ # 保存到parquet,utt2parquet_file,spk2parquet_file
40
+ df = pd.DataFrame()
41
+ df['utt'] = utt_list
42
+ df['wav'] = wav_list
43
+ df['audio_data'] = data_list
44
+ df['text'] = text_list
45
+ df['spk'] = spk_list
46
+ df['utt_embedding'] = uttembedding_list
47
+ df['spk_embedding'] = spkembedding_list
48
+ df['speech_token'] = speech_token_list
49
+ df.to_parquet(parquet_file)
50
+ with open(utt2parquet_file, 'w') as f:
51
+ json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
52
+ with open(spk2parquet_file, 'w') as f:
53
+ json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54
+ logging.info('spend time {}'.format(time.time() - start_time))
55
+
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument('--num_utts_per_parquet',
60
+ type=int,
61
+ default=1000,
62
+ help='num utts per parquet')
63
+ parser.add_argument('--num_processes',
64
+ type=int,
65
+ default=1,
66
+ help='num processes for make parquets')
67
+ parser.add_argument('--src_dir',
68
+ type=str)
69
+ parser.add_argument('--des_dir',
70
+ type=str)
71
+ args = parser.parse_args()
72
+
73
+ utt2wav, utt2text, utt2spk = {}, {}, {}
74
+ with open('{}/wav.scp'.format(args.src_dir)) as f:
75
+ for l in f:
76
+ l = l.replace('\n', '').split()
77
+ utt2wav[l[0]] = l[1]
78
+ with open('{}/text'.format(args.src_dir)) as f:
79
+ for l in f:
80
+ l = l.replace('\n', '').split()
81
+ utt2text[l[0]] = ' '.join(l[1:])
82
+ with open('{}/utt2spk'.format(args.src_dir)) as f:
83
+ for l in f:
84
+ l = l.replace('\n', '').split()
85
+ utt2spk[l[0]] = l[1]
86
+ utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
87
+ spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
88
+ utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
89
+ utts = list(utt2wav.keys())
90
+
91
+ # Using process pool to speedup
92
+ pool = multiprocessing.Pool(processes=args.num_processes)
93
+ parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
94
+ for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
95
+ parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
96
+ utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
97
+ spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
98
+ parquet_list.append(parquet_file)
99
+ utt2parquet_list.append(utt2parquet_file)
100
+ spk2parquet_list.append(spk2parquet_file)
101
+ pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
102
+ pool.close()
103
+ pool.join()
104
+
105
+ with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
106
+ open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
107
+ open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
108
+ for name in parquet_list:
109
+ f1.write(name + '\n')
110
+ for name in utt2parquet_list:
111
+ f2.write(name + '\n')
112
+ for name in spk2parquet_list:
113
+ f3.write(name + '\n')