Staticaliza commited on
Commit
b326959
·
verified ·
1 Parent(s): 9cb9281

Upload 10 files

Browse files
modules/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec
modules/commons.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+
16
+ def init_weights(m, mean=0.0, std=0.01):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv") != -1:
19
+ m.weight.data.normal_(mean, std)
20
+
21
+
22
+ def get_padding(kernel_size, dilation=1):
23
+ return int((kernel_size * dilation - dilation) / 2)
24
+
25
+
26
+ def convert_pad_shape(pad_shape):
27
+ l = pad_shape[::-1]
28
+ pad_shape = [item for sublist in l for item in sublist]
29
+ return pad_shape
30
+
31
+
32
+ def intersperse(lst, item):
33
+ result = [item] * (len(lst) * 2 + 1)
34
+ result[1::2] = lst
35
+ return result
36
+
37
+
38
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
39
+ """KL(P||Q)"""
40
+ kl = (logs_q - logs_p) - 0.5
41
+ kl += (
42
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
43
+ )
44
+ return kl
45
+
46
+
47
+ def rand_gumbel(shape):
48
+ """Sample from the Gumbel distribution, protect from overflows."""
49
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
50
+ return -torch.log(-torch.log(uniform_samples))
51
+
52
+
53
+ def rand_gumbel_like(x):
54
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
55
+ return g
56
+
57
+
58
+ def slice_segments(x, ids_str, segment_size=4):
59
+ ret = torch.zeros_like(x[:, :, :segment_size])
60
+ for i in range(x.size(0)):
61
+ idx_str = ids_str[i]
62
+ idx_end = idx_str + segment_size
63
+ ret[i] = x[i, :, idx_str:idx_end]
64
+ return ret
65
+
66
+
67
+ def slice_segments_audio(x, ids_str, segment_size=4):
68
+ ret = torch.zeros_like(x[:, :segment_size])
69
+ for i in range(x.size(0)):
70
+ idx_str = ids_str[i]
71
+ idx_end = idx_str + segment_size
72
+ ret[i] = x[i, idx_str:idx_end]
73
+ return ret
74
+
75
+
76
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
77
+ b, d, t = x.size()
78
+ if x_lengths is None:
79
+ x_lengths = t
80
+ ids_str_max = x_lengths - segment_size + 1
81
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
82
+ dtype=torch.long
83
+ )
84
+ ret = slice_segments(x, ids_str, segment_size)
85
+ return ret, ids_str
86
+
87
+
88
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
89
+ position = torch.arange(length, dtype=torch.float)
90
+ num_timescales = channels // 2
91
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
92
+ num_timescales - 1
93
+ )
94
+ inv_timescales = min_timescale * torch.exp(
95
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
96
+ )
97
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
98
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
99
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
100
+ signal = signal.view(1, channels, length)
101
+ return signal
102
+
103
+
104
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
105
+ b, channels, length = x.size()
106
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
107
+ return x + signal.to(dtype=x.dtype, device=x.device)
108
+
109
+
110
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
111
+ b, channels, length = x.size()
112
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
114
+
115
+
116
+ def subsequent_mask(length):
117
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
118
+ return mask
119
+
120
+
121
+ @torch.jit.script
122
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
123
+ n_channels_int = n_channels[0]
124
+ in_act = input_a + input_b
125
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
126
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
127
+ acts = t_act * s_act
128
+ return acts
129
+
130
+
131
+ def convert_pad_shape(pad_shape):
132
+ l = pad_shape[::-1]
133
+ pad_shape = [item for sublist in l for item in sublist]
134
+ return pad_shape
135
+
136
+
137
+ def shift_1d(x):
138
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
139
+ return x
140
+
141
+
142
+ def sequence_mask(length, max_length=None):
143
+ if max_length is None:
144
+ max_length = length.max()
145
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
146
+ return x.unsqueeze(0) < length.unsqueeze(1)
147
+
148
+
149
+ def avg_with_mask(x, mask):
150
+ assert mask.dtype == torch.float, "Mask should be float"
151
+
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(1)
154
+
155
+ if mask.shape[1] == 1:
156
+ mask = mask.expand_as(x)
157
+
158
+ return (x * mask).sum() / mask.sum()
159
+
160
+
161
+ def generate_path(duration, mask):
162
+ """
163
+ duration: [b, 1, t_x]
164
+ mask: [b, 1, t_y, t_x]
165
+ """
166
+ device = duration.device
167
+
168
+ b, _, t_y, t_x = mask.shape
169
+ cum_duration = torch.cumsum(duration, -1)
170
+
171
+ cum_duration_flat = cum_duration.view(b * t_x)
172
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
173
+ path = path.view(b, t_x, t_y)
174
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
175
+ path = path.unsqueeze(1).transpose(2, 3) * mask
176
+ return path
177
+
178
+
179
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
180
+ if isinstance(parameters, torch.Tensor):
181
+ parameters = [parameters]
182
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
183
+ norm_type = float(norm_type)
184
+ if clip_value is not None:
185
+ clip_value = float(clip_value)
186
+
187
+ total_norm = 0
188
+ for p in parameters:
189
+ param_norm = p.grad.data.norm(norm_type)
190
+ total_norm += param_norm.item() ** norm_type
191
+ if clip_value is not None:
192
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
193
+ total_norm = total_norm ** (1.0 / norm_type)
194
+ return total_norm
195
+
196
+
197
+ def log_norm(x, mean=-4, std=4, dim=2):
198
+ """
199
+ normalized log mel -> mel -> norm -> log(norm)
200
+ """
201
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
202
+ return x
203
+
204
+
205
+ def load_F0_models(path):
206
+ # load F0 model
207
+ from .JDC.model import JDCNet
208
+
209
+ F0_model = JDCNet(num_class=1, seq_len=192)
210
+ params = torch.load(path, map_location="cpu")["net"]
211
+ F0_model.load_state_dict(params)
212
+ _ = F0_model.train()
213
+
214
+ return F0_model
215
+
216
+
217
+ def modify_w2v_forward(self, output_layer=15):
218
+ """
219
+ change forward method of w2v encoder to get its intermediate layer output
220
+ :param self:
221
+ :param layer:
222
+ :return:
223
+ """
224
+ from transformers.modeling_outputs import BaseModelOutput
225
+
226
+ def forward(
227
+ hidden_states,
228
+ attention_mask=None,
229
+ output_attentions=False,
230
+ output_hidden_states=False,
231
+ return_dict=True,
232
+ ):
233
+ all_hidden_states = () if output_hidden_states else None
234
+ all_self_attentions = () if output_attentions else None
235
+
236
+ conv_attention_mask = attention_mask
237
+ if attention_mask is not None:
238
+ # make sure padded tokens output 0
239
+ hidden_states = hidden_states.masked_fill(
240
+ ~attention_mask.bool().unsqueeze(-1), 0.0
241
+ )
242
+
243
+ # extend attention_mask
244
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
245
+ dtype=hidden_states.dtype
246
+ )
247
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
248
+ attention_mask = attention_mask.expand(
249
+ attention_mask.shape[0],
250
+ 1,
251
+ attention_mask.shape[-1],
252
+ attention_mask.shape[-1],
253
+ )
254
+
255
+ hidden_states = self.dropout(hidden_states)
256
+
257
+ if self.embed_positions is not None:
258
+ relative_position_embeddings = self.embed_positions(hidden_states)
259
+ else:
260
+ relative_position_embeddings = None
261
+
262
+ deepspeed_zero3_is_enabled = False
263
+
264
+ for i, layer in enumerate(self.layers):
265
+ if output_hidden_states:
266
+ all_hidden_states = all_hidden_states + (hidden_states,)
267
+
268
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
269
+ dropout_probability = torch.rand([])
270
+
271
+ skip_the_layer = (
272
+ True
273
+ if self.training and (dropout_probability < self.config.layerdrop)
274
+ else False
275
+ )
276
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
277
+ # under deepspeed zero3 all gpus must run in sync
278
+ if self.gradient_checkpointing and self.training:
279
+ layer_outputs = self._gradient_checkpointing_func(
280
+ layer.__call__,
281
+ hidden_states,
282
+ attention_mask,
283
+ relative_position_embeddings,
284
+ output_attentions,
285
+ conv_attention_mask,
286
+ )
287
+ else:
288
+ layer_outputs = layer(
289
+ hidden_states,
290
+ attention_mask=attention_mask,
291
+ relative_position_embeddings=relative_position_embeddings,
292
+ output_attentions=output_attentions,
293
+ conv_attention_mask=conv_attention_mask,
294
+ )
295
+ hidden_states = layer_outputs[0]
296
+
297
+ if skip_the_layer:
298
+ layer_outputs = (None, None)
299
+
300
+ if output_attentions:
301
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
302
+
303
+ if i == output_layer - 1:
304
+ break
305
+
306
+ if output_hidden_states:
307
+ all_hidden_states = all_hidden_states + (hidden_states,)
308
+
309
+ if not return_dict:
310
+ return tuple(
311
+ v
312
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
313
+ if v is not None
314
+ )
315
+ return BaseModelOutput(
316
+ last_hidden_state=hidden_states,
317
+ hidden_states=all_hidden_states,
318
+ attentions=all_self_attentions,
319
+ )
320
+
321
+ return forward
322
+
323
+
324
+ MATPLOTLIB_FLAG = False
325
+
326
+
327
+ def plot_spectrogram_to_numpy(spectrogram):
328
+ global MATPLOTLIB_FLAG
329
+ if not MATPLOTLIB_FLAG:
330
+ import matplotlib
331
+ import logging
332
+
333
+ matplotlib.use("Agg")
334
+ MATPLOTLIB_FLAG = True
335
+ mpl_logger = logging.getLogger("matplotlib")
336
+ mpl_logger.setLevel(logging.WARNING)
337
+ import matplotlib.pylab as plt
338
+ import numpy as np
339
+
340
+ fig, ax = plt.subplots(figsize=(10, 2))
341
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
342
+ plt.colorbar(im, ax=ax)
343
+ plt.xlabel("Frames")
344
+ plt.ylabel("Channels")
345
+ plt.tight_layout()
346
+
347
+ fig.canvas.draw()
348
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
349
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
350
+ plt.close()
351
+ return data
352
+
353
+
354
+ def normalize_f0(f0_sequence):
355
+ # Remove unvoiced frames (replace with -1)
356
+ voiced_indices = np.where(f0_sequence > 0)[0]
357
+ f0_voiced = f0_sequence[voiced_indices]
358
+
359
+ # Convert to log scale
360
+ log_f0 = np.log2(f0_voiced)
361
+
362
+ # Calculate mean and standard deviation
363
+ mean_f0 = np.mean(log_f0)
364
+ std_f0 = np.std(log_f0)
365
+
366
+ # Normalize the F0 sequence
367
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
368
+
369
+ # Create the normalized F0 sequence with unvoiced frames
370
+ normalized_sequence = np.zeros_like(f0_sequence)
371
+ normalized_sequence[voiced_indices] = normalized_f0
372
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
373
+
374
+ return normalized_sequence
375
+
376
+
377
+ def build_model(args, stage="DiT"):
378
+ if stage == "DiT":
379
+ from modules.flow_matching import CFM
380
+ from modules.length_regulator import InterpolateRegulator
381
+
382
+ length_regulator = InterpolateRegulator(
383
+ channels=args.length_regulator.channels,
384
+ sampling_ratios=args.length_regulator.sampling_ratios,
385
+ is_discrete=args.length_regulator.is_discrete,
386
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
387
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
388
+ codebook_size=args.length_regulator.content_codebook_size,
389
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
390
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
391
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
392
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
393
+ )
394
+ cfm = CFM(args)
395
+ nets = Munch(
396
+ cfm=cfm,
397
+ length_regulator=length_regulator,
398
+ )
399
+ elif stage == 'codec':
400
+ from dac.model.dac import Encoder
401
+ from modules.quantize import (
402
+ FAquantizer,
403
+ )
404
+
405
+ encoder = Encoder(
406
+ d_model=args.DAC.encoder_dim,
407
+ strides=args.DAC.encoder_rates,
408
+ d_latent=1024,
409
+ causal=args.causal,
410
+ lstm=args.lstm,
411
+ )
412
+
413
+ quantizer = FAquantizer(
414
+ in_dim=1024,
415
+ n_p_codebooks=1,
416
+ n_c_codebooks=args.n_c_codebooks,
417
+ n_t_codebooks=2,
418
+ n_r_codebooks=3,
419
+ codebook_size=1024,
420
+ codebook_dim=8,
421
+ quantizer_dropout=0.5,
422
+ causal=args.causal,
423
+ separate_prosody_encoder=args.separate_prosody_encoder,
424
+ timbre_norm=args.timbre_norm,
425
+ )
426
+
427
+ nets = Munch(
428
+ encoder=encoder,
429
+ quantizer=quantizer,
430
+ )
431
+ else:
432
+ raise ValueError(f"Unknown stage: {stage}")
433
+
434
+ return nets
435
+
436
+
437
+ def load_checkpoint(
438
+ model,
439
+ optimizer,
440
+ path,
441
+ load_only_params=True,
442
+ ignore_modules=[],
443
+ is_distributed=False,
444
+ ):
445
+ state = torch.load(path, map_location="cpu")
446
+ params = state["net"]
447
+ for key in model:
448
+ if key in params and key not in ignore_modules:
449
+ if not is_distributed:
450
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
451
+ for k in list(params[key].keys()):
452
+ if k.startswith("module."):
453
+ params[key][k[len("module.") :]] = params[key][k]
454
+ del params[key][k]
455
+ model_state_dict = model[key].state_dict()
456
+ # 过滤出形状匹配的键值对
457
+ filtered_state_dict = {
458
+ k: v
459
+ for k, v in params[key].items()
460
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
461
+ }
462
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
463
+ if skipped_keys:
464
+ print(
465
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
466
+ )
467
+ print("%s loaded" % key)
468
+ model[key].load_state_dict(filtered_state_dict, strict=False)
469
+ _ = [model[key].eval() for key in model]
470
+
471
+ if not load_only_params:
472
+ epoch = state["epoch"] + 1
473
+ iters = state["iters"]
474
+ optimizer.load_state_dict(state["optimizer"])
475
+ optimizer.load_scheduler_state_dict(state["scheduler"])
476
+
477
+ else:
478
+ epoch = 0
479
+ iters = 0
480
+
481
+ return model, optimizer, epoch, iters
482
+
483
+
484
+ def recursive_munch(d):
485
+ if isinstance(d, dict):
486
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
487
+ elif isinstance(d, list):
488
+ return [recursive_munch(v) for v in d]
489
+ else:
490
+ return d
modules/diffusion_transformer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+ from modules.gpt_fast.model import ModelArgs, Transformer
6
+ # from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
7
+ from modules.wavenet import WN
8
+ from modules.commons import sequence_mask
9
+
10
+ from torch.nn.utils import weight_norm
11
+
12
+ def modulate(x, shift, scale):
13
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
14
+
15
+
16
+ #################################################################################
17
+ # Embedding Layers for Timesteps and Class Labels #
18
+ #################################################################################
19
+
20
+ class TimestepEmbedder(nn.Module):
21
+ """
22
+ Embeds scalar timesteps into vector representations.
23
+ """
24
+ def __init__(self, hidden_size, frequency_embedding_size=256):
25
+ super().__init__()
26
+ self.mlp = nn.Sequential(
27
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
28
+ nn.SiLU(),
29
+ nn.Linear(hidden_size, hidden_size, bias=True),
30
+ )
31
+ self.frequency_embedding_size = frequency_embedding_size
32
+ self.max_period = 10000
33
+ self.scale = 1000
34
+
35
+ half = frequency_embedding_size // 2
36
+ freqs = torch.exp(
37
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
38
+ )
39
+ self.register_buffer("freqs", freqs)
40
+
41
+ def timestep_embedding(self, t):
42
+ """
43
+ Create sinusoidal timestep embeddings.
44
+ :param t: a 1-D Tensor of N indices, one per batch element.
45
+ These may be fractional.
46
+ :param dim: the dimension of the output.
47
+ :param max_period: controls the minimum frequency of the embeddings.
48
+ :return: an (N, D) Tensor of positional embeddings.
49
+ """
50
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51
+
52
+ args = self.scale * t[:, None].float() * self.freqs[None]
53
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
54
+ if self.frequency_embedding_size % 2:
55
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
56
+ return embedding
57
+
58
+ def forward(self, t):
59
+ t_freq = self.timestep_embedding(t)
60
+ t_emb = self.mlp(t_freq)
61
+ return t_emb
62
+
63
+
64
+ class StyleEmbedder(nn.Module):
65
+ """
66
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
67
+ """
68
+ def __init__(self, input_size, hidden_size, dropout_prob):
69
+ super().__init__()
70
+ use_cfg_embedding = dropout_prob > 0
71
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
72
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
73
+ self.input_size = input_size
74
+ self.dropout_prob = dropout_prob
75
+
76
+ def forward(self, labels, train, force_drop_ids=None):
77
+ use_dropout = self.dropout_prob > 0
78
+ if (train and use_dropout) or (force_drop_ids is not None):
79
+ labels = self.token_drop(labels, force_drop_ids)
80
+ else:
81
+ labels = self.style_in(labels)
82
+ embeddings = labels
83
+ return embeddings
84
+
85
+ class FinalLayer(nn.Module):
86
+ """
87
+ The final layer of DiT.
88
+ """
89
+ def __init__(self, hidden_size, patch_size, out_channels):
90
+ super().__init__()
91
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
92
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
93
+ self.adaLN_modulation = nn.Sequential(
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c):
99
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
100
+ x = modulate(self.norm_final(x), shift, scale)
101
+ x = self.linear(x)
102
+ return x
103
+
104
+ class DiT(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ args
108
+ ):
109
+ super(DiT, self).__init__()
110
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
111
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
112
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
113
+ model_args = ModelArgs(
114
+ block_size=16384,#args.DiT.block_size,
115
+ n_layer=args.DiT.depth,
116
+ n_head=args.DiT.num_heads,
117
+ dim=args.DiT.hidden_dim,
118
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
119
+ vocab_size=1024,
120
+ uvit_skip_connection=self.uvit_skip_connection,
121
+ )
122
+ self.transformer = Transformer(model_args)
123
+ self.in_channels = args.DiT.in_channels
124
+ self.out_channels = args.DiT.in_channels
125
+ self.num_heads = args.DiT.num_heads
126
+
127
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
128
+
129
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
130
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
131
+ self.content_dim = args.DiT.content_dim # for continuous content
132
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
133
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
134
+
135
+ self.is_causal = args.DiT.is_causal
136
+
137
+ self.n_f0_bins = args.DiT.n_f0_bins
138
+ self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
139
+ self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
140
+ self.f0_condition = args.DiT.f0_condition
141
+
142
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
143
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
144
+ # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
145
+ # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
146
+
147
+ input_pos = torch.arange(16384)
148
+ self.register_buffer("input_pos", input_pos)
149
+
150
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
151
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
152
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
153
+ if self.final_layer_type == 'wavenet':
154
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
155
+ kernel_size=args.wavenet.kernel_size,
156
+ dilation_rate=args.wavenet.dilation_rate,
157
+ n_layers=args.wavenet.num_layers,
158
+ gin_channels=args.wavenet.hidden_dim,
159
+ p_dropout=args.wavenet.p_dropout,
160
+ causal=False)
161
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
162
+ else:
163
+ self.final_mlp = nn.Sequential(
164
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
165
+ nn.SiLU(),
166
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
167
+ )
168
+ self.transformer_style_condition = args.DiT.style_condition
169
+ self.wavenet_style_condition = args.wavenet.style_condition
170
+ assert args.DiT.style_condition == args.wavenet.style_condition
171
+
172
+ self.class_dropout_prob = args.DiT.class_dropout_prob
173
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
174
+ self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
175
+ self.long_skip_connection = args.DiT.long_skip_connection
176
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
177
+
178
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
179
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
180
+ args.DiT.hidden_dim)
181
+ if self.style_as_token:
182
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
183
+
184
+ def setup_caches(self, max_batch_size, max_seq_length):
185
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
186
+ def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
187
+ class_dropout = False
188
+ if self.training and torch.rand(1) < self.class_dropout_prob:
189
+ class_dropout = True
190
+ if not self.training and mask_content:
191
+ class_dropout = True
192
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
193
+ cond_in_module = self.cond_projection
194
+
195
+ B, _, T = x.size()
196
+
197
+
198
+ t1 = self.t_embedder(t) # (N, D)
199
+
200
+ cond = cond_in_module(cond)
201
+ if self.f0_condition and f0 is not None:
202
+ quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
203
+ cond = cond + self.f0_embedder(quantized_f0)
204
+
205
+ x = x.transpose(1, 2)
206
+ prompt_x = prompt_x.transpose(1, 2)
207
+
208
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
209
+ if self.transformer_style_condition and not self.style_as_token:
210
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
211
+ if class_dropout:
212
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
213
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
214
+
215
+ if self.style_as_token:
216
+ style = self.style_in(style)
217
+ style = torch.zeros_like(style) if class_dropout else style
218
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
219
+ if self.time_as_token:
220
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
221
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
222
+ input_pos = self.input_pos[:x_in.size(1)] # (T,)
223
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
224
+ x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
225
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
226
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
227
+ if self.long_skip_connection:
228
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
229
+ if self.final_layer_type == 'wavenet':
230
+ x = self.conv1(x_res)
231
+ x = x.transpose(1, 2)
232
+ t2 = self.t_embedder2(t)
233
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
234
+ x_res) # long residual connection
235
+ x = self.final_layer(x, t1).transpose(1, 2)
236
+ x = self.conv2(x)
237
+ else:
238
+ x = self.final_mlp(x_res)
239
+ x = x.transpose(1, 2)
240
+ return x
modules/encodec.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Convolutional layers wrappers and utilities."""
8
+
9
+ import math
10
+ import typing as tp
11
+ import warnings
12
+
13
+ import torch
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import spectral_norm, weight_norm
17
+
18
+ import typing as tp
19
+
20
+ import einops
21
+
22
+
23
+ class ConvLayerNorm(nn.LayerNorm):
24
+ """
25
+ Convolution-friendly LayerNorm that moves channels to last dimensions
26
+ before running the normalization and moves them back to original position right after.
27
+ """
28
+ def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
29
+ super().__init__(normalized_shape, **kwargs)
30
+
31
+ def forward(self, x):
32
+ x = einops.rearrange(x, 'b ... t -> b t ...')
33
+ x = super().forward(x)
34
+ x = einops.rearrange(x, 'b t ... -> b ... t')
35
+ return
36
+
37
+
38
+ CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
39
+ 'time_layer_norm', 'layer_norm', 'time_group_norm'])
40
+
41
+
42
+ def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
43
+ assert norm in CONV_NORMALIZATIONS
44
+ if norm == 'weight_norm':
45
+ return weight_norm(module)
46
+ elif norm == 'spectral_norm':
47
+ return spectral_norm(module)
48
+ else:
49
+ # We already check was in CONV_NORMALIZATION, so any other choice
50
+ # doesn't need reparametrization.
51
+ return module
52
+
53
+
54
+ def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
55
+ """Return the proper normalization module. If causal is True, this will ensure the returned
56
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
57
+ """
58
+ assert norm in CONV_NORMALIZATIONS
59
+ if norm == 'layer_norm':
60
+ assert isinstance(module, nn.modules.conv._ConvNd)
61
+ return ConvLayerNorm(module.out_channels, **norm_kwargs)
62
+ elif norm == 'time_group_norm':
63
+ if causal:
64
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
65
+ assert isinstance(module, nn.modules.conv._ConvNd)
66
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
67
+ else:
68
+ return nn.Identity()
69
+
70
+
71
+ def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
72
+ padding_total: int = 0) -> int:
73
+ """See `pad_for_conv1d`.
74
+ """
75
+ length = x.shape[-1]
76
+ n_frames = (length - kernel_size + padding_total) / stride + 1
77
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
78
+ return ideal_length - length
79
+
80
+
81
+ def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
82
+ """Pad for a convolution to make sure that the last window is full.
83
+ Extra padding is added at the end. This is required to ensure that we can rebuild
84
+ an output of the same length, as otherwise, even with padding, some time steps
85
+ might get removed.
86
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
87
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
88
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
89
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
90
+ 1 2 3 4 # once you removed padding, we are missing one time step !
91
+ """
92
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
93
+ return F.pad(x, (0, extra_padding))
94
+
95
+
96
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
97
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
98
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
99
+ """
100
+ length = x.shape[-1]
101
+ padding_left, padding_right = paddings
102
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
103
+ if mode == 'reflect':
104
+ max_pad = max(padding_left, padding_right)
105
+ extra_pad = 0
106
+ if length <= max_pad:
107
+ extra_pad = max_pad - length + 1
108
+ x = F.pad(x, (0, extra_pad))
109
+ padded = F.pad(x, paddings, mode, value)
110
+ end = padded.shape[-1] - extra_pad
111
+ return padded[..., :end]
112
+ else:
113
+ return F.pad(x, paddings, mode, value)
114
+
115
+
116
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
117
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
118
+ padding_left, padding_right = paddings
119
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
120
+ assert (padding_left + padding_right) <= x.shape[-1]
121
+ end = x.shape[-1] - padding_right
122
+ return x[..., padding_left: end]
123
+
124
+
125
+ class NormConv1d(nn.Module):
126
+ """Wrapper around Conv1d and normalization applied to this conv
127
+ to provide a uniform interface across normalization approaches.
128
+ """
129
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
130
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
131
+ super().__init__()
132
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
133
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
134
+ self.norm_type = norm
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.norm(x)
139
+ return x
140
+
141
+
142
+ class NormConv2d(nn.Module):
143
+ """Wrapper around Conv2d and normalization applied to this conv
144
+ to provide a uniform interface across normalization approaches.
145
+ """
146
+ def __init__(self, *args, norm: str = 'none',
147
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
148
+ super().__init__()
149
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
150
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
151
+ self.norm_type = norm
152
+
153
+ def forward(self, x):
154
+ x = self.conv(x)
155
+ x = self.norm(x)
156
+ return x
157
+
158
+
159
+ class NormConvTranspose1d(nn.Module):
160
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
161
+ to provide a uniform interface across normalization approaches.
162
+ """
163
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
164
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
165
+ super().__init__()
166
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
167
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
168
+ self.norm_type = norm
169
+
170
+ def forward(self, x):
171
+ x = self.convtr(x)
172
+ x = self.norm(x)
173
+ return x
174
+
175
+
176
+ class NormConvTranspose2d(nn.Module):
177
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
178
+ to provide a uniform interface across normalization approaches.
179
+ """
180
+ def __init__(self, *args, norm: str = 'none',
181
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
182
+ super().__init__()
183
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
184
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
185
+
186
+ def forward(self, x):
187
+ x = self.convtr(x)
188
+ x = self.norm(x)
189
+ return x
190
+
191
+
192
+ class SConv1d(nn.Module):
193
+ """Conv1d with some builtin handling of asymmetric or causal padding
194
+ and normalization.
195
+ """
196
+ def __init__(self, in_channels: int, out_channels: int,
197
+ kernel_size: int, stride: int = 1, dilation: int = 1,
198
+ groups: int = 1, bias: bool = True, causal: bool = False,
199
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
200
+ pad_mode: str = 'reflect', **kwargs):
201
+ super().__init__()
202
+ # warn user on unusual setup between dilation and stride
203
+ if stride > 1 and dilation > 1:
204
+ warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
205
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
206
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
207
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
208
+ norm=norm, norm_kwargs=norm_kwargs)
209
+ self.causal = causal
210
+ self.pad_mode = pad_mode
211
+
212
+ def forward(self, x):
213
+ B, C, T = x.shape
214
+ kernel_size = self.conv.conv.kernel_size[0]
215
+ stride = self.conv.conv.stride[0]
216
+ dilation = self.conv.conv.dilation[0]
217
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
218
+ padding_total = kernel_size - stride
219
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
220
+ if self.causal:
221
+ # Left padding for causal
222
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
223
+ else:
224
+ # Asymmetric padding required for odd strides
225
+ padding_right = padding_total // 2
226
+ padding_left = padding_total - padding_right
227
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
228
+ return self.conv(x)
229
+
230
+
231
+ class SConvTranspose1d(nn.Module):
232
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
233
+ and normalization.
234
+ """
235
+ def __init__(self, in_channels: int, out_channels: int,
236
+ kernel_size: int, stride: int = 1, causal: bool = False,
237
+ norm: str = 'none', trim_right_ratio: float = 1.,
238
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
239
+ super().__init__()
240
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
241
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
242
+ self.causal = causal
243
+ self.trim_right_ratio = trim_right_ratio
244
+ assert self.causal or self.trim_right_ratio == 1., \
245
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
246
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
247
+
248
+ def forward(self, x):
249
+ kernel_size = self.convtr.convtr.kernel_size[0]
250
+ stride = self.convtr.convtr.stride[0]
251
+ padding_total = kernel_size - stride
252
+
253
+ y = self.convtr(x)
254
+
255
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
256
+ # removed at the very end, when keeping only the right length for the output,
257
+ # as removing it here would require also passing the length at the matching layer
258
+ # in the encoder.
259
+ if self.causal:
260
+ # Trim the padding on the right according to the specified ratio
261
+ # if trim_right_ratio = 1.0, trim everything from right
262
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
263
+ padding_left = padding_total - padding_right
264
+ y = unpad1d(y, (padding_left, padding_right))
265
+ else:
266
+ # Asymmetric padding required for odd strides
267
+ padding_right = padding_total // 2
268
+ padding_left = padding_total - padding_right
269
+ y = unpad1d(y, (padding_left, padding_right))
270
+ return y
271
+
272
+ class SLSTM(nn.Module):
273
+ """
274
+ LSTM without worrying about the hidden state, nor the layout of the data.
275
+ Expects input as convolutional layout.
276
+ """
277
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
278
+ super().__init__()
279
+ self.skip = skip
280
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
281
+ self.hidden = None
282
+
283
+ def forward(self, x):
284
+ x = x.permute(2, 0, 1)
285
+ if self.training:
286
+ y, _ = self.lstm(x)
287
+ else:
288
+ y, self.hidden = self.lstm(x, self.hidden)
289
+ if self.skip:
290
+ y = y + x
291
+ y = y.permute(1, 2, 0)
292
+ return y
modules/flow_matching.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from modules.diffusion_transformer import DiT
7
+ from modules.commons import sequence_mask
8
+
9
+ from tqdm import tqdm
10
+
11
+ class BASECFM(torch.nn.Module, ABC):
12
+ def __init__(
13
+ self,
14
+ args,
15
+ ):
16
+ super().__init__()
17
+ self.sigma_min = 1e-6
18
+
19
+ self.estimator = None
20
+
21
+ self.in_channels = args.DiT.in_channels
22
+
23
+ self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
24
+
25
+ if hasattr(args.DiT, 'zero_prompt_speech_token'):
26
+ self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
27
+ else:
28
+ self.zero_prompt_speech_token = False
29
+
30
+ @torch.inference_mode()
31
+ def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
32
+ """Forward diffusion
33
+
34
+ Args:
35
+ mu (torch.Tensor): output of encoder
36
+ shape: (batch_size, n_feats, mel_timesteps)
37
+ mask (torch.Tensor): output_mask
38
+ shape: (batch_size, 1, mel_timesteps)
39
+ n_timesteps (int): number of diffusion steps
40
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
41
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
42
+ shape: (batch_size, spk_emb_dim)
43
+ cond: Not used but kept for future purposes
44
+
45
+ Returns:
46
+ sample: generated mel-spectrogram
47
+ shape: (batch_size, n_feats, mel_timesteps)
48
+ """
49
+ B, T = mu.size(0), mu.size(1)
50
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
51
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
52
+ return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
53
+
54
+ def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
55
+ """
56
+ Fixed euler solver for ODEs.
57
+ Args:
58
+ x (torch.Tensor): random noise
59
+ t_span (torch.Tensor): n_timesteps interpolated
60
+ shape: (n_timesteps + 1,)
61
+ mu (torch.Tensor): output of encoder
62
+ shape: (batch_size, n_feats, mel_timesteps)
63
+ mask (torch.Tensor): output_mask
64
+ shape: (batch_size, 1, mel_timesteps)
65
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
66
+ shape: (batch_size, spk_emb_dim)
67
+ cond: Not used but kept for future purposes
68
+ """
69
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
70
+
71
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
72
+ # Or in future might add like a return_all_steps flag
73
+ sol = []
74
+ # apply prompt
75
+ prompt_len = prompt.size(-1)
76
+ prompt_x = torch.zeros_like(x)
77
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
78
+ x[..., :prompt_len] = 0
79
+ if self.zero_prompt_speech_token:
80
+ mu[..., :prompt_len] = 0
81
+ for step in tqdm(range(1, len(t_span))):
82
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
83
+ # Classifier-Free Guidance inference introduced in VoiceBox
84
+ if inference_cfg_rate > 0:
85
+ cfg_dphi_dt = self.estimator(
86
+ x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
87
+ torch.zeros_like(style),
88
+ torch.zeros_like(mu), None
89
+ )
90
+ dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
91
+ inference_cfg_rate * cfg_dphi_dt)
92
+ x = x + dt * dphi_dt
93
+ t = t + dt
94
+ sol.append(x)
95
+ if step < len(t_span) - 1:
96
+ dt = t_span[step + 1] - t
97
+ x[:, :, :prompt_len] = 0
98
+
99
+ return sol[-1]
100
+
101
+ def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
102
+ """Computes diffusion loss
103
+
104
+ Args:
105
+ x1 (torch.Tensor): Target
106
+ shape: (batch_size, n_feats, mel_timesteps)
107
+ mask (torch.Tensor): target mask
108
+ shape: (batch_size, 1, mel_timesteps)
109
+ mu (torch.Tensor): output of encoder
110
+ shape: (batch_size, n_feats, mel_timesteps)
111
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
112
+ shape: (batch_size, spk_emb_dim)
113
+
114
+ Returns:
115
+ loss: conditional flow matching loss
116
+ y: conditional flow
117
+ shape: (batch_size, n_feats, mel_timesteps)
118
+ """
119
+ b, _, t = x1.shape
120
+
121
+ # random timestep
122
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
123
+ # sample noise p(x_0)
124
+ z = torch.randn_like(x1)
125
+
126
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
127
+ u = x1 - (1 - self.sigma_min) * z
128
+
129
+ prompt = torch.zeros_like(x1)
130
+ for bib in range(b):
131
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
132
+ # range covered by prompt are set to 0
133
+ y[bib, :, :prompt_lens[bib]] = 0
134
+ if self.zero_prompt_speech_token:
135
+ mu[bib, :, :prompt_lens[bib]] = 0
136
+
137
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
138
+ loss = 0
139
+ for bib in range(b):
140
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
141
+ loss /= b
142
+
143
+ return loss, y
144
+
145
+
146
+
147
+ class CFM(BASECFM):
148
+ def __init__(self, args):
149
+ super().__init__(
150
+ args
151
+ )
152
+ if args.dit_type == "DiT":
153
+ self.estimator = DiT(args)
154
+ else:
155
+ raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
modules/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
modules/length_regulator.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from modules.commons import sequence_mask
6
+ import numpy as np
7
+ from dac.nn.quantize import VectorQuantize
8
+
9
+ # f0_bin = 256
10
+ f0_max = 1100.0
11
+ f0_min = 50.0
12
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
13
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
14
+
15
+ def f0_to_coarse(f0, f0_bin):
16
+ f0_mel = 1127 * (1 + f0 / 700).log()
17
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
18
+ b = f0_mel_min * a - 1.
19
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
20
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
21
+ f0_coarse = torch.round(f0_mel).long()
22
+ f0_coarse = f0_coarse * (f0_coarse > 0)
23
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
24
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
25
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
26
+ return f0_coarse
27
+
28
+ class InterpolateRegulator(nn.Module):
29
+ def __init__(
30
+ self,
31
+ channels: int,
32
+ sampling_ratios: Tuple,
33
+ is_discrete: bool = False,
34
+ in_channels: int = None, # only applies to continuous input
35
+ vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
36
+ codebook_size: int = 1024, # for discrete only
37
+ out_channels: int = None,
38
+ groups: int = 1,
39
+ n_codebooks: int = 1, # number of codebooks
40
+ quantizer_dropout: float = 0.0, # dropout for quantizer
41
+ f0_condition: bool = False,
42
+ n_f0_bins: int = 512,
43
+ ):
44
+ super().__init__()
45
+ self.sampling_ratios = sampling_ratios
46
+ out_channels = out_channels or channels
47
+ model = nn.ModuleList([])
48
+ if len(sampling_ratios) > 0:
49
+ self.interpolate = True
50
+ for _ in sampling_ratios:
51
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
52
+ norm = nn.GroupNorm(groups, channels)
53
+ act = nn.Mish()
54
+ model.extend([module, norm, act])
55
+ else:
56
+ self.interpolate = False
57
+ model.append(
58
+ nn.Conv1d(channels, out_channels, 1, 1)
59
+ )
60
+ self.model = nn.Sequential(*model)
61
+ self.embedding = nn.Embedding(codebook_size, channels)
62
+ self.is_discrete = is_discrete
63
+
64
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
65
+
66
+ self.n_codebooks = n_codebooks
67
+ if n_codebooks > 1:
68
+ self.extra_codebooks = nn.ModuleList([
69
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
70
+ ])
71
+ self.extra_codebook_mask_tokens = nn.ParameterList([
72
+ nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
73
+ ])
74
+ self.quantizer_dropout = quantizer_dropout
75
+
76
+ if f0_condition:
77
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
78
+ self.f0_condition = f0_condition
79
+ self.n_f0_bins = n_f0_bins
80
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
81
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
82
+ else:
83
+ self.f0_condition = False
84
+
85
+ if not is_discrete:
86
+ self.content_in_proj = nn.Linear(in_channels, channels)
87
+ if vector_quantize:
88
+ self.vq = VectorQuantize(channels, codebook_size, 8)
89
+
90
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
91
+ # apply token drop
92
+ if self.training:
93
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
94
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
95
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
96
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
97
+ n_quantizers = n_quantizers.to(x.device)
98
+ # decide whether to drop for each sample in batch
99
+ else:
100
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
101
+ if self.is_discrete:
102
+ if self.n_codebooks > 1:
103
+ assert len(x.size()) == 3
104
+ x_emb = self.embedding(x[:, 0])
105
+ for i, emb in enumerate(self.extra_codebooks):
106
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
107
+ # add mask token if not using this codebook
108
+ # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
109
+ x = x_emb
110
+ elif self.n_codebooks == 1:
111
+ if len(x.size()) == 2:
112
+ x = self.embedding(x)
113
+ else:
114
+ x = self.embedding(x[:, 0])
115
+ else:
116
+ x = self.content_in_proj(x)
117
+ # x in (B, T, D)
118
+ mask = sequence_mask(ylens).unsqueeze(-1)
119
+ if self.interpolate:
120
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
121
+ else:
122
+ x = x.transpose(1, 2).contiguous()
123
+ mask = mask[:, :x.size(2), :]
124
+ ylens = ylens.clamp(max=x.size(2)).long()
125
+ if self.f0_condition:
126
+ if f0 is None:
127
+ x = x + self.f0_mask.unsqueeze(-1)
128
+ else:
129
+ #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
130
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
131
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
132
+ f0_emb = self.f0_embedding(quantized_f0)
133
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
134
+ x = x + f0_emb
135
+ out = self.model(x).transpose(1, 2).contiguous()
136
+ if hasattr(self, 'vq'):
137
+ out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
138
+ out_q = out_q.transpose(1, 2)
139
+ return out_q * mask, ylens, codes, commitment_loss, codebook_loss
140
+ olens = ylens
141
+ return out * mask, olens, None, None, None
modules/quantize.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dac.nn.quantize import ResidualVectorQuantize
2
+ from torch import nn
3
+ from modules.wavenet import WN
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.functional as audio_F
7
+ import numpy as np
8
+ from .alias_free_torch import *
9
+ from torch.nn.utils import weight_norm
10
+ from torch import nn, sin, pow
11
+ from einops.layers.torch import Rearrange
12
+ from dac.model.encodec import SConv1d
13
+
14
+ def init_weights(m):
15
+ if isinstance(m, nn.Conv1d):
16
+ nn.init.trunc_normal_(m.weight, std=0.02)
17
+ nn.init.constant_(m.bias, 0)
18
+
19
+
20
+ def WNConv1d(*args, **kwargs):
21
+ return weight_norm(nn.Conv1d(*args, **kwargs))
22
+
23
+
24
+ def WNConvTranspose1d(*args, **kwargs):
25
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
26
+
27
+ class SnakeBeta(nn.Module):
28
+ """
29
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
30
+ Shape:
31
+ - Input: (B, C, T)
32
+ - Output: (B, C, T), same shape as the input
33
+ Parameters:
34
+ - alpha - trainable parameter that controls frequency
35
+ - beta - trainable parameter that controls magnitude
36
+ References:
37
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
38
+ https://arxiv.org/abs/2006.08195
39
+ Examples:
40
+ >>> a1 = snakebeta(256)
41
+ >>> x = torch.randn(256)
42
+ >>> x = a1(x)
43
+ """
44
+
45
+ def __init__(
46
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
47
+ ):
48
+ """
49
+ Initialization.
50
+ INPUT:
51
+ - in_features: shape of the input
52
+ - alpha - trainable parameter that controls frequency
53
+ - beta - trainable parameter that controls magnitude
54
+ alpha is initialized to 1 by default, higher values = higher-frequency.
55
+ beta is initialized to 1 by default, higher values = higher-magnitude.
56
+ alpha will be trained along with the rest of your model.
57
+ """
58
+ super(SnakeBeta, self).__init__()
59
+ self.in_features = in_features
60
+
61
+ # initialize alpha
62
+ self.alpha_logscale = alpha_logscale
63
+ if self.alpha_logscale: # log scale alphas initialized to zeros
64
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
65
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
68
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
69
+
70
+ self.alpha.requires_grad = alpha_trainable
71
+ self.beta.requires_grad = alpha_trainable
72
+
73
+ self.no_div_by_zero = 0.000000001
74
+
75
+ def forward(self, x):
76
+ """
77
+ Forward pass of the function.
78
+ Applies the function to the input elementwise.
79
+ SnakeBeta := x + 1/b * sin^2 (xa)
80
+ """
81
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
82
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(alpha)
85
+ beta = torch.exp(beta)
86
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
87
+
88
+ return x
89
+
90
+ class ResidualUnit(nn.Module):
91
+ def __init__(self, dim: int = 16, dilation: int = 1):
92
+ super().__init__()
93
+ pad = ((7 - 1) * dilation) // 2
94
+ self.block = nn.Sequential(
95
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
96
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
97
+ Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
98
+ WNConv1d(dim, dim, kernel_size=1),
99
+ )
100
+
101
+ def forward(self, x):
102
+ return x + self.block(x)
103
+
104
+ class CNNLSTM(nn.Module):
105
+ def __init__(self, indim, outdim, head, global_pred=False):
106
+ super().__init__()
107
+ self.global_pred = global_pred
108
+ self.model = nn.Sequential(
109
+ ResidualUnit(indim, dilation=1),
110
+ ResidualUnit(indim, dilation=2),
111
+ ResidualUnit(indim, dilation=3),
112
+ Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
113
+ Rearrange("b c t -> b t c"),
114
+ )
115
+ self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
116
+
117
+ def forward(self, x):
118
+ # x: [B, C, T]
119
+ x = self.model(x)
120
+ if self.global_pred:
121
+ x = torch.mean(x, dim=1, keepdim=False)
122
+ outs = [head(x) for head in self.heads]
123
+ return outs
124
+
125
+ def sequence_mask(length, max_length=None):
126
+ if max_length is None:
127
+ max_length = length.max()
128
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
129
+ return x.unsqueeze(0) < length.unsqueeze(1)
130
+ class FAquantizer(nn.Module):
131
+ def __init__(self, in_dim=1024,
132
+ n_p_codebooks=1,
133
+ n_c_codebooks=2,
134
+ n_t_codebooks=2,
135
+ n_r_codebooks=3,
136
+ codebook_size=1024,
137
+ codebook_dim=8,
138
+ quantizer_dropout=0.5,
139
+ causal=False,
140
+ separate_prosody_encoder=False,
141
+ timbre_norm=False,):
142
+ super(FAquantizer, self).__init__()
143
+ conv1d_type = SConv1d# if causal else nn.Conv1d
144
+ self.prosody_quantizer = ResidualVectorQuantize(
145
+ input_dim=in_dim,
146
+ n_codebooks=n_p_codebooks,
147
+ codebook_size=codebook_size,
148
+ codebook_dim=codebook_dim,
149
+ quantizer_dropout=quantizer_dropout,
150
+ )
151
+
152
+ self.content_quantizer = ResidualVectorQuantize(
153
+ input_dim=in_dim,
154
+ n_codebooks=n_c_codebooks,
155
+ codebook_size=codebook_size,
156
+ codebook_dim=codebook_dim,
157
+ quantizer_dropout=quantizer_dropout,
158
+ )
159
+
160
+ self.residual_quantizer = ResidualVectorQuantize(
161
+ input_dim=in_dim,
162
+ n_codebooks=n_r_codebooks,
163
+ codebook_size=codebook_size,
164
+ codebook_dim=codebook_dim,
165
+ quantizer_dropout=quantizer_dropout,
166
+ )
167
+
168
+ self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
169
+ self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
170
+ self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
171
+
172
+ self.prob_random_mask_residual = 0.75
173
+
174
+ SPECT_PARAMS = {
175
+ "n_fft": 2048,
176
+ "win_length": 1200,
177
+ "hop_length": 300,
178
+ }
179
+ MEL_PARAMS = {
180
+ "n_mels": 80,
181
+ }
182
+
183
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
184
+ n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
185
+ )
186
+ self.mel_mean, self.mel_std = -4, 4
187
+ self.frame_rate = 24000 / 300
188
+ self.hop_length = 300
189
+
190
+ def preprocess(self, wave_tensor, n_bins=20):
191
+ mel_tensor = self.to_mel(wave_tensor.squeeze(1))
192
+ mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
193
+ return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
194
+
195
+ def forward(self, x, wave_segments):
196
+ outs = 0
197
+ prosody_feature = self.preprocess(wave_segments)
198
+
199
+ f0_input = prosody_feature # (B, T, 20)
200
+ f0_input = self.melspec_linear(f0_input)
201
+ f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
202
+ f0_input.device).bool())
203
+ f0_input = self.melspec_linear2(f0_input)
204
+
205
+ common_min_size = min(f0_input.size(2), x.size(2))
206
+ f0_input = f0_input[:, :, :common_min_size]
207
+
208
+ x = x[:, :, :common_min_size]
209
+
210
+ z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
211
+ f0_input, 1
212
+ )
213
+ outs += z_p.detach()
214
+
215
+ z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
216
+ x, 2
217
+ )
218
+ outs += z_c.detach()
219
+
220
+ residual_feature = x - z_p.detach() - z_c.detach()
221
+
222
+ z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
223
+ residual_feature, 3
224
+ )
225
+
226
+ quantized = [z_p, z_c, z_r]
227
+ codes = [codes_p, codes_c, codes_r]
228
+
229
+ return quantized, codes
modules/rmvpe.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+ import numpy as np
5
+ import torch
6
+
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from librosa.util import normalize, pad_center, tiny
10
+ from scipy.signal import get_window
11
+
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class STFT(torch.nn.Module):
18
+ def __init__(
19
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
20
+ ):
21
+ """
22
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
23
+ This is a bit tricky so there are some cases that probably won't work as working
24
+ out the same sizes before and after in all overlap add setups is tough. Right now,
25
+ this code should work with hop lengths that are half the filter length (50% overlap
26
+ between frames).
27
+
28
+ Keyword Arguments:
29
+ filter_length {int} -- Length of filters used (default: {1024})
30
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
31
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
32
+ equals the filter length). (default: {None})
33
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
34
+ (default: {'hann'})
35
+ """
36
+ super(STFT, self).__init__()
37
+ self.filter_length = filter_length
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length if win_length else filter_length
40
+ self.window = window
41
+ self.forward_transform = None
42
+ self.pad_amount = int(self.filter_length / 2)
43
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
44
+
45
+ cutoff = int((self.filter_length / 2 + 1))
46
+ fourier_basis = np.vstack(
47
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
48
+ )
49
+ forward_basis = torch.FloatTensor(fourier_basis)
50
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
51
+
52
+ assert filter_length >= self.win_length
53
+ # get window and zero center pad it to filter_length
54
+ fft_window = get_window(window, self.win_length, fftbins=True)
55
+ fft_window = pad_center(fft_window, size=filter_length)
56
+ fft_window = torch.from_numpy(fft_window).float()
57
+
58
+ # window the bases
59
+ forward_basis *= fft_window
60
+ inverse_basis = (inverse_basis.T * fft_window).T
61
+
62
+ self.register_buffer("forward_basis", forward_basis.float())
63
+ self.register_buffer("inverse_basis", inverse_basis.float())
64
+ self.register_buffer("fft_window", fft_window.float())
65
+
66
+ def transform(self, input_data, return_phase=False):
67
+ """Take input data (audio) to STFT domain.
68
+
69
+ Arguments:
70
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
71
+
72
+ Returns:
73
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
74
+ num_frequencies, num_frames)
75
+ phase {tensor} -- Phase of STFT with shape (num_batch,
76
+ num_frequencies, num_frames)
77
+ """
78
+ input_data = F.pad(
79
+ input_data,
80
+ (self.pad_amount, self.pad_amount),
81
+ mode="reflect",
82
+ )
83
+ forward_transform = input_data.unfold(
84
+ 1, self.filter_length, self.hop_length
85
+ ).permute(0, 2, 1)
86
+ forward_transform = torch.matmul(self.forward_basis, forward_transform)
87
+ cutoff = int((self.filter_length / 2) + 1)
88
+ real_part = forward_transform[:, :cutoff, :]
89
+ imag_part = forward_transform[:, cutoff:, :]
90
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
91
+ if return_phase:
92
+ phase = torch.atan2(imag_part.data, real_part.data)
93
+ return magnitude, phase
94
+ else:
95
+ return magnitude
96
+
97
+ def inverse(self, magnitude, phase):
98
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
99
+ by the ```transform``` function.
100
+
101
+ Arguments:
102
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
103
+ num_frequencies, num_frames)
104
+ phase {tensor} -- Phase of STFT with shape (num_batch,
105
+ num_frequencies, num_frames)
106
+
107
+ Returns:
108
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
109
+ shape (num_batch, num_samples)
110
+ """
111
+ cat = torch.cat(
112
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
113
+ )
114
+ fold = torch.nn.Fold(
115
+ output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
116
+ kernel_size=(1, self.filter_length),
117
+ stride=(1, self.hop_length),
118
+ )
119
+ inverse_transform = torch.matmul(self.inverse_basis, cat)
120
+ inverse_transform = fold(inverse_transform)[
121
+ :, 0, 0, self.pad_amount : -self.pad_amount
122
+ ]
123
+ window_square_sum = (
124
+ self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
125
+ )
126
+ window_square_sum = fold(window_square_sum)[
127
+ :, 0, 0, self.pad_amount : -self.pad_amount
128
+ ]
129
+ inverse_transform /= window_square_sum
130
+ return inverse_transform
131
+
132
+ def forward(self, input_data):
133
+ """Take input data (audio) to STFT domain and then back to audio.
134
+
135
+ Arguments:
136
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
137
+
138
+ Returns:
139
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
140
+ shape (num_batch, num_samples)
141
+ """
142
+ self.magnitude, self.phase = self.transform(input_data, return_phase=True)
143
+ reconstruction = self.inverse(self.magnitude, self.phase)
144
+ return reconstruction
145
+
146
+
147
+ from time import time as ttime
148
+
149
+
150
+ class BiGRU(nn.Module):
151
+ def __init__(self, input_features, hidden_features, num_layers):
152
+ super(BiGRU, self).__init__()
153
+ self.gru = nn.GRU(
154
+ input_features,
155
+ hidden_features,
156
+ num_layers=num_layers,
157
+ batch_first=True,
158
+ bidirectional=True,
159
+ )
160
+
161
+ def forward(self, x):
162
+ return self.gru(x)[0]
163
+
164
+
165
+ class ConvBlockRes(nn.Module):
166
+ def __init__(self, in_channels, out_channels, momentum=0.01):
167
+ super(ConvBlockRes, self).__init__()
168
+ self.conv = nn.Sequential(
169
+ nn.Conv2d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=(3, 3),
173
+ stride=(1, 1),
174
+ padding=(1, 1),
175
+ bias=False,
176
+ ),
177
+ nn.BatchNorm2d(out_channels, momentum=momentum),
178
+ nn.ReLU(),
179
+ nn.Conv2d(
180
+ in_channels=out_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=(1, 1),
185
+ bias=False,
186
+ ),
187
+ nn.BatchNorm2d(out_channels, momentum=momentum),
188
+ nn.ReLU(),
189
+ )
190
+ # self.shortcut:Optional[nn.Module] = None
191
+ if in_channels != out_channels:
192
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
193
+
194
+ def forward(self, x: torch.Tensor):
195
+ if not hasattr(self, "shortcut"):
196
+ return self.conv(x) + x
197
+ else:
198
+ return self.conv(x) + self.shortcut(x)
199
+
200
+
201
+ class Encoder(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_channels,
205
+ in_size,
206
+ n_encoders,
207
+ kernel_size,
208
+ n_blocks,
209
+ out_channels=16,
210
+ momentum=0.01,
211
+ ):
212
+ super(Encoder, self).__init__()
213
+ self.n_encoders = n_encoders
214
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
215
+ self.layers = nn.ModuleList()
216
+ self.latent_channels = []
217
+ for i in range(self.n_encoders):
218
+ self.layers.append(
219
+ ResEncoderBlock(
220
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
221
+ )
222
+ )
223
+ self.latent_channels.append([out_channels, in_size])
224
+ in_channels = out_channels
225
+ out_channels *= 2
226
+ in_size //= 2
227
+ self.out_size = in_size
228
+ self.out_channel = out_channels
229
+
230
+ def forward(self, x: torch.Tensor):
231
+ concat_tensors: List[torch.Tensor] = []
232
+ x = self.bn(x)
233
+ for i, layer in enumerate(self.layers):
234
+ t, x = layer(x)
235
+ concat_tensors.append(t)
236
+ return x, concat_tensors
237
+
238
+
239
+ class ResEncoderBlock(nn.Module):
240
+ def __init__(
241
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
242
+ ):
243
+ super(ResEncoderBlock, self).__init__()
244
+ self.n_blocks = n_blocks
245
+ self.conv = nn.ModuleList()
246
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
247
+ for i in range(n_blocks - 1):
248
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
249
+ self.kernel_size = kernel_size
250
+ if self.kernel_size is not None:
251
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
252
+
253
+ def forward(self, x):
254
+ for i, conv in enumerate(self.conv):
255
+ x = conv(x)
256
+ if self.kernel_size is not None:
257
+ return x, self.pool(x)
258
+ else:
259
+ return x
260
+
261
+
262
+ class Intermediate(nn.Module): #
263
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
264
+ super(Intermediate, self).__init__()
265
+ self.n_inters = n_inters
266
+ self.layers = nn.ModuleList()
267
+ self.layers.append(
268
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
269
+ )
270
+ for i in range(self.n_inters - 1):
271
+ self.layers.append(
272
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
273
+ )
274
+
275
+ def forward(self, x):
276
+ for i, layer in enumerate(self.layers):
277
+ x = layer(x)
278
+ return x
279
+
280
+
281
+ class ResDecoderBlock(nn.Module):
282
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
283
+ super(ResDecoderBlock, self).__init__()
284
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
285
+ self.n_blocks = n_blocks
286
+ self.conv1 = nn.Sequential(
287
+ nn.ConvTranspose2d(
288
+ in_channels=in_channels,
289
+ out_channels=out_channels,
290
+ kernel_size=(3, 3),
291
+ stride=stride,
292
+ padding=(1, 1),
293
+ output_padding=out_padding,
294
+ bias=False,
295
+ ),
296
+ nn.BatchNorm2d(out_channels, momentum=momentum),
297
+ nn.ReLU(),
298
+ )
299
+ self.conv2 = nn.ModuleList()
300
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
301
+ for i in range(n_blocks - 1):
302
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
303
+
304
+ def forward(self, x, concat_tensor):
305
+ x = self.conv1(x)
306
+ x = torch.cat((x, concat_tensor), dim=1)
307
+ for i, conv2 in enumerate(self.conv2):
308
+ x = conv2(x)
309
+ return x
310
+
311
+
312
+ class Decoder(nn.Module):
313
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
314
+ super(Decoder, self).__init__()
315
+ self.layers = nn.ModuleList()
316
+ self.n_decoders = n_decoders
317
+ for i in range(self.n_decoders):
318
+ out_channels = in_channels // 2
319
+ self.layers.append(
320
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
321
+ )
322
+ in_channels = out_channels
323
+
324
+ def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
325
+ for i, layer in enumerate(self.layers):
326
+ x = layer(x, concat_tensors[-1 - i])
327
+ return x
328
+
329
+
330
+ class DeepUnet(nn.Module):
331
+ def __init__(
332
+ self,
333
+ kernel_size,
334
+ n_blocks,
335
+ en_de_layers=5,
336
+ inter_layers=4,
337
+ in_channels=1,
338
+ en_out_channels=16,
339
+ ):
340
+ super(DeepUnet, self).__init__()
341
+ self.encoder = Encoder(
342
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
343
+ )
344
+ self.intermediate = Intermediate(
345
+ self.encoder.out_channel // 2,
346
+ self.encoder.out_channel,
347
+ inter_layers,
348
+ n_blocks,
349
+ )
350
+ self.decoder = Decoder(
351
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ x, concat_tensors = self.encoder(x)
356
+ x = self.intermediate(x)
357
+ x = self.decoder(x, concat_tensors)
358
+ return x
359
+
360
+
361
+ class E2E(nn.Module):
362
+ def __init__(
363
+ self,
364
+ n_blocks,
365
+ n_gru,
366
+ kernel_size,
367
+ en_de_layers=5,
368
+ inter_layers=4,
369
+ in_channels=1,
370
+ en_out_channels=16,
371
+ ):
372
+ super(E2E, self).__init__()
373
+ self.unet = DeepUnet(
374
+ kernel_size,
375
+ n_blocks,
376
+ en_de_layers,
377
+ inter_layers,
378
+ in_channels,
379
+ en_out_channels,
380
+ )
381
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
382
+ if n_gru:
383
+ self.fc = nn.Sequential(
384
+ BiGRU(3 * 128, 256, n_gru),
385
+ nn.Linear(512, 360),
386
+ nn.Dropout(0.25),
387
+ nn.Sigmoid(),
388
+ )
389
+ else:
390
+ self.fc = nn.Sequential(
391
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
392
+ )
393
+
394
+ def forward(self, mel):
395
+ # print(mel.shape)
396
+ mel = mel.transpose(-1, -2).unsqueeze(1)
397
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
398
+ x = self.fc(x)
399
+ # print(x.shape)
400
+ return x
401
+
402
+
403
+ from librosa.filters import mel
404
+
405
+
406
+ class MelSpectrogram(torch.nn.Module):
407
+ def __init__(
408
+ self,
409
+ is_half,
410
+ n_mel_channels,
411
+ sampling_rate,
412
+ win_length,
413
+ hop_length,
414
+ n_fft=None,
415
+ mel_fmin=0,
416
+ mel_fmax=None,
417
+ clamp=1e-5,
418
+ ):
419
+ super().__init__()
420
+ n_fft = win_length if n_fft is None else n_fft
421
+ self.hann_window = {}
422
+ mel_basis = mel(
423
+ sr=sampling_rate,
424
+ n_fft=n_fft,
425
+ n_mels=n_mel_channels,
426
+ fmin=mel_fmin,
427
+ fmax=mel_fmax,
428
+ htk=True,
429
+ )
430
+ mel_basis = torch.from_numpy(mel_basis).float()
431
+ self.register_buffer("mel_basis", mel_basis)
432
+ self.n_fft = win_length if n_fft is None else n_fft
433
+ self.hop_length = hop_length
434
+ self.win_length = win_length
435
+ self.sampling_rate = sampling_rate
436
+ self.n_mel_channels = n_mel_channels
437
+ self.clamp = clamp
438
+ self.is_half = is_half
439
+
440
+ def forward(self, audio, keyshift=0, speed=1, center=True):
441
+ factor = 2 ** (keyshift / 12)
442
+ n_fft_new = int(np.round(self.n_fft * factor))
443
+ win_length_new = int(np.round(self.win_length * factor))
444
+ hop_length_new = int(np.round(self.hop_length * speed))
445
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
446
+ if keyshift_key not in self.hann_window:
447
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
448
+ audio.device
449
+ )
450
+ if "privateuseone" in str(audio.device):
451
+ if not hasattr(self, "stft"):
452
+ self.stft = STFT(
453
+ filter_length=n_fft_new,
454
+ hop_length=hop_length_new,
455
+ win_length=win_length_new,
456
+ window="hann",
457
+ ).to(audio.device)
458
+ magnitude = self.stft.transform(audio)
459
+ else:
460
+ fft = torch.stft(
461
+ audio,
462
+ n_fft=n_fft_new,
463
+ hop_length=hop_length_new,
464
+ win_length=win_length_new,
465
+ window=self.hann_window[keyshift_key],
466
+ center=center,
467
+ return_complex=True,
468
+ )
469
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
470
+ if keyshift != 0:
471
+ size = self.n_fft // 2 + 1
472
+ resize = magnitude.size(1)
473
+ if resize < size:
474
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
475
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
476
+ mel_output = torch.matmul(self.mel_basis, magnitude)
477
+ if self.is_half == True:
478
+ mel_output = mel_output.half()
479
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
480
+ return log_mel_spec
481
+
482
+
483
+ class RMVPE:
484
+ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
485
+ self.resample_kernel = {}
486
+ self.resample_kernel = {}
487
+ self.is_half = is_half
488
+ if device is None:
489
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
490
+ self.device = device
491
+ self.mel_extractor = MelSpectrogram(
492
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
493
+ ).to(device)
494
+ if "privateuseone" in str(device):
495
+ import onnxruntime as ort
496
+
497
+ ort_session = ort.InferenceSession(
498
+ "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
499
+ providers=["DmlExecutionProvider"],
500
+ )
501
+ self.model = ort_session
502
+ else:
503
+ if str(self.device) == "cuda":
504
+ self.device = torch.device("cuda:0")
505
+
506
+ def get_default_model():
507
+ model = E2E(4, 1, (2, 2))
508
+ ckpt = torch.load(model_path, map_location="cpu")
509
+ model.load_state_dict(ckpt)
510
+ model.eval()
511
+ if is_half:
512
+ model = model.half()
513
+ else:
514
+ model = model.float()
515
+ return model
516
+
517
+ self.model = get_default_model()
518
+
519
+ self.model = self.model.to(device)
520
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
521
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
522
+
523
+ def mel2hidden(self, mel):
524
+ with torch.no_grad():
525
+ n_frames = mel.shape[-1]
526
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
527
+ if n_pad > 0:
528
+ mel = F.pad(mel, (0, n_pad), mode="constant")
529
+ if "privateuseone" in str(self.device):
530
+ onnx_input_name = self.model.get_inputs()[0].name
531
+ onnx_outputs_names = self.model.get_outputs()[0].name
532
+ hidden = self.model.run(
533
+ [onnx_outputs_names],
534
+ input_feed={onnx_input_name: mel.cpu().numpy()},
535
+ )[0]
536
+ else:
537
+ mel = mel.half() if self.is_half else mel.float()
538
+ hidden = self.model(mel)
539
+ return hidden[:, :n_frames]
540
+
541
+ def decode(self, hidden, thred=0.03):
542
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
543
+ f0 = 10 * (2 ** (cents_pred / 1200))
544
+ f0[f0 == 10] = 0
545
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
546
+ return f0
547
+
548
+ def infer_from_audio(self, audio, thred=0.03):
549
+ # torch.cuda.synchronize()
550
+ # t0 = ttime()
551
+ if not torch.is_tensor(audio):
552
+ audio = torch.from_numpy(audio)
553
+ mel = self.mel_extractor(
554
+ audio.float().to(self.device).unsqueeze(0), center=True
555
+ )
556
+ # print(123123123,mel.device.type)
557
+ # torch.cuda.synchronize()
558
+ # t1 = ttime()
559
+ hidden = self.mel2hidden(mel)
560
+ # torch.cuda.synchronize()
561
+ # t2 = ttime()
562
+ # print(234234,hidden.device.type)
563
+ if "privateuseone" not in str(self.device):
564
+ hidden = hidden.squeeze(0).cpu().numpy()
565
+ else:
566
+ hidden = hidden[0]
567
+ if self.is_half == True:
568
+ hidden = hidden.astype("float32")
569
+
570
+ f0 = self.decode(hidden, thred=thred)
571
+ # torch.cuda.synchronize()
572
+ # t3 = ttime()
573
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
574
+ return f0
575
+
576
+ def to_local_average_cents(self, salience, thred=0.05):
577
+ # t0 = ttime()
578
+ center = np.argmax(salience, axis=1) # 帧长#index
579
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
580
+ # t1 = ttime()
581
+ center += 4
582
+ todo_salience = []
583
+ todo_cents_mapping = []
584
+ starts = center - 4
585
+ ends = center + 5
586
+ for idx in range(salience.shape[0]):
587
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
588
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
589
+ # t2 = ttime()
590
+ todo_salience = np.array(todo_salience) # 帧长,9
591
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
592
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
593
+ weight_sum = np.sum(todo_salience, 1) # 帧长
594
+ devided = product_sum / weight_sum # 帧长
595
+ # t3 = ttime()
596
+ maxx = np.max(salience, axis=1) # 帧长
597
+ devided[maxx <= thred] = 0
598
+ # t4 = ttime()
599
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
600
+ return devided
modules/wavenet.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from modules.encodec import SConv1d
7
+
8
+ from . import commons
9
+ LRELU_SLOPE = 0.1
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class ConvReluNorm(nn.Module):
27
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
28
+ super().__init__()
29
+ self.in_channels = in_channels
30
+ self.hidden_channels = hidden_channels
31
+ self.out_channels = out_channels
32
+ self.kernel_size = kernel_size
33
+ self.n_layers = n_layers
34
+ self.p_dropout = p_dropout
35
+ assert n_layers > 1, "Number of layers should be larger than 0."
36
+
37
+ self.conv_layers = nn.ModuleList()
38
+ self.norm_layers = nn.ModuleList()
39
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
40
+ self.norm_layers.append(LayerNorm(hidden_channels))
41
+ self.relu_drop = nn.Sequential(
42
+ nn.ReLU(),
43
+ nn.Dropout(p_dropout))
44
+ for _ in range(n_layers - 1):
45
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
46
+ self.norm_layers.append(LayerNorm(hidden_channels))
47
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
48
+ self.proj.weight.data.zero_()
49
+ self.proj.bias.data.zero_()
50
+
51
+ def forward(self, x, x_mask):
52
+ x_org = x
53
+ for i in range(self.n_layers):
54
+ x = self.conv_layers[i](x * x_mask)
55
+ x = self.norm_layers[i](x)
56
+ x = self.relu_drop(x)
57
+ x = x_org + self.proj(x)
58
+ return x * x_mask
59
+
60
+
61
+ class DDSConv(nn.Module):
62
+ """
63
+ Dialted and Depth-Separable Convolution
64
+ """
65
+
66
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
67
+ super().__init__()
68
+ self.channels = channels
69
+ self.kernel_size = kernel_size
70
+ self.n_layers = n_layers
71
+ self.p_dropout = p_dropout
72
+
73
+ self.drop = nn.Dropout(p_dropout)
74
+ self.convs_sep = nn.ModuleList()
75
+ self.convs_1x1 = nn.ModuleList()
76
+ self.norms_1 = nn.ModuleList()
77
+ self.norms_2 = nn.ModuleList()
78
+ for i in range(n_layers):
79
+ dilation = kernel_size ** i
80
+ padding = (kernel_size * dilation - dilation) // 2
81
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
82
+ groups=channels, dilation=dilation, padding=padding
83
+ ))
84
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
85
+ self.norms_1.append(LayerNorm(channels))
86
+ self.norms_2.append(LayerNorm(channels))
87
+
88
+ def forward(self, x, x_mask, g=None):
89
+ if g is not None:
90
+ x = x + g
91
+ for i in range(self.n_layers):
92
+ y = self.convs_sep[i](x * x_mask)
93
+ y = self.norms_1[i](y)
94
+ y = F.gelu(y)
95
+ y = self.convs_1x1[i](y)
96
+ y = self.norms_2[i](y)
97
+ y = F.gelu(y)
98
+ y = self.drop(y)
99
+ x = x + y
100
+ return x * x_mask
101
+
102
+
103
+ class WN(torch.nn.Module):
104
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
105
+ super(WN, self).__init__()
106
+ conv1d_type = SConv1d
107
+ assert (kernel_size % 2 == 1)
108
+ self.hidden_channels = hidden_channels
109
+ self.kernel_size = kernel_size,
110
+ self.dilation_rate = dilation_rate
111
+ self.n_layers = n_layers
112
+ self.gin_channels = gin_channels
113
+ self.p_dropout = p_dropout
114
+
115
+ self.in_layers = torch.nn.ModuleList()
116
+ self.res_skip_layers = torch.nn.ModuleList()
117
+ self.drop = nn.Dropout(p_dropout)
118
+
119
+ if gin_channels != 0:
120
+ self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
121
+
122
+ for i in range(n_layers):
123
+ dilation = dilation_rate ** i
124
+ padding = int((kernel_size * dilation - dilation) / 2)
125
+ in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
126
+ padding=padding, norm='weight_norm', causal=causal)
127
+ self.in_layers.append(in_layer)
128
+
129
+ # last one is not necessary
130
+ if i < n_layers - 1:
131
+ res_skip_channels = 2 * hidden_channels
132
+ else:
133
+ res_skip_channels = hidden_channels
134
+
135
+ res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
136
+ self.res_skip_layers.append(res_skip_layer)
137
+
138
+ def forward(self, x, x_mask, g=None, **kwargs):
139
+ output = torch.zeros_like(x)
140
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
141
+
142
+ if g is not None:
143
+ g = self.cond_layer(g)
144
+
145
+ for i in range(self.n_layers):
146
+ x_in = self.in_layers[i](x)
147
+ if g is not None:
148
+ cond_offset = i * 2 * self.hidden_channels
149
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
150
+ else:
151
+ g_l = torch.zeros_like(x_in)
152
+
153
+ acts = commons.fused_add_tanh_sigmoid_multiply(
154
+ x_in,
155
+ g_l,
156
+ n_channels_tensor)
157
+ acts = self.drop(acts)
158
+
159
+ res_skip_acts = self.res_skip_layers[i](acts)
160
+ if i < self.n_layers - 1:
161
+ res_acts = res_skip_acts[:, :self.hidden_channels, :]
162
+ x = (x + res_acts) * x_mask
163
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
164
+ else:
165
+ output = output + res_skip_acts
166
+ return output * x_mask
167
+
168
+ def remove_weight_norm(self):
169
+ if self.gin_channels != 0:
170
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
171
+ for l in self.in_layers:
172
+ torch.nn.utils.remove_weight_norm(l)
173
+ for l in self.res_skip_layers:
174
+ torch.nn.utils.remove_weight_norm(l)