Spaces:
Sleeping
Sleeping
Staticaliza
commited on
Upload 10 files
Browse files- modules/audio.py +82 -0
- modules/commons.py +490 -0
- modules/diffusion_transformer.py +240 -0
- modules/encodec.py +292 -0
- modules/flow_matching.py +155 -0
- modules/layers.py +354 -0
- modules/length_regulator.py +141 -0
- modules/quantize.py +229 -0
- modules/rmvpe.py +600 -0
- modules/wavenet.py +174 -0
modules/audio.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
46 |
+
if torch.min(y) < -1.0:
|
47 |
+
print("min value is ", torch.min(y))
|
48 |
+
if torch.max(y) > 1.0:
|
49 |
+
print("max value is ", torch.max(y))
|
50 |
+
|
51 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
52 |
+
if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
|
53 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
54 |
+
mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
55 |
+
hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
|
56 |
+
|
57 |
+
y = torch.nn.functional.pad(
|
58 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
59 |
+
)
|
60 |
+
y = y.squeeze(1)
|
61 |
+
|
62 |
+
spec = torch.view_as_real(
|
63 |
+
torch.stft(
|
64 |
+
y,
|
65 |
+
n_fft,
|
66 |
+
hop_length=hop_size,
|
67 |
+
win_length=win_size,
|
68 |
+
window=hann_window[str(sampling_rate) + "_" + str(y.device)],
|
69 |
+
center=center,
|
70 |
+
pad_mode="reflect",
|
71 |
+
normalized=False,
|
72 |
+
onesided=True,
|
73 |
+
return_complex=True,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
|
77 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
78 |
+
|
79 |
+
spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
|
80 |
+
spec = spectral_normalize_torch(spec)
|
81 |
+
|
82 |
+
return spec
|
modules/commons.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from munch import Munch
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
class AttrDict(dict):
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
13 |
+
self.__dict__ = self
|
14 |
+
|
15 |
+
|
16 |
+
def init_weights(m, mean=0.0, std=0.01):
|
17 |
+
classname = m.__class__.__name__
|
18 |
+
if classname.find("Conv") != -1:
|
19 |
+
m.weight.data.normal_(mean, std)
|
20 |
+
|
21 |
+
|
22 |
+
def get_padding(kernel_size, dilation=1):
|
23 |
+
return int((kernel_size * dilation - dilation) / 2)
|
24 |
+
|
25 |
+
|
26 |
+
def convert_pad_shape(pad_shape):
|
27 |
+
l = pad_shape[::-1]
|
28 |
+
pad_shape = [item for sublist in l for item in sublist]
|
29 |
+
return pad_shape
|
30 |
+
|
31 |
+
|
32 |
+
def intersperse(lst, item):
|
33 |
+
result = [item] * (len(lst) * 2 + 1)
|
34 |
+
result[1::2] = lst
|
35 |
+
return result
|
36 |
+
|
37 |
+
|
38 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
39 |
+
"""KL(P||Q)"""
|
40 |
+
kl = (logs_q - logs_p) - 0.5
|
41 |
+
kl += (
|
42 |
+
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
43 |
+
)
|
44 |
+
return kl
|
45 |
+
|
46 |
+
|
47 |
+
def rand_gumbel(shape):
|
48 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
49 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
50 |
+
return -torch.log(-torch.log(uniform_samples))
|
51 |
+
|
52 |
+
|
53 |
+
def rand_gumbel_like(x):
|
54 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
55 |
+
return g
|
56 |
+
|
57 |
+
|
58 |
+
def slice_segments(x, ids_str, segment_size=4):
|
59 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
60 |
+
for i in range(x.size(0)):
|
61 |
+
idx_str = ids_str[i]
|
62 |
+
idx_end = idx_str + segment_size
|
63 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
64 |
+
return ret
|
65 |
+
|
66 |
+
|
67 |
+
def slice_segments_audio(x, ids_str, segment_size=4):
|
68 |
+
ret = torch.zeros_like(x[:, :segment_size])
|
69 |
+
for i in range(x.size(0)):
|
70 |
+
idx_str = ids_str[i]
|
71 |
+
idx_end = idx_str + segment_size
|
72 |
+
ret[i] = x[i, idx_str:idx_end]
|
73 |
+
return ret
|
74 |
+
|
75 |
+
|
76 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
77 |
+
b, d, t = x.size()
|
78 |
+
if x_lengths is None:
|
79 |
+
x_lengths = t
|
80 |
+
ids_str_max = x_lengths - segment_size + 1
|
81 |
+
ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
|
82 |
+
dtype=torch.long
|
83 |
+
)
|
84 |
+
ret = slice_segments(x, ids_str, segment_size)
|
85 |
+
return ret, ids_str
|
86 |
+
|
87 |
+
|
88 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
89 |
+
position = torch.arange(length, dtype=torch.float)
|
90 |
+
num_timescales = channels // 2
|
91 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
92 |
+
num_timescales - 1
|
93 |
+
)
|
94 |
+
inv_timescales = min_timescale * torch.exp(
|
95 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
96 |
+
)
|
97 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
98 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
99 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
100 |
+
signal = signal.view(1, channels, length)
|
101 |
+
return signal
|
102 |
+
|
103 |
+
|
104 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
105 |
+
b, channels, length = x.size()
|
106 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
107 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
108 |
+
|
109 |
+
|
110 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
111 |
+
b, channels, length = x.size()
|
112 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
113 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
114 |
+
|
115 |
+
|
116 |
+
def subsequent_mask(length):
|
117 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
118 |
+
return mask
|
119 |
+
|
120 |
+
|
121 |
+
@torch.jit.script
|
122 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
123 |
+
n_channels_int = n_channels[0]
|
124 |
+
in_act = input_a + input_b
|
125 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
126 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
127 |
+
acts = t_act * s_act
|
128 |
+
return acts
|
129 |
+
|
130 |
+
|
131 |
+
def convert_pad_shape(pad_shape):
|
132 |
+
l = pad_shape[::-1]
|
133 |
+
pad_shape = [item for sublist in l for item in sublist]
|
134 |
+
return pad_shape
|
135 |
+
|
136 |
+
|
137 |
+
def shift_1d(x):
|
138 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
def sequence_mask(length, max_length=None):
|
143 |
+
if max_length is None:
|
144 |
+
max_length = length.max()
|
145 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
146 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
147 |
+
|
148 |
+
|
149 |
+
def avg_with_mask(x, mask):
|
150 |
+
assert mask.dtype == torch.float, "Mask should be float"
|
151 |
+
|
152 |
+
if mask.ndim == 2:
|
153 |
+
mask = mask.unsqueeze(1)
|
154 |
+
|
155 |
+
if mask.shape[1] == 1:
|
156 |
+
mask = mask.expand_as(x)
|
157 |
+
|
158 |
+
return (x * mask).sum() / mask.sum()
|
159 |
+
|
160 |
+
|
161 |
+
def generate_path(duration, mask):
|
162 |
+
"""
|
163 |
+
duration: [b, 1, t_x]
|
164 |
+
mask: [b, 1, t_y, t_x]
|
165 |
+
"""
|
166 |
+
device = duration.device
|
167 |
+
|
168 |
+
b, _, t_y, t_x = mask.shape
|
169 |
+
cum_duration = torch.cumsum(duration, -1)
|
170 |
+
|
171 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
172 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
173 |
+
path = path.view(b, t_x, t_y)
|
174 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
175 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
176 |
+
return path
|
177 |
+
|
178 |
+
|
179 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
180 |
+
if isinstance(parameters, torch.Tensor):
|
181 |
+
parameters = [parameters]
|
182 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
183 |
+
norm_type = float(norm_type)
|
184 |
+
if clip_value is not None:
|
185 |
+
clip_value = float(clip_value)
|
186 |
+
|
187 |
+
total_norm = 0
|
188 |
+
for p in parameters:
|
189 |
+
param_norm = p.grad.data.norm(norm_type)
|
190 |
+
total_norm += param_norm.item() ** norm_type
|
191 |
+
if clip_value is not None:
|
192 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
193 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
194 |
+
return total_norm
|
195 |
+
|
196 |
+
|
197 |
+
def log_norm(x, mean=-4, std=4, dim=2):
|
198 |
+
"""
|
199 |
+
normalized log mel -> mel -> norm -> log(norm)
|
200 |
+
"""
|
201 |
+
x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
|
202 |
+
return x
|
203 |
+
|
204 |
+
|
205 |
+
def load_F0_models(path):
|
206 |
+
# load F0 model
|
207 |
+
from .JDC.model import JDCNet
|
208 |
+
|
209 |
+
F0_model = JDCNet(num_class=1, seq_len=192)
|
210 |
+
params = torch.load(path, map_location="cpu")["net"]
|
211 |
+
F0_model.load_state_dict(params)
|
212 |
+
_ = F0_model.train()
|
213 |
+
|
214 |
+
return F0_model
|
215 |
+
|
216 |
+
|
217 |
+
def modify_w2v_forward(self, output_layer=15):
|
218 |
+
"""
|
219 |
+
change forward method of w2v encoder to get its intermediate layer output
|
220 |
+
:param self:
|
221 |
+
:param layer:
|
222 |
+
:return:
|
223 |
+
"""
|
224 |
+
from transformers.modeling_outputs import BaseModelOutput
|
225 |
+
|
226 |
+
def forward(
|
227 |
+
hidden_states,
|
228 |
+
attention_mask=None,
|
229 |
+
output_attentions=False,
|
230 |
+
output_hidden_states=False,
|
231 |
+
return_dict=True,
|
232 |
+
):
|
233 |
+
all_hidden_states = () if output_hidden_states else None
|
234 |
+
all_self_attentions = () if output_attentions else None
|
235 |
+
|
236 |
+
conv_attention_mask = attention_mask
|
237 |
+
if attention_mask is not None:
|
238 |
+
# make sure padded tokens output 0
|
239 |
+
hidden_states = hidden_states.masked_fill(
|
240 |
+
~attention_mask.bool().unsqueeze(-1), 0.0
|
241 |
+
)
|
242 |
+
|
243 |
+
# extend attention_mask
|
244 |
+
attention_mask = 1.0 - attention_mask[:, None, None, :].to(
|
245 |
+
dtype=hidden_states.dtype
|
246 |
+
)
|
247 |
+
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
|
248 |
+
attention_mask = attention_mask.expand(
|
249 |
+
attention_mask.shape[0],
|
250 |
+
1,
|
251 |
+
attention_mask.shape[-1],
|
252 |
+
attention_mask.shape[-1],
|
253 |
+
)
|
254 |
+
|
255 |
+
hidden_states = self.dropout(hidden_states)
|
256 |
+
|
257 |
+
if self.embed_positions is not None:
|
258 |
+
relative_position_embeddings = self.embed_positions(hidden_states)
|
259 |
+
else:
|
260 |
+
relative_position_embeddings = None
|
261 |
+
|
262 |
+
deepspeed_zero3_is_enabled = False
|
263 |
+
|
264 |
+
for i, layer in enumerate(self.layers):
|
265 |
+
if output_hidden_states:
|
266 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
267 |
+
|
268 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
269 |
+
dropout_probability = torch.rand([])
|
270 |
+
|
271 |
+
skip_the_layer = (
|
272 |
+
True
|
273 |
+
if self.training and (dropout_probability < self.config.layerdrop)
|
274 |
+
else False
|
275 |
+
)
|
276 |
+
if not skip_the_layer or deepspeed_zero3_is_enabled:
|
277 |
+
# under deepspeed zero3 all gpus must run in sync
|
278 |
+
if self.gradient_checkpointing and self.training:
|
279 |
+
layer_outputs = self._gradient_checkpointing_func(
|
280 |
+
layer.__call__,
|
281 |
+
hidden_states,
|
282 |
+
attention_mask,
|
283 |
+
relative_position_embeddings,
|
284 |
+
output_attentions,
|
285 |
+
conv_attention_mask,
|
286 |
+
)
|
287 |
+
else:
|
288 |
+
layer_outputs = layer(
|
289 |
+
hidden_states,
|
290 |
+
attention_mask=attention_mask,
|
291 |
+
relative_position_embeddings=relative_position_embeddings,
|
292 |
+
output_attentions=output_attentions,
|
293 |
+
conv_attention_mask=conv_attention_mask,
|
294 |
+
)
|
295 |
+
hidden_states = layer_outputs[0]
|
296 |
+
|
297 |
+
if skip_the_layer:
|
298 |
+
layer_outputs = (None, None)
|
299 |
+
|
300 |
+
if output_attentions:
|
301 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
302 |
+
|
303 |
+
if i == output_layer - 1:
|
304 |
+
break
|
305 |
+
|
306 |
+
if output_hidden_states:
|
307 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
308 |
+
|
309 |
+
if not return_dict:
|
310 |
+
return tuple(
|
311 |
+
v
|
312 |
+
for v in [hidden_states, all_hidden_states, all_self_attentions]
|
313 |
+
if v is not None
|
314 |
+
)
|
315 |
+
return BaseModelOutput(
|
316 |
+
last_hidden_state=hidden_states,
|
317 |
+
hidden_states=all_hidden_states,
|
318 |
+
attentions=all_self_attentions,
|
319 |
+
)
|
320 |
+
|
321 |
+
return forward
|
322 |
+
|
323 |
+
|
324 |
+
MATPLOTLIB_FLAG = False
|
325 |
+
|
326 |
+
|
327 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
328 |
+
global MATPLOTLIB_FLAG
|
329 |
+
if not MATPLOTLIB_FLAG:
|
330 |
+
import matplotlib
|
331 |
+
import logging
|
332 |
+
|
333 |
+
matplotlib.use("Agg")
|
334 |
+
MATPLOTLIB_FLAG = True
|
335 |
+
mpl_logger = logging.getLogger("matplotlib")
|
336 |
+
mpl_logger.setLevel(logging.WARNING)
|
337 |
+
import matplotlib.pylab as plt
|
338 |
+
import numpy as np
|
339 |
+
|
340 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
341 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
342 |
+
plt.colorbar(im, ax=ax)
|
343 |
+
plt.xlabel("Frames")
|
344 |
+
plt.ylabel("Channels")
|
345 |
+
plt.tight_layout()
|
346 |
+
|
347 |
+
fig.canvas.draw()
|
348 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
|
349 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
350 |
+
plt.close()
|
351 |
+
return data
|
352 |
+
|
353 |
+
|
354 |
+
def normalize_f0(f0_sequence):
|
355 |
+
# Remove unvoiced frames (replace with -1)
|
356 |
+
voiced_indices = np.where(f0_sequence > 0)[0]
|
357 |
+
f0_voiced = f0_sequence[voiced_indices]
|
358 |
+
|
359 |
+
# Convert to log scale
|
360 |
+
log_f0 = np.log2(f0_voiced)
|
361 |
+
|
362 |
+
# Calculate mean and standard deviation
|
363 |
+
mean_f0 = np.mean(log_f0)
|
364 |
+
std_f0 = np.std(log_f0)
|
365 |
+
|
366 |
+
# Normalize the F0 sequence
|
367 |
+
normalized_f0 = (log_f0 - mean_f0) / std_f0
|
368 |
+
|
369 |
+
# Create the normalized F0 sequence with unvoiced frames
|
370 |
+
normalized_sequence = np.zeros_like(f0_sequence)
|
371 |
+
normalized_sequence[voiced_indices] = normalized_f0
|
372 |
+
normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
|
373 |
+
|
374 |
+
return normalized_sequence
|
375 |
+
|
376 |
+
|
377 |
+
def build_model(args, stage="DiT"):
|
378 |
+
if stage == "DiT":
|
379 |
+
from modules.flow_matching import CFM
|
380 |
+
from modules.length_regulator import InterpolateRegulator
|
381 |
+
|
382 |
+
length_regulator = InterpolateRegulator(
|
383 |
+
channels=args.length_regulator.channels,
|
384 |
+
sampling_ratios=args.length_regulator.sampling_ratios,
|
385 |
+
is_discrete=args.length_regulator.is_discrete,
|
386 |
+
in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
|
387 |
+
vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
|
388 |
+
codebook_size=args.length_regulator.content_codebook_size,
|
389 |
+
n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
|
390 |
+
quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
|
391 |
+
f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
|
392 |
+
n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
|
393 |
+
)
|
394 |
+
cfm = CFM(args)
|
395 |
+
nets = Munch(
|
396 |
+
cfm=cfm,
|
397 |
+
length_regulator=length_regulator,
|
398 |
+
)
|
399 |
+
elif stage == 'codec':
|
400 |
+
from dac.model.dac import Encoder
|
401 |
+
from modules.quantize import (
|
402 |
+
FAquantizer,
|
403 |
+
)
|
404 |
+
|
405 |
+
encoder = Encoder(
|
406 |
+
d_model=args.DAC.encoder_dim,
|
407 |
+
strides=args.DAC.encoder_rates,
|
408 |
+
d_latent=1024,
|
409 |
+
causal=args.causal,
|
410 |
+
lstm=args.lstm,
|
411 |
+
)
|
412 |
+
|
413 |
+
quantizer = FAquantizer(
|
414 |
+
in_dim=1024,
|
415 |
+
n_p_codebooks=1,
|
416 |
+
n_c_codebooks=args.n_c_codebooks,
|
417 |
+
n_t_codebooks=2,
|
418 |
+
n_r_codebooks=3,
|
419 |
+
codebook_size=1024,
|
420 |
+
codebook_dim=8,
|
421 |
+
quantizer_dropout=0.5,
|
422 |
+
causal=args.causal,
|
423 |
+
separate_prosody_encoder=args.separate_prosody_encoder,
|
424 |
+
timbre_norm=args.timbre_norm,
|
425 |
+
)
|
426 |
+
|
427 |
+
nets = Munch(
|
428 |
+
encoder=encoder,
|
429 |
+
quantizer=quantizer,
|
430 |
+
)
|
431 |
+
else:
|
432 |
+
raise ValueError(f"Unknown stage: {stage}")
|
433 |
+
|
434 |
+
return nets
|
435 |
+
|
436 |
+
|
437 |
+
def load_checkpoint(
|
438 |
+
model,
|
439 |
+
optimizer,
|
440 |
+
path,
|
441 |
+
load_only_params=True,
|
442 |
+
ignore_modules=[],
|
443 |
+
is_distributed=False,
|
444 |
+
):
|
445 |
+
state = torch.load(path, map_location="cpu")
|
446 |
+
params = state["net"]
|
447 |
+
for key in model:
|
448 |
+
if key in params and key not in ignore_modules:
|
449 |
+
if not is_distributed:
|
450 |
+
# strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
|
451 |
+
for k in list(params[key].keys()):
|
452 |
+
if k.startswith("module."):
|
453 |
+
params[key][k[len("module.") :]] = params[key][k]
|
454 |
+
del params[key][k]
|
455 |
+
model_state_dict = model[key].state_dict()
|
456 |
+
# 过滤出形状匹配的键值对
|
457 |
+
filtered_state_dict = {
|
458 |
+
k: v
|
459 |
+
for k, v in params[key].items()
|
460 |
+
if k in model_state_dict and v.shape == model_state_dict[k].shape
|
461 |
+
}
|
462 |
+
skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
|
463 |
+
if skipped_keys:
|
464 |
+
print(
|
465 |
+
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
|
466 |
+
)
|
467 |
+
print("%s loaded" % key)
|
468 |
+
model[key].load_state_dict(filtered_state_dict, strict=False)
|
469 |
+
_ = [model[key].eval() for key in model]
|
470 |
+
|
471 |
+
if not load_only_params:
|
472 |
+
epoch = state["epoch"] + 1
|
473 |
+
iters = state["iters"]
|
474 |
+
optimizer.load_state_dict(state["optimizer"])
|
475 |
+
optimizer.load_scheduler_state_dict(state["scheduler"])
|
476 |
+
|
477 |
+
else:
|
478 |
+
epoch = 0
|
479 |
+
iters = 0
|
480 |
+
|
481 |
+
return model, optimizer, epoch, iters
|
482 |
+
|
483 |
+
|
484 |
+
def recursive_munch(d):
|
485 |
+
if isinstance(d, dict):
|
486 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
487 |
+
elif isinstance(d, list):
|
488 |
+
return [recursive_munch(v) for v in d]
|
489 |
+
else:
|
490 |
+
return d
|
modules/diffusion_transformer.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
from modules.gpt_fast.model import ModelArgs, Transformer
|
6 |
+
# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
|
7 |
+
from modules.wavenet import WN
|
8 |
+
from modules.commons import sequence_mask
|
9 |
+
|
10 |
+
from torch.nn.utils import weight_norm
|
11 |
+
|
12 |
+
def modulate(x, shift, scale):
|
13 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
14 |
+
|
15 |
+
|
16 |
+
#################################################################################
|
17 |
+
# Embedding Layers for Timesteps and Class Labels #
|
18 |
+
#################################################################################
|
19 |
+
|
20 |
+
class TimestepEmbedder(nn.Module):
|
21 |
+
"""
|
22 |
+
Embeds scalar timesteps into vector representations.
|
23 |
+
"""
|
24 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
25 |
+
super().__init__()
|
26 |
+
self.mlp = nn.Sequential(
|
27 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
28 |
+
nn.SiLU(),
|
29 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
30 |
+
)
|
31 |
+
self.frequency_embedding_size = frequency_embedding_size
|
32 |
+
self.max_period = 10000
|
33 |
+
self.scale = 1000
|
34 |
+
|
35 |
+
half = frequency_embedding_size // 2
|
36 |
+
freqs = torch.exp(
|
37 |
+
-math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
38 |
+
)
|
39 |
+
self.register_buffer("freqs", freqs)
|
40 |
+
|
41 |
+
def timestep_embedding(self, t):
|
42 |
+
"""
|
43 |
+
Create sinusoidal timestep embeddings.
|
44 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
45 |
+
These may be fractional.
|
46 |
+
:param dim: the dimension of the output.
|
47 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
48 |
+
:return: an (N, D) Tensor of positional embeddings.
|
49 |
+
"""
|
50 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
51 |
+
|
52 |
+
args = self.scale * t[:, None].float() * self.freqs[None]
|
53 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
54 |
+
if self.frequency_embedding_size % 2:
|
55 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
56 |
+
return embedding
|
57 |
+
|
58 |
+
def forward(self, t):
|
59 |
+
t_freq = self.timestep_embedding(t)
|
60 |
+
t_emb = self.mlp(t_freq)
|
61 |
+
return t_emb
|
62 |
+
|
63 |
+
|
64 |
+
class StyleEmbedder(nn.Module):
|
65 |
+
"""
|
66 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
67 |
+
"""
|
68 |
+
def __init__(self, input_size, hidden_size, dropout_prob):
|
69 |
+
super().__init__()
|
70 |
+
use_cfg_embedding = dropout_prob > 0
|
71 |
+
self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
|
72 |
+
self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
|
73 |
+
self.input_size = input_size
|
74 |
+
self.dropout_prob = dropout_prob
|
75 |
+
|
76 |
+
def forward(self, labels, train, force_drop_ids=None):
|
77 |
+
use_dropout = self.dropout_prob > 0
|
78 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
79 |
+
labels = self.token_drop(labels, force_drop_ids)
|
80 |
+
else:
|
81 |
+
labels = self.style_in(labels)
|
82 |
+
embeddings = labels
|
83 |
+
return embeddings
|
84 |
+
|
85 |
+
class FinalLayer(nn.Module):
|
86 |
+
"""
|
87 |
+
The final layer of DiT.
|
88 |
+
"""
|
89 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
90 |
+
super().__init__()
|
91 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
92 |
+
self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
|
93 |
+
self.adaLN_modulation = nn.Sequential(
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x, c):
|
99 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
100 |
+
x = modulate(self.norm_final(x), shift, scale)
|
101 |
+
x = self.linear(x)
|
102 |
+
return x
|
103 |
+
|
104 |
+
class DiT(torch.nn.Module):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
args
|
108 |
+
):
|
109 |
+
super(DiT, self).__init__()
|
110 |
+
self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
|
111 |
+
self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
|
112 |
+
self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
|
113 |
+
model_args = ModelArgs(
|
114 |
+
block_size=16384,#args.DiT.block_size,
|
115 |
+
n_layer=args.DiT.depth,
|
116 |
+
n_head=args.DiT.num_heads,
|
117 |
+
dim=args.DiT.hidden_dim,
|
118 |
+
head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
|
119 |
+
vocab_size=1024,
|
120 |
+
uvit_skip_connection=self.uvit_skip_connection,
|
121 |
+
)
|
122 |
+
self.transformer = Transformer(model_args)
|
123 |
+
self.in_channels = args.DiT.in_channels
|
124 |
+
self.out_channels = args.DiT.in_channels
|
125 |
+
self.num_heads = args.DiT.num_heads
|
126 |
+
|
127 |
+
self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
|
128 |
+
|
129 |
+
self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
|
130 |
+
self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
|
131 |
+
self.content_dim = args.DiT.content_dim # for continuous content
|
132 |
+
self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
|
133 |
+
self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
|
134 |
+
|
135 |
+
self.is_causal = args.DiT.is_causal
|
136 |
+
|
137 |
+
self.n_f0_bins = args.DiT.n_f0_bins
|
138 |
+
self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
|
139 |
+
self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
|
140 |
+
self.f0_condition = args.DiT.f0_condition
|
141 |
+
|
142 |
+
self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
|
143 |
+
self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
|
144 |
+
# self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
|
145 |
+
# self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
|
146 |
+
|
147 |
+
input_pos = torch.arange(16384)
|
148 |
+
self.register_buffer("input_pos", input_pos)
|
149 |
+
|
150 |
+
self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
|
151 |
+
self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
|
152 |
+
self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
|
153 |
+
if self.final_layer_type == 'wavenet':
|
154 |
+
self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
|
155 |
+
kernel_size=args.wavenet.kernel_size,
|
156 |
+
dilation_rate=args.wavenet.dilation_rate,
|
157 |
+
n_layers=args.wavenet.num_layers,
|
158 |
+
gin_channels=args.wavenet.hidden_dim,
|
159 |
+
p_dropout=args.wavenet.p_dropout,
|
160 |
+
causal=False)
|
161 |
+
self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
|
162 |
+
else:
|
163 |
+
self.final_mlp = nn.Sequential(
|
164 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
|
165 |
+
nn.SiLU(),
|
166 |
+
nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
|
167 |
+
)
|
168 |
+
self.transformer_style_condition = args.DiT.style_condition
|
169 |
+
self.wavenet_style_condition = args.wavenet.style_condition
|
170 |
+
assert args.DiT.style_condition == args.wavenet.style_condition
|
171 |
+
|
172 |
+
self.class_dropout_prob = args.DiT.class_dropout_prob
|
173 |
+
self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
|
174 |
+
self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
|
175 |
+
self.long_skip_connection = args.DiT.long_skip_connection
|
176 |
+
self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
|
177 |
+
|
178 |
+
self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
|
179 |
+
args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
|
180 |
+
args.DiT.hidden_dim)
|
181 |
+
if self.style_as_token:
|
182 |
+
self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
|
183 |
+
|
184 |
+
def setup_caches(self, max_batch_size, max_seq_length):
|
185 |
+
self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
|
186 |
+
def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
|
187 |
+
class_dropout = False
|
188 |
+
if self.training and torch.rand(1) < self.class_dropout_prob:
|
189 |
+
class_dropout = True
|
190 |
+
if not self.training and mask_content:
|
191 |
+
class_dropout = True
|
192 |
+
# cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
|
193 |
+
cond_in_module = self.cond_projection
|
194 |
+
|
195 |
+
B, _, T = x.size()
|
196 |
+
|
197 |
+
|
198 |
+
t1 = self.t_embedder(t) # (N, D)
|
199 |
+
|
200 |
+
cond = cond_in_module(cond)
|
201 |
+
if self.f0_condition and f0 is not None:
|
202 |
+
quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
|
203 |
+
cond = cond + self.f0_embedder(quantized_f0)
|
204 |
+
|
205 |
+
x = x.transpose(1, 2)
|
206 |
+
prompt_x = prompt_x.transpose(1, 2)
|
207 |
+
|
208 |
+
x_in = torch.cat([x, prompt_x, cond], dim=-1)
|
209 |
+
if self.transformer_style_condition and not self.style_as_token:
|
210 |
+
x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
|
211 |
+
if class_dropout:
|
212 |
+
x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
|
213 |
+
x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
|
214 |
+
|
215 |
+
if self.style_as_token:
|
216 |
+
style = self.style_in(style)
|
217 |
+
style = torch.zeros_like(style) if class_dropout else style
|
218 |
+
x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
|
219 |
+
if self.time_as_token:
|
220 |
+
x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
|
221 |
+
x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
|
222 |
+
input_pos = self.input_pos[:x_in.size(1)] # (T,)
|
223 |
+
x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
|
224 |
+
x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
|
225 |
+
x_res = x_res[:, 1:] if self.time_as_token else x_res
|
226 |
+
x_res = x_res[:, 1:] if self.style_as_token else x_res
|
227 |
+
if self.long_skip_connection:
|
228 |
+
x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
|
229 |
+
if self.final_layer_type == 'wavenet':
|
230 |
+
x = self.conv1(x_res)
|
231 |
+
x = x.transpose(1, 2)
|
232 |
+
t2 = self.t_embedder2(t)
|
233 |
+
x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
|
234 |
+
x_res) # long residual connection
|
235 |
+
x = self.final_layer(x, t1).transpose(1, 2)
|
236 |
+
x = self.conv2(x)
|
237 |
+
else:
|
238 |
+
x = self.final_mlp(x_res)
|
239 |
+
x = x.transpose(1, 2)
|
240 |
+
return x
|
modules/encodec.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""Convolutional layers wrappers and utilities."""
|
8 |
+
|
9 |
+
import math
|
10 |
+
import typing as tp
|
11 |
+
import warnings
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from torch.nn.utils import spectral_norm, weight_norm
|
17 |
+
|
18 |
+
import typing as tp
|
19 |
+
|
20 |
+
import einops
|
21 |
+
|
22 |
+
|
23 |
+
class ConvLayerNorm(nn.LayerNorm):
|
24 |
+
"""
|
25 |
+
Convolution-friendly LayerNorm that moves channels to last dimensions
|
26 |
+
before running the normalization and moves them back to original position right after.
|
27 |
+
"""
|
28 |
+
def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
29 |
+
super().__init__(normalized_shape, **kwargs)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = einops.rearrange(x, 'b ... t -> b t ...')
|
33 |
+
x = super().forward(x)
|
34 |
+
x = einops.rearrange(x, 'b t ... -> b ... t')
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
39 |
+
'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
40 |
+
|
41 |
+
|
42 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
43 |
+
assert norm in CONV_NORMALIZATIONS
|
44 |
+
if norm == 'weight_norm':
|
45 |
+
return weight_norm(module)
|
46 |
+
elif norm == 'spectral_norm':
|
47 |
+
return spectral_norm(module)
|
48 |
+
else:
|
49 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
50 |
+
# doesn't need reparametrization.
|
51 |
+
return module
|
52 |
+
|
53 |
+
|
54 |
+
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
55 |
+
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
56 |
+
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
57 |
+
"""
|
58 |
+
assert norm in CONV_NORMALIZATIONS
|
59 |
+
if norm == 'layer_norm':
|
60 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
61 |
+
return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
62 |
+
elif norm == 'time_group_norm':
|
63 |
+
if causal:
|
64 |
+
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
65 |
+
assert isinstance(module, nn.modules.conv._ConvNd)
|
66 |
+
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
67 |
+
else:
|
68 |
+
return nn.Identity()
|
69 |
+
|
70 |
+
|
71 |
+
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
72 |
+
padding_total: int = 0) -> int:
|
73 |
+
"""See `pad_for_conv1d`.
|
74 |
+
"""
|
75 |
+
length = x.shape[-1]
|
76 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
77 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
78 |
+
return ideal_length - length
|
79 |
+
|
80 |
+
|
81 |
+
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
82 |
+
"""Pad for a convolution to make sure that the last window is full.
|
83 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
84 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
85 |
+
might get removed.
|
86 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
87 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
88 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
89 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
90 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
91 |
+
"""
|
92 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
93 |
+
return F.pad(x, (0, extra_padding))
|
94 |
+
|
95 |
+
|
96 |
+
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
97 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
98 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
99 |
+
"""
|
100 |
+
length = x.shape[-1]
|
101 |
+
padding_left, padding_right = paddings
|
102 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
103 |
+
if mode == 'reflect':
|
104 |
+
max_pad = max(padding_left, padding_right)
|
105 |
+
extra_pad = 0
|
106 |
+
if length <= max_pad:
|
107 |
+
extra_pad = max_pad - length + 1
|
108 |
+
x = F.pad(x, (0, extra_pad))
|
109 |
+
padded = F.pad(x, paddings, mode, value)
|
110 |
+
end = padded.shape[-1] - extra_pad
|
111 |
+
return padded[..., :end]
|
112 |
+
else:
|
113 |
+
return F.pad(x, paddings, mode, value)
|
114 |
+
|
115 |
+
|
116 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
117 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
118 |
+
padding_left, padding_right = paddings
|
119 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
120 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
121 |
+
end = x.shape[-1] - padding_right
|
122 |
+
return x[..., padding_left: end]
|
123 |
+
|
124 |
+
|
125 |
+
class NormConv1d(nn.Module):
|
126 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
127 |
+
to provide a uniform interface across normalization approaches.
|
128 |
+
"""
|
129 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
130 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
131 |
+
super().__init__()
|
132 |
+
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
133 |
+
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
134 |
+
self.norm_type = norm
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.conv(x)
|
138 |
+
x = self.norm(x)
|
139 |
+
return x
|
140 |
+
|
141 |
+
|
142 |
+
class NormConv2d(nn.Module):
|
143 |
+
"""Wrapper around Conv2d and normalization applied to this conv
|
144 |
+
to provide a uniform interface across normalization approaches.
|
145 |
+
"""
|
146 |
+
def __init__(self, *args, norm: str = 'none',
|
147 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
148 |
+
super().__init__()
|
149 |
+
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
150 |
+
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
151 |
+
self.norm_type = norm
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
x = self.conv(x)
|
155 |
+
x = self.norm(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class NormConvTranspose1d(nn.Module):
|
160 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
161 |
+
to provide a uniform interface across normalization approaches.
|
162 |
+
"""
|
163 |
+
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
164 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
165 |
+
super().__init__()
|
166 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
167 |
+
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
168 |
+
self.norm_type = norm
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
x = self.convtr(x)
|
172 |
+
x = self.norm(x)
|
173 |
+
return x
|
174 |
+
|
175 |
+
|
176 |
+
class NormConvTranspose2d(nn.Module):
|
177 |
+
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
178 |
+
to provide a uniform interface across normalization approaches.
|
179 |
+
"""
|
180 |
+
def __init__(self, *args, norm: str = 'none',
|
181 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
182 |
+
super().__init__()
|
183 |
+
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
184 |
+
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.convtr(x)
|
188 |
+
x = self.norm(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class SConv1d(nn.Module):
|
193 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
194 |
+
and normalization.
|
195 |
+
"""
|
196 |
+
def __init__(self, in_channels: int, out_channels: int,
|
197 |
+
kernel_size: int, stride: int = 1, dilation: int = 1,
|
198 |
+
groups: int = 1, bias: bool = True, causal: bool = False,
|
199 |
+
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
200 |
+
pad_mode: str = 'reflect', **kwargs):
|
201 |
+
super().__init__()
|
202 |
+
# warn user on unusual setup between dilation and stride
|
203 |
+
if stride > 1 and dilation > 1:
|
204 |
+
warnings.warn('SConv1d has been initialized with stride > 1 and dilation > 1'
|
205 |
+
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
206 |
+
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
207 |
+
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
208 |
+
norm=norm, norm_kwargs=norm_kwargs)
|
209 |
+
self.causal = causal
|
210 |
+
self.pad_mode = pad_mode
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
B, C, T = x.shape
|
214 |
+
kernel_size = self.conv.conv.kernel_size[0]
|
215 |
+
stride = self.conv.conv.stride[0]
|
216 |
+
dilation = self.conv.conv.dilation[0]
|
217 |
+
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
218 |
+
padding_total = kernel_size - stride
|
219 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
220 |
+
if self.causal:
|
221 |
+
# Left padding for causal
|
222 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
223 |
+
else:
|
224 |
+
# Asymmetric padding required for odd strides
|
225 |
+
padding_right = padding_total // 2
|
226 |
+
padding_left = padding_total - padding_right
|
227 |
+
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
228 |
+
return self.conv(x)
|
229 |
+
|
230 |
+
|
231 |
+
class SConvTranspose1d(nn.Module):
|
232 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
233 |
+
and normalization.
|
234 |
+
"""
|
235 |
+
def __init__(self, in_channels: int, out_channels: int,
|
236 |
+
kernel_size: int, stride: int = 1, causal: bool = False,
|
237 |
+
norm: str = 'none', trim_right_ratio: float = 1.,
|
238 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
239 |
+
super().__init__()
|
240 |
+
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
241 |
+
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
242 |
+
self.causal = causal
|
243 |
+
self.trim_right_ratio = trim_right_ratio
|
244 |
+
assert self.causal or self.trim_right_ratio == 1., \
|
245 |
+
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
246 |
+
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
250 |
+
stride = self.convtr.convtr.stride[0]
|
251 |
+
padding_total = kernel_size - stride
|
252 |
+
|
253 |
+
y = self.convtr(x)
|
254 |
+
|
255 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
256 |
+
# removed at the very end, when keeping only the right length for the output,
|
257 |
+
# as removing it here would require also passing the length at the matching layer
|
258 |
+
# in the encoder.
|
259 |
+
if self.causal:
|
260 |
+
# Trim the padding on the right according to the specified ratio
|
261 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
262 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
263 |
+
padding_left = padding_total - padding_right
|
264 |
+
y = unpad1d(y, (padding_left, padding_right))
|
265 |
+
else:
|
266 |
+
# Asymmetric padding required for odd strides
|
267 |
+
padding_right = padding_total // 2
|
268 |
+
padding_left = padding_total - padding_right
|
269 |
+
y = unpad1d(y, (padding_left, padding_right))
|
270 |
+
return y
|
271 |
+
|
272 |
+
class SLSTM(nn.Module):
|
273 |
+
"""
|
274 |
+
LSTM without worrying about the hidden state, nor the layout of the data.
|
275 |
+
Expects input as convolutional layout.
|
276 |
+
"""
|
277 |
+
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
278 |
+
super().__init__()
|
279 |
+
self.skip = skip
|
280 |
+
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
281 |
+
self.hidden = None
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
x = x.permute(2, 0, 1)
|
285 |
+
if self.training:
|
286 |
+
y, _ = self.lstm(x)
|
287 |
+
else:
|
288 |
+
y, self.hidden = self.lstm(x, self.hidden)
|
289 |
+
if self.skip:
|
290 |
+
y = y + x
|
291 |
+
y = y.permute(1, 2, 0)
|
292 |
+
return y
|
modules/flow_matching.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from modules.diffusion_transformer import DiT
|
7 |
+
from modules.commons import sequence_mask
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
class BASECFM(torch.nn.Module, ABC):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
args,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.sigma_min = 1e-6
|
18 |
+
|
19 |
+
self.estimator = None
|
20 |
+
|
21 |
+
self.in_channels = args.DiT.in_channels
|
22 |
+
|
23 |
+
self.criterion = torch.nn.MSELoss() if args.reg_loss_type == "l2" else torch.nn.L1Loss()
|
24 |
+
|
25 |
+
if hasattr(args.DiT, 'zero_prompt_speech_token'):
|
26 |
+
self.zero_prompt_speech_token = args.DiT.zero_prompt_speech_token
|
27 |
+
else:
|
28 |
+
self.zero_prompt_speech_token = False
|
29 |
+
|
30 |
+
@torch.inference_mode()
|
31 |
+
def inference(self, mu, x_lens, prompt, style, f0, n_timesteps, temperature=1.0, inference_cfg_rate=0.5):
|
32 |
+
"""Forward diffusion
|
33 |
+
|
34 |
+
Args:
|
35 |
+
mu (torch.Tensor): output of encoder
|
36 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
37 |
+
mask (torch.Tensor): output_mask
|
38 |
+
shape: (batch_size, 1, mel_timesteps)
|
39 |
+
n_timesteps (int): number of diffusion steps
|
40 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
41 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
42 |
+
shape: (batch_size, spk_emb_dim)
|
43 |
+
cond: Not used but kept for future purposes
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
sample: generated mel-spectrogram
|
47 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
48 |
+
"""
|
49 |
+
B, T = mu.size(0), mu.size(1)
|
50 |
+
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
|
51 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
|
52 |
+
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
|
53 |
+
|
54 |
+
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
|
55 |
+
"""
|
56 |
+
Fixed euler solver for ODEs.
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): random noise
|
59 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
60 |
+
shape: (n_timesteps + 1,)
|
61 |
+
mu (torch.Tensor): output of encoder
|
62 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
63 |
+
mask (torch.Tensor): output_mask
|
64 |
+
shape: (batch_size, 1, mel_timesteps)
|
65 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
66 |
+
shape: (batch_size, spk_emb_dim)
|
67 |
+
cond: Not used but kept for future purposes
|
68 |
+
"""
|
69 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
70 |
+
|
71 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
72 |
+
# Or in future might add like a return_all_steps flag
|
73 |
+
sol = []
|
74 |
+
# apply prompt
|
75 |
+
prompt_len = prompt.size(-1)
|
76 |
+
prompt_x = torch.zeros_like(x)
|
77 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
78 |
+
x[..., :prompt_len] = 0
|
79 |
+
if self.zero_prompt_speech_token:
|
80 |
+
mu[..., :prompt_len] = 0
|
81 |
+
for step in tqdm(range(1, len(t_span))):
|
82 |
+
dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
|
83 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
84 |
+
if inference_cfg_rate > 0:
|
85 |
+
cfg_dphi_dt = self.estimator(
|
86 |
+
x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
|
87 |
+
torch.zeros_like(style),
|
88 |
+
torch.zeros_like(mu), None
|
89 |
+
)
|
90 |
+
dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
|
91 |
+
inference_cfg_rate * cfg_dphi_dt)
|
92 |
+
x = x + dt * dphi_dt
|
93 |
+
t = t + dt
|
94 |
+
sol.append(x)
|
95 |
+
if step < len(t_span) - 1:
|
96 |
+
dt = t_span[step + 1] - t
|
97 |
+
x[:, :, :prompt_len] = 0
|
98 |
+
|
99 |
+
return sol[-1]
|
100 |
+
|
101 |
+
def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
|
102 |
+
"""Computes diffusion loss
|
103 |
+
|
104 |
+
Args:
|
105 |
+
x1 (torch.Tensor): Target
|
106 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
107 |
+
mask (torch.Tensor): target mask
|
108 |
+
shape: (batch_size, 1, mel_timesteps)
|
109 |
+
mu (torch.Tensor): output of encoder
|
110 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
111 |
+
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
112 |
+
shape: (batch_size, spk_emb_dim)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
loss: conditional flow matching loss
|
116 |
+
y: conditional flow
|
117 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
118 |
+
"""
|
119 |
+
b, _, t = x1.shape
|
120 |
+
|
121 |
+
# random timestep
|
122 |
+
t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
|
123 |
+
# sample noise p(x_0)
|
124 |
+
z = torch.randn_like(x1)
|
125 |
+
|
126 |
+
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
127 |
+
u = x1 - (1 - self.sigma_min) * z
|
128 |
+
|
129 |
+
prompt = torch.zeros_like(x1)
|
130 |
+
for bib in range(b):
|
131 |
+
prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
|
132 |
+
# range covered by prompt are set to 0
|
133 |
+
y[bib, :, :prompt_lens[bib]] = 0
|
134 |
+
if self.zero_prompt_speech_token:
|
135 |
+
mu[bib, :, :prompt_lens[bib]] = 0
|
136 |
+
|
137 |
+
estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
|
138 |
+
loss = 0
|
139 |
+
for bib in range(b):
|
140 |
+
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
|
141 |
+
loss /= b
|
142 |
+
|
143 |
+
return loss, y
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
class CFM(BASECFM):
|
148 |
+
def __init__(self, args):
|
149 |
+
super().__init__(
|
150 |
+
args
|
151 |
+
)
|
152 |
+
if args.dit_type == "DiT":
|
153 |
+
self.estimator = DiT(args)
|
154 |
+
else:
|
155 |
+
raise NotImplementedError(f"Unknown diffusion type {args.dit_type}")
|
modules/layers.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from typing import Optional, Any
|
5 |
+
from torch import Tensor
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.functional as audio_F
|
9 |
+
|
10 |
+
import random
|
11 |
+
random.seed(0)
|
12 |
+
|
13 |
+
|
14 |
+
def _get_activation_fn(activ):
|
15 |
+
if activ == 'relu':
|
16 |
+
return nn.ReLU()
|
17 |
+
elif activ == 'lrelu':
|
18 |
+
return nn.LeakyReLU(0.2)
|
19 |
+
elif activ == 'swish':
|
20 |
+
return lambda x: x*torch.sigmoid(x)
|
21 |
+
else:
|
22 |
+
raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
|
23 |
+
|
24 |
+
class LinearNorm(torch.nn.Module):
|
25 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
26 |
+
super(LinearNorm, self).__init__()
|
27 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
28 |
+
|
29 |
+
torch.nn.init.xavier_uniform_(
|
30 |
+
self.linear_layer.weight,
|
31 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return self.linear_layer(x)
|
35 |
+
|
36 |
+
|
37 |
+
class ConvNorm(torch.nn.Module):
|
38 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
39 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
40 |
+
super(ConvNorm, self).__init__()
|
41 |
+
if padding is None:
|
42 |
+
assert(kernel_size % 2 == 1)
|
43 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
44 |
+
|
45 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
46 |
+
kernel_size=kernel_size, stride=stride,
|
47 |
+
padding=padding, dilation=dilation,
|
48 |
+
bias=bias)
|
49 |
+
|
50 |
+
torch.nn.init.xavier_uniform_(
|
51 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
52 |
+
|
53 |
+
def forward(self, signal):
|
54 |
+
conv_signal = self.conv(signal)
|
55 |
+
return conv_signal
|
56 |
+
|
57 |
+
class CausualConv(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
|
59 |
+
super(CausualConv, self).__init__()
|
60 |
+
if padding is None:
|
61 |
+
assert(kernel_size % 2 == 1)
|
62 |
+
padding = int(dilation * (kernel_size - 1) / 2) * 2
|
63 |
+
else:
|
64 |
+
self.padding = padding * 2
|
65 |
+
self.conv = nn.Conv1d(in_channels, out_channels,
|
66 |
+
kernel_size=kernel_size, stride=stride,
|
67 |
+
padding=self.padding,
|
68 |
+
dilation=dilation,
|
69 |
+
bias=bias)
|
70 |
+
|
71 |
+
torch.nn.init.xavier_uniform_(
|
72 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
x = self.conv(x)
|
76 |
+
x = x[:, :, :-self.padding]
|
77 |
+
return x
|
78 |
+
|
79 |
+
class CausualBlock(nn.Module):
|
80 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
|
81 |
+
super(CausualBlock, self).__init__()
|
82 |
+
self.blocks = nn.ModuleList([
|
83 |
+
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
|
84 |
+
for i in range(n_conv)])
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
for block in self.blocks:
|
88 |
+
res = x
|
89 |
+
x = block(x)
|
90 |
+
x += res
|
91 |
+
return x
|
92 |
+
|
93 |
+
def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
|
94 |
+
layers = [
|
95 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
|
96 |
+
_get_activation_fn(activ),
|
97 |
+
nn.BatchNorm1d(hidden_dim),
|
98 |
+
nn.Dropout(p=dropout_p),
|
99 |
+
CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
100 |
+
_get_activation_fn(activ),
|
101 |
+
nn.Dropout(p=dropout_p)
|
102 |
+
]
|
103 |
+
return nn.Sequential(*layers)
|
104 |
+
|
105 |
+
class ConvBlock(nn.Module):
|
106 |
+
def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
|
107 |
+
super().__init__()
|
108 |
+
self._n_groups = 8
|
109 |
+
self.blocks = nn.ModuleList([
|
110 |
+
self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
|
111 |
+
for i in range(n_conv)])
|
112 |
+
|
113 |
+
|
114 |
+
def forward(self, x):
|
115 |
+
for block in self.blocks:
|
116 |
+
res = x
|
117 |
+
x = block(x)
|
118 |
+
x += res
|
119 |
+
return x
|
120 |
+
|
121 |
+
def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
|
122 |
+
layers = [
|
123 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
|
124 |
+
_get_activation_fn(activ),
|
125 |
+
nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
|
126 |
+
nn.Dropout(p=dropout_p),
|
127 |
+
ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
|
128 |
+
_get_activation_fn(activ),
|
129 |
+
nn.Dropout(p=dropout_p)
|
130 |
+
]
|
131 |
+
return nn.Sequential(*layers)
|
132 |
+
|
133 |
+
class LocationLayer(nn.Module):
|
134 |
+
def __init__(self, attention_n_filters, attention_kernel_size,
|
135 |
+
attention_dim):
|
136 |
+
super(LocationLayer, self).__init__()
|
137 |
+
padding = int((attention_kernel_size - 1) / 2)
|
138 |
+
self.location_conv = ConvNorm(2, attention_n_filters,
|
139 |
+
kernel_size=attention_kernel_size,
|
140 |
+
padding=padding, bias=False, stride=1,
|
141 |
+
dilation=1)
|
142 |
+
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
|
143 |
+
bias=False, w_init_gain='tanh')
|
144 |
+
|
145 |
+
def forward(self, attention_weights_cat):
|
146 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
147 |
+
processed_attention = processed_attention.transpose(1, 2)
|
148 |
+
processed_attention = self.location_dense(processed_attention)
|
149 |
+
return processed_attention
|
150 |
+
|
151 |
+
|
152 |
+
class Attention(nn.Module):
|
153 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
154 |
+
attention_location_n_filters, attention_location_kernel_size):
|
155 |
+
super(Attention, self).__init__()
|
156 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
157 |
+
bias=False, w_init_gain='tanh')
|
158 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
159 |
+
w_init_gain='tanh')
|
160 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
161 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
162 |
+
attention_location_kernel_size,
|
163 |
+
attention_dim)
|
164 |
+
self.score_mask_value = -float("inf")
|
165 |
+
|
166 |
+
def get_alignment_energies(self, query, processed_memory,
|
167 |
+
attention_weights_cat):
|
168 |
+
"""
|
169 |
+
PARAMS
|
170 |
+
------
|
171 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
172 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
173 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
174 |
+
RETURNS
|
175 |
+
-------
|
176 |
+
alignment (batch, max_time)
|
177 |
+
"""
|
178 |
+
|
179 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
180 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
181 |
+
energies = self.v(torch.tanh(
|
182 |
+
processed_query + processed_attention_weights + processed_memory))
|
183 |
+
|
184 |
+
energies = energies.squeeze(-1)
|
185 |
+
return energies
|
186 |
+
|
187 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
188 |
+
attention_weights_cat, mask):
|
189 |
+
"""
|
190 |
+
PARAMS
|
191 |
+
------
|
192 |
+
attention_hidden_state: attention rnn last output
|
193 |
+
memory: encoder outputs
|
194 |
+
processed_memory: processed encoder outputs
|
195 |
+
attention_weights_cat: previous and cummulative attention weights
|
196 |
+
mask: binary mask for padded data
|
197 |
+
"""
|
198 |
+
alignment = self.get_alignment_energies(
|
199 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
200 |
+
|
201 |
+
if mask is not None:
|
202 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
203 |
+
|
204 |
+
attention_weights = F.softmax(alignment, dim=1)
|
205 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
206 |
+
attention_context = attention_context.squeeze(1)
|
207 |
+
|
208 |
+
return attention_context, attention_weights
|
209 |
+
|
210 |
+
|
211 |
+
class ForwardAttentionV2(nn.Module):
|
212 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
213 |
+
attention_location_n_filters, attention_location_kernel_size):
|
214 |
+
super(ForwardAttentionV2, self).__init__()
|
215 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
216 |
+
bias=False, w_init_gain='tanh')
|
217 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
218 |
+
w_init_gain='tanh')
|
219 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
220 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
221 |
+
attention_location_kernel_size,
|
222 |
+
attention_dim)
|
223 |
+
self.score_mask_value = -float(1e20)
|
224 |
+
|
225 |
+
def get_alignment_energies(self, query, processed_memory,
|
226 |
+
attention_weights_cat):
|
227 |
+
"""
|
228 |
+
PARAMS
|
229 |
+
------
|
230 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
231 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
232 |
+
attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
|
233 |
+
RETURNS
|
234 |
+
-------
|
235 |
+
alignment (batch, max_time)
|
236 |
+
"""
|
237 |
+
|
238 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
239 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
240 |
+
energies = self.v(torch.tanh(
|
241 |
+
processed_query + processed_attention_weights + processed_memory))
|
242 |
+
|
243 |
+
energies = energies.squeeze(-1)
|
244 |
+
return energies
|
245 |
+
|
246 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
247 |
+
attention_weights_cat, mask, log_alpha):
|
248 |
+
"""
|
249 |
+
PARAMS
|
250 |
+
------
|
251 |
+
attention_hidden_state: attention rnn last output
|
252 |
+
memory: encoder outputs
|
253 |
+
processed_memory: processed encoder outputs
|
254 |
+
attention_weights_cat: previous and cummulative attention weights
|
255 |
+
mask: binary mask for padded data
|
256 |
+
"""
|
257 |
+
log_energy = self.get_alignment_energies(
|
258 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
259 |
+
|
260 |
+
#log_energy =
|
261 |
+
|
262 |
+
if mask is not None:
|
263 |
+
log_energy.data.masked_fill_(mask, self.score_mask_value)
|
264 |
+
|
265 |
+
#attention_weights = F.softmax(alignment, dim=1)
|
266 |
+
|
267 |
+
#content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
|
268 |
+
#log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
|
269 |
+
|
270 |
+
#log_total_score = log_alpha + content_score
|
271 |
+
|
272 |
+
#previous_attention_weights = attention_weights_cat[:,0,:]
|
273 |
+
|
274 |
+
log_alpha_shift_padded = []
|
275 |
+
max_time = log_energy.size(1)
|
276 |
+
for sft in range(2):
|
277 |
+
shifted = log_alpha[:,:max_time-sft]
|
278 |
+
shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
|
279 |
+
log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
|
280 |
+
|
281 |
+
biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
|
282 |
+
|
283 |
+
log_alpha_new = biased + log_energy
|
284 |
+
|
285 |
+
attention_weights = F.softmax(log_alpha_new, dim=1)
|
286 |
+
|
287 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
288 |
+
attention_context = attention_context.squeeze(1)
|
289 |
+
|
290 |
+
return attention_context, attention_weights, log_alpha_new
|
291 |
+
|
292 |
+
|
293 |
+
class PhaseShuffle2d(nn.Module):
|
294 |
+
def __init__(self, n=2):
|
295 |
+
super(PhaseShuffle2d, self).__init__()
|
296 |
+
self.n = n
|
297 |
+
self.random = random.Random(1)
|
298 |
+
|
299 |
+
def forward(self, x, move=None):
|
300 |
+
# x.size = (B, C, M, L)
|
301 |
+
if move is None:
|
302 |
+
move = self.random.randint(-self.n, self.n)
|
303 |
+
|
304 |
+
if move == 0:
|
305 |
+
return x
|
306 |
+
else:
|
307 |
+
left = x[:, :, :, :move]
|
308 |
+
right = x[:, :, :, move:]
|
309 |
+
shuffled = torch.cat([right, left], dim=3)
|
310 |
+
return shuffled
|
311 |
+
|
312 |
+
class PhaseShuffle1d(nn.Module):
|
313 |
+
def __init__(self, n=2):
|
314 |
+
super(PhaseShuffle1d, self).__init__()
|
315 |
+
self.n = n
|
316 |
+
self.random = random.Random(1)
|
317 |
+
|
318 |
+
def forward(self, x, move=None):
|
319 |
+
# x.size = (B, C, M, L)
|
320 |
+
if move is None:
|
321 |
+
move = self.random.randint(-self.n, self.n)
|
322 |
+
|
323 |
+
if move == 0:
|
324 |
+
return x
|
325 |
+
else:
|
326 |
+
left = x[:, :, :move]
|
327 |
+
right = x[:, :, move:]
|
328 |
+
shuffled = torch.cat([right, left], dim=2)
|
329 |
+
|
330 |
+
return shuffled
|
331 |
+
|
332 |
+
class MFCC(nn.Module):
|
333 |
+
def __init__(self, n_mfcc=40, n_mels=80):
|
334 |
+
super(MFCC, self).__init__()
|
335 |
+
self.n_mfcc = n_mfcc
|
336 |
+
self.n_mels = n_mels
|
337 |
+
self.norm = 'ortho'
|
338 |
+
dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
|
339 |
+
self.register_buffer('dct_mat', dct_mat)
|
340 |
+
|
341 |
+
def forward(self, mel_specgram):
|
342 |
+
if len(mel_specgram.shape) == 2:
|
343 |
+
mel_specgram = mel_specgram.unsqueeze(0)
|
344 |
+
unsqueezed = True
|
345 |
+
else:
|
346 |
+
unsqueezed = False
|
347 |
+
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
|
348 |
+
# -> (channel, time, n_mfcc).tranpose(...)
|
349 |
+
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
|
350 |
+
|
351 |
+
# unpack batch
|
352 |
+
if unsqueezed:
|
353 |
+
mfcc = mfcc.squeeze(0)
|
354 |
+
return mfcc
|
modules/length_regulator.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from modules.commons import sequence_mask
|
6 |
+
import numpy as np
|
7 |
+
from dac.nn.quantize import VectorQuantize
|
8 |
+
|
9 |
+
# f0_bin = 256
|
10 |
+
f0_max = 1100.0
|
11 |
+
f0_min = 50.0
|
12 |
+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
13 |
+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
14 |
+
|
15 |
+
def f0_to_coarse(f0, f0_bin):
|
16 |
+
f0_mel = 1127 * (1 + f0 / 700).log()
|
17 |
+
a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
|
18 |
+
b = f0_mel_min * a - 1.
|
19 |
+
f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
|
20 |
+
# torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
|
21 |
+
f0_coarse = torch.round(f0_mel).long()
|
22 |
+
f0_coarse = f0_coarse * (f0_coarse > 0)
|
23 |
+
f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
|
24 |
+
f0_coarse = f0_coarse * (f0_coarse < f0_bin)
|
25 |
+
f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
|
26 |
+
return f0_coarse
|
27 |
+
|
28 |
+
class InterpolateRegulator(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
channels: int,
|
32 |
+
sampling_ratios: Tuple,
|
33 |
+
is_discrete: bool = False,
|
34 |
+
in_channels: int = None, # only applies to continuous input
|
35 |
+
vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
|
36 |
+
codebook_size: int = 1024, # for discrete only
|
37 |
+
out_channels: int = None,
|
38 |
+
groups: int = 1,
|
39 |
+
n_codebooks: int = 1, # number of codebooks
|
40 |
+
quantizer_dropout: float = 0.0, # dropout for quantizer
|
41 |
+
f0_condition: bool = False,
|
42 |
+
n_f0_bins: int = 512,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self.sampling_ratios = sampling_ratios
|
46 |
+
out_channels = out_channels or channels
|
47 |
+
model = nn.ModuleList([])
|
48 |
+
if len(sampling_ratios) > 0:
|
49 |
+
self.interpolate = True
|
50 |
+
for _ in sampling_ratios:
|
51 |
+
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
52 |
+
norm = nn.GroupNorm(groups, channels)
|
53 |
+
act = nn.Mish()
|
54 |
+
model.extend([module, norm, act])
|
55 |
+
else:
|
56 |
+
self.interpolate = False
|
57 |
+
model.append(
|
58 |
+
nn.Conv1d(channels, out_channels, 1, 1)
|
59 |
+
)
|
60 |
+
self.model = nn.Sequential(*model)
|
61 |
+
self.embedding = nn.Embedding(codebook_size, channels)
|
62 |
+
self.is_discrete = is_discrete
|
63 |
+
|
64 |
+
self.mask_token = nn.Parameter(torch.zeros(1, channels))
|
65 |
+
|
66 |
+
self.n_codebooks = n_codebooks
|
67 |
+
if n_codebooks > 1:
|
68 |
+
self.extra_codebooks = nn.ModuleList([
|
69 |
+
nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
|
70 |
+
])
|
71 |
+
self.extra_codebook_mask_tokens = nn.ParameterList([
|
72 |
+
nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
|
73 |
+
])
|
74 |
+
self.quantizer_dropout = quantizer_dropout
|
75 |
+
|
76 |
+
if f0_condition:
|
77 |
+
self.f0_embedding = nn.Embedding(n_f0_bins, channels)
|
78 |
+
self.f0_condition = f0_condition
|
79 |
+
self.n_f0_bins = n_f0_bins
|
80 |
+
self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
|
81 |
+
self.f0_mask = nn.Parameter(torch.zeros(1, channels))
|
82 |
+
else:
|
83 |
+
self.f0_condition = False
|
84 |
+
|
85 |
+
if not is_discrete:
|
86 |
+
self.content_in_proj = nn.Linear(in_channels, channels)
|
87 |
+
if vector_quantize:
|
88 |
+
self.vq = VectorQuantize(channels, codebook_size, 8)
|
89 |
+
|
90 |
+
def forward(self, x, ylens=None, n_quantizers=None, f0=None):
|
91 |
+
# apply token drop
|
92 |
+
if self.training:
|
93 |
+
n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
|
94 |
+
dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
|
95 |
+
n_dropout = int(x.shape[0] * self.quantizer_dropout)
|
96 |
+
n_quantizers[:n_dropout] = dropout[:n_dropout]
|
97 |
+
n_quantizers = n_quantizers.to(x.device)
|
98 |
+
# decide whether to drop for each sample in batch
|
99 |
+
else:
|
100 |
+
n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
|
101 |
+
if self.is_discrete:
|
102 |
+
if self.n_codebooks > 1:
|
103 |
+
assert len(x.size()) == 3
|
104 |
+
x_emb = self.embedding(x[:, 0])
|
105 |
+
for i, emb in enumerate(self.extra_codebooks):
|
106 |
+
x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
|
107 |
+
# add mask token if not using this codebook
|
108 |
+
# x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
|
109 |
+
x = x_emb
|
110 |
+
elif self.n_codebooks == 1:
|
111 |
+
if len(x.size()) == 2:
|
112 |
+
x = self.embedding(x)
|
113 |
+
else:
|
114 |
+
x = self.embedding(x[:, 0])
|
115 |
+
else:
|
116 |
+
x = self.content_in_proj(x)
|
117 |
+
# x in (B, T, D)
|
118 |
+
mask = sequence_mask(ylens).unsqueeze(-1)
|
119 |
+
if self.interpolate:
|
120 |
+
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
121 |
+
else:
|
122 |
+
x = x.transpose(1, 2).contiguous()
|
123 |
+
mask = mask[:, :x.size(2), :]
|
124 |
+
ylens = ylens.clamp(max=x.size(2)).long()
|
125 |
+
if self.f0_condition:
|
126 |
+
if f0 is None:
|
127 |
+
x = x + self.f0_mask.unsqueeze(-1)
|
128 |
+
else:
|
129 |
+
#quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
|
130 |
+
quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
|
131 |
+
quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
|
132 |
+
f0_emb = self.f0_embedding(quantized_f0)
|
133 |
+
f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
|
134 |
+
x = x + f0_emb
|
135 |
+
out = self.model(x).transpose(1, 2).contiguous()
|
136 |
+
if hasattr(self, 'vq'):
|
137 |
+
out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
|
138 |
+
out_q = out_q.transpose(1, 2)
|
139 |
+
return out_q * mask, ylens, codes, commitment_loss, codebook_loss
|
140 |
+
olens = ylens
|
141 |
+
return out * mask, olens, None, None, None
|
modules/quantize.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dac.nn.quantize import ResidualVectorQuantize
|
2 |
+
from torch import nn
|
3 |
+
from modules.wavenet import WN
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import torchaudio.functional as audio_F
|
7 |
+
import numpy as np
|
8 |
+
from .alias_free_torch import *
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
from torch import nn, sin, pow
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from dac.model.encodec import SConv1d
|
13 |
+
|
14 |
+
def init_weights(m):
|
15 |
+
if isinstance(m, nn.Conv1d):
|
16 |
+
nn.init.trunc_normal_(m.weight, std=0.02)
|
17 |
+
nn.init.constant_(m.bias, 0)
|
18 |
+
|
19 |
+
|
20 |
+
def WNConv1d(*args, **kwargs):
|
21 |
+
return weight_norm(nn.Conv1d(*args, **kwargs))
|
22 |
+
|
23 |
+
|
24 |
+
def WNConvTranspose1d(*args, **kwargs):
|
25 |
+
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
26 |
+
|
27 |
+
class SnakeBeta(nn.Module):
|
28 |
+
"""
|
29 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
30 |
+
Shape:
|
31 |
+
- Input: (B, C, T)
|
32 |
+
- Output: (B, C, T), same shape as the input
|
33 |
+
Parameters:
|
34 |
+
- alpha - trainable parameter that controls frequency
|
35 |
+
- beta - trainable parameter that controls magnitude
|
36 |
+
References:
|
37 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
38 |
+
https://arxiv.org/abs/2006.08195
|
39 |
+
Examples:
|
40 |
+
>>> a1 = snakebeta(256)
|
41 |
+
>>> x = torch.randn(256)
|
42 |
+
>>> x = a1(x)
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Initialization.
|
50 |
+
INPUT:
|
51 |
+
- in_features: shape of the input
|
52 |
+
- alpha - trainable parameter that controls frequency
|
53 |
+
- beta - trainable parameter that controls magnitude
|
54 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
55 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
56 |
+
alpha will be trained along with the rest of your model.
|
57 |
+
"""
|
58 |
+
super(SnakeBeta, self).__init__()
|
59 |
+
self.in_features = in_features
|
60 |
+
|
61 |
+
# initialize alpha
|
62 |
+
self.alpha_logscale = alpha_logscale
|
63 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
64 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
65 |
+
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
|
66 |
+
else: # linear scale alphas initialized to ones
|
67 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
68 |
+
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
|
69 |
+
|
70 |
+
self.alpha.requires_grad = alpha_trainable
|
71 |
+
self.beta.requires_grad = alpha_trainable
|
72 |
+
|
73 |
+
self.no_div_by_zero = 0.000000001
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
"""
|
77 |
+
Forward pass of the function.
|
78 |
+
Applies the function to the input elementwise.
|
79 |
+
SnakeBeta := x + 1/b * sin^2 (xa)
|
80 |
+
"""
|
81 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
82 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
83 |
+
if self.alpha_logscale:
|
84 |
+
alpha = torch.exp(alpha)
|
85 |
+
beta = torch.exp(beta)
|
86 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
87 |
+
|
88 |
+
return x
|
89 |
+
|
90 |
+
class ResidualUnit(nn.Module):
|
91 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
92 |
+
super().__init__()
|
93 |
+
pad = ((7 - 1) * dilation) // 2
|
94 |
+
self.block = nn.Sequential(
|
95 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
96 |
+
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
|
97 |
+
Activation1d(activation=SnakeBeta(dim, alpha_logscale=True)),
|
98 |
+
WNConv1d(dim, dim, kernel_size=1),
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, x):
|
102 |
+
return x + self.block(x)
|
103 |
+
|
104 |
+
class CNNLSTM(nn.Module):
|
105 |
+
def __init__(self, indim, outdim, head, global_pred=False):
|
106 |
+
super().__init__()
|
107 |
+
self.global_pred = global_pred
|
108 |
+
self.model = nn.Sequential(
|
109 |
+
ResidualUnit(indim, dilation=1),
|
110 |
+
ResidualUnit(indim, dilation=2),
|
111 |
+
ResidualUnit(indim, dilation=3),
|
112 |
+
Activation1d(activation=SnakeBeta(indim, alpha_logscale=True)),
|
113 |
+
Rearrange("b c t -> b t c"),
|
114 |
+
)
|
115 |
+
self.heads = nn.ModuleList([nn.Linear(indim, outdim) for i in range(head)])
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
# x: [B, C, T]
|
119 |
+
x = self.model(x)
|
120 |
+
if self.global_pred:
|
121 |
+
x = torch.mean(x, dim=1, keepdim=False)
|
122 |
+
outs = [head(x) for head in self.heads]
|
123 |
+
return outs
|
124 |
+
|
125 |
+
def sequence_mask(length, max_length=None):
|
126 |
+
if max_length is None:
|
127 |
+
max_length = length.max()
|
128 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
129 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
130 |
+
class FAquantizer(nn.Module):
|
131 |
+
def __init__(self, in_dim=1024,
|
132 |
+
n_p_codebooks=1,
|
133 |
+
n_c_codebooks=2,
|
134 |
+
n_t_codebooks=2,
|
135 |
+
n_r_codebooks=3,
|
136 |
+
codebook_size=1024,
|
137 |
+
codebook_dim=8,
|
138 |
+
quantizer_dropout=0.5,
|
139 |
+
causal=False,
|
140 |
+
separate_prosody_encoder=False,
|
141 |
+
timbre_norm=False,):
|
142 |
+
super(FAquantizer, self).__init__()
|
143 |
+
conv1d_type = SConv1d# if causal else nn.Conv1d
|
144 |
+
self.prosody_quantizer = ResidualVectorQuantize(
|
145 |
+
input_dim=in_dim,
|
146 |
+
n_codebooks=n_p_codebooks,
|
147 |
+
codebook_size=codebook_size,
|
148 |
+
codebook_dim=codebook_dim,
|
149 |
+
quantizer_dropout=quantizer_dropout,
|
150 |
+
)
|
151 |
+
|
152 |
+
self.content_quantizer = ResidualVectorQuantize(
|
153 |
+
input_dim=in_dim,
|
154 |
+
n_codebooks=n_c_codebooks,
|
155 |
+
codebook_size=codebook_size,
|
156 |
+
codebook_dim=codebook_dim,
|
157 |
+
quantizer_dropout=quantizer_dropout,
|
158 |
+
)
|
159 |
+
|
160 |
+
self.residual_quantizer = ResidualVectorQuantize(
|
161 |
+
input_dim=in_dim,
|
162 |
+
n_codebooks=n_r_codebooks,
|
163 |
+
codebook_size=codebook_size,
|
164 |
+
codebook_dim=codebook_dim,
|
165 |
+
quantizer_dropout=quantizer_dropout,
|
166 |
+
)
|
167 |
+
|
168 |
+
self.melspec_linear = conv1d_type(in_channels=20, out_channels=256, kernel_size=1, causal=causal)
|
169 |
+
self.melspec_encoder = WN(hidden_channels=256, kernel_size=5, dilation_rate=1, n_layers=8, gin_channels=0, p_dropout=0.2, causal=causal)
|
170 |
+
self.melspec_linear2 = conv1d_type(in_channels=256, out_channels=1024, kernel_size=1, causal=causal)
|
171 |
+
|
172 |
+
self.prob_random_mask_residual = 0.75
|
173 |
+
|
174 |
+
SPECT_PARAMS = {
|
175 |
+
"n_fft": 2048,
|
176 |
+
"win_length": 1200,
|
177 |
+
"hop_length": 300,
|
178 |
+
}
|
179 |
+
MEL_PARAMS = {
|
180 |
+
"n_mels": 80,
|
181 |
+
}
|
182 |
+
|
183 |
+
self.to_mel = torchaudio.transforms.MelSpectrogram(
|
184 |
+
n_mels=MEL_PARAMS["n_mels"], sample_rate=24000, **SPECT_PARAMS
|
185 |
+
)
|
186 |
+
self.mel_mean, self.mel_std = -4, 4
|
187 |
+
self.frame_rate = 24000 / 300
|
188 |
+
self.hop_length = 300
|
189 |
+
|
190 |
+
def preprocess(self, wave_tensor, n_bins=20):
|
191 |
+
mel_tensor = self.to_mel(wave_tensor.squeeze(1))
|
192 |
+
mel_tensor = (torch.log(1e-5 + mel_tensor) - self.mel_mean) / self.mel_std
|
193 |
+
return mel_tensor[:, :n_bins, :int(wave_tensor.size(-1) / self.hop_length)]
|
194 |
+
|
195 |
+
def forward(self, x, wave_segments):
|
196 |
+
outs = 0
|
197 |
+
prosody_feature = self.preprocess(wave_segments)
|
198 |
+
|
199 |
+
f0_input = prosody_feature # (B, T, 20)
|
200 |
+
f0_input = self.melspec_linear(f0_input)
|
201 |
+
f0_input = self.melspec_encoder(f0_input, torch.ones(f0_input.shape[0], 1, f0_input.shape[2]).to(
|
202 |
+
f0_input.device).bool())
|
203 |
+
f0_input = self.melspec_linear2(f0_input)
|
204 |
+
|
205 |
+
common_min_size = min(f0_input.size(2), x.size(2))
|
206 |
+
f0_input = f0_input[:, :, :common_min_size]
|
207 |
+
|
208 |
+
x = x[:, :, :common_min_size]
|
209 |
+
|
210 |
+
z_p, codes_p, latents_p, commitment_loss_p, codebook_loss_p = self.prosody_quantizer(
|
211 |
+
f0_input, 1
|
212 |
+
)
|
213 |
+
outs += z_p.detach()
|
214 |
+
|
215 |
+
z_c, codes_c, latents_c, commitment_loss_c, codebook_loss_c = self.content_quantizer(
|
216 |
+
x, 2
|
217 |
+
)
|
218 |
+
outs += z_c.detach()
|
219 |
+
|
220 |
+
residual_feature = x - z_p.detach() - z_c.detach()
|
221 |
+
|
222 |
+
z_r, codes_r, latents_r, commitment_loss_r, codebook_loss_r = self.residual_quantizer(
|
223 |
+
residual_feature, 3
|
224 |
+
)
|
225 |
+
|
226 |
+
quantized = [z_p, z_c, z_r]
|
227 |
+
codes = [codes_p, codes_c, codes_r]
|
228 |
+
|
229 |
+
return quantized, codes
|
modules/rmvpe.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from librosa.util import normalize, pad_center, tiny
|
10 |
+
from scipy.signal import get_window
|
11 |
+
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class STFT(torch.nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self, filter_length=1024, hop_length=512, win_length=None, window="hann"
|
20 |
+
):
|
21 |
+
"""
|
22 |
+
This module implements an STFT using 1D convolution and 1D transpose convolutions.
|
23 |
+
This is a bit tricky so there are some cases that probably won't work as working
|
24 |
+
out the same sizes before and after in all overlap add setups is tough. Right now,
|
25 |
+
this code should work with hop lengths that are half the filter length (50% overlap
|
26 |
+
between frames).
|
27 |
+
|
28 |
+
Keyword Arguments:
|
29 |
+
filter_length {int} -- Length of filters used (default: {1024})
|
30 |
+
hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
|
31 |
+
win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
|
32 |
+
equals the filter length). (default: {None})
|
33 |
+
window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
|
34 |
+
(default: {'hann'})
|
35 |
+
"""
|
36 |
+
super(STFT, self).__init__()
|
37 |
+
self.filter_length = filter_length
|
38 |
+
self.hop_length = hop_length
|
39 |
+
self.win_length = win_length if win_length else filter_length
|
40 |
+
self.window = window
|
41 |
+
self.forward_transform = None
|
42 |
+
self.pad_amount = int(self.filter_length / 2)
|
43 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
44 |
+
|
45 |
+
cutoff = int((self.filter_length / 2 + 1))
|
46 |
+
fourier_basis = np.vstack(
|
47 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
48 |
+
)
|
49 |
+
forward_basis = torch.FloatTensor(fourier_basis)
|
50 |
+
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
|
51 |
+
|
52 |
+
assert filter_length >= self.win_length
|
53 |
+
# get window and zero center pad it to filter_length
|
54 |
+
fft_window = get_window(window, self.win_length, fftbins=True)
|
55 |
+
fft_window = pad_center(fft_window, size=filter_length)
|
56 |
+
fft_window = torch.from_numpy(fft_window).float()
|
57 |
+
|
58 |
+
# window the bases
|
59 |
+
forward_basis *= fft_window
|
60 |
+
inverse_basis = (inverse_basis.T * fft_window).T
|
61 |
+
|
62 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
63 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
64 |
+
self.register_buffer("fft_window", fft_window.float())
|
65 |
+
|
66 |
+
def transform(self, input_data, return_phase=False):
|
67 |
+
"""Take input data (audio) to STFT domain.
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
74 |
+
num_frequencies, num_frames)
|
75 |
+
phase {tensor} -- Phase of STFT with shape (num_batch,
|
76 |
+
num_frequencies, num_frames)
|
77 |
+
"""
|
78 |
+
input_data = F.pad(
|
79 |
+
input_data,
|
80 |
+
(self.pad_amount, self.pad_amount),
|
81 |
+
mode="reflect",
|
82 |
+
)
|
83 |
+
forward_transform = input_data.unfold(
|
84 |
+
1, self.filter_length, self.hop_length
|
85 |
+
).permute(0, 2, 1)
|
86 |
+
forward_transform = torch.matmul(self.forward_basis, forward_transform)
|
87 |
+
cutoff = int((self.filter_length / 2) + 1)
|
88 |
+
real_part = forward_transform[:, :cutoff, :]
|
89 |
+
imag_part = forward_transform[:, cutoff:, :]
|
90 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
91 |
+
if return_phase:
|
92 |
+
phase = torch.atan2(imag_part.data, real_part.data)
|
93 |
+
return magnitude, phase
|
94 |
+
else:
|
95 |
+
return magnitude
|
96 |
+
|
97 |
+
def inverse(self, magnitude, phase):
|
98 |
+
"""Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
|
99 |
+
by the ```transform``` function.
|
100 |
+
|
101 |
+
Arguments:
|
102 |
+
magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
|
103 |
+
num_frequencies, num_frames)
|
104 |
+
phase {tensor} -- Phase of STFT with shape (num_batch,
|
105 |
+
num_frequencies, num_frames)
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
|
109 |
+
shape (num_batch, num_samples)
|
110 |
+
"""
|
111 |
+
cat = torch.cat(
|
112 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
113 |
+
)
|
114 |
+
fold = torch.nn.Fold(
|
115 |
+
output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
|
116 |
+
kernel_size=(1, self.filter_length),
|
117 |
+
stride=(1, self.hop_length),
|
118 |
+
)
|
119 |
+
inverse_transform = torch.matmul(self.inverse_basis, cat)
|
120 |
+
inverse_transform = fold(inverse_transform)[
|
121 |
+
:, 0, 0, self.pad_amount : -self.pad_amount
|
122 |
+
]
|
123 |
+
window_square_sum = (
|
124 |
+
self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
|
125 |
+
)
|
126 |
+
window_square_sum = fold(window_square_sum)[
|
127 |
+
:, 0, 0, self.pad_amount : -self.pad_amount
|
128 |
+
]
|
129 |
+
inverse_transform /= window_square_sum
|
130 |
+
return inverse_transform
|
131 |
+
|
132 |
+
def forward(self, input_data):
|
133 |
+
"""Take input data (audio) to STFT domain and then back to audio.
|
134 |
+
|
135 |
+
Arguments:
|
136 |
+
input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
|
140 |
+
shape (num_batch, num_samples)
|
141 |
+
"""
|
142 |
+
self.magnitude, self.phase = self.transform(input_data, return_phase=True)
|
143 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
144 |
+
return reconstruction
|
145 |
+
|
146 |
+
|
147 |
+
from time import time as ttime
|
148 |
+
|
149 |
+
|
150 |
+
class BiGRU(nn.Module):
|
151 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
152 |
+
super(BiGRU, self).__init__()
|
153 |
+
self.gru = nn.GRU(
|
154 |
+
input_features,
|
155 |
+
hidden_features,
|
156 |
+
num_layers=num_layers,
|
157 |
+
batch_first=True,
|
158 |
+
bidirectional=True,
|
159 |
+
)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
return self.gru(x)[0]
|
163 |
+
|
164 |
+
|
165 |
+
class ConvBlockRes(nn.Module):
|
166 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
167 |
+
super(ConvBlockRes, self).__init__()
|
168 |
+
self.conv = nn.Sequential(
|
169 |
+
nn.Conv2d(
|
170 |
+
in_channels=in_channels,
|
171 |
+
out_channels=out_channels,
|
172 |
+
kernel_size=(3, 3),
|
173 |
+
stride=(1, 1),
|
174 |
+
padding=(1, 1),
|
175 |
+
bias=False,
|
176 |
+
),
|
177 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
178 |
+
nn.ReLU(),
|
179 |
+
nn.Conv2d(
|
180 |
+
in_channels=out_channels,
|
181 |
+
out_channels=out_channels,
|
182 |
+
kernel_size=(3, 3),
|
183 |
+
stride=(1, 1),
|
184 |
+
padding=(1, 1),
|
185 |
+
bias=False,
|
186 |
+
),
|
187 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
188 |
+
nn.ReLU(),
|
189 |
+
)
|
190 |
+
# self.shortcut:Optional[nn.Module] = None
|
191 |
+
if in_channels != out_channels:
|
192 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
193 |
+
|
194 |
+
def forward(self, x: torch.Tensor):
|
195 |
+
if not hasattr(self, "shortcut"):
|
196 |
+
return self.conv(x) + x
|
197 |
+
else:
|
198 |
+
return self.conv(x) + self.shortcut(x)
|
199 |
+
|
200 |
+
|
201 |
+
class Encoder(nn.Module):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
in_channels,
|
205 |
+
in_size,
|
206 |
+
n_encoders,
|
207 |
+
kernel_size,
|
208 |
+
n_blocks,
|
209 |
+
out_channels=16,
|
210 |
+
momentum=0.01,
|
211 |
+
):
|
212 |
+
super(Encoder, self).__init__()
|
213 |
+
self.n_encoders = n_encoders
|
214 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
215 |
+
self.layers = nn.ModuleList()
|
216 |
+
self.latent_channels = []
|
217 |
+
for i in range(self.n_encoders):
|
218 |
+
self.layers.append(
|
219 |
+
ResEncoderBlock(
|
220 |
+
in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
|
221 |
+
)
|
222 |
+
)
|
223 |
+
self.latent_channels.append([out_channels, in_size])
|
224 |
+
in_channels = out_channels
|
225 |
+
out_channels *= 2
|
226 |
+
in_size //= 2
|
227 |
+
self.out_size = in_size
|
228 |
+
self.out_channel = out_channels
|
229 |
+
|
230 |
+
def forward(self, x: torch.Tensor):
|
231 |
+
concat_tensors: List[torch.Tensor] = []
|
232 |
+
x = self.bn(x)
|
233 |
+
for i, layer in enumerate(self.layers):
|
234 |
+
t, x = layer(x)
|
235 |
+
concat_tensors.append(t)
|
236 |
+
return x, concat_tensors
|
237 |
+
|
238 |
+
|
239 |
+
class ResEncoderBlock(nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
|
242 |
+
):
|
243 |
+
super(ResEncoderBlock, self).__init__()
|
244 |
+
self.n_blocks = n_blocks
|
245 |
+
self.conv = nn.ModuleList()
|
246 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
247 |
+
for i in range(n_blocks - 1):
|
248 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
249 |
+
self.kernel_size = kernel_size
|
250 |
+
if self.kernel_size is not None:
|
251 |
+
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
for i, conv in enumerate(self.conv):
|
255 |
+
x = conv(x)
|
256 |
+
if self.kernel_size is not None:
|
257 |
+
return x, self.pool(x)
|
258 |
+
else:
|
259 |
+
return x
|
260 |
+
|
261 |
+
|
262 |
+
class Intermediate(nn.Module): #
|
263 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
264 |
+
super(Intermediate, self).__init__()
|
265 |
+
self.n_inters = n_inters
|
266 |
+
self.layers = nn.ModuleList()
|
267 |
+
self.layers.append(
|
268 |
+
ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
|
269 |
+
)
|
270 |
+
for i in range(self.n_inters - 1):
|
271 |
+
self.layers.append(
|
272 |
+
ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
|
273 |
+
)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
for i, layer in enumerate(self.layers):
|
277 |
+
x = layer(x)
|
278 |
+
return x
|
279 |
+
|
280 |
+
|
281 |
+
class ResDecoderBlock(nn.Module):
|
282 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
283 |
+
super(ResDecoderBlock, self).__init__()
|
284 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
285 |
+
self.n_blocks = n_blocks
|
286 |
+
self.conv1 = nn.Sequential(
|
287 |
+
nn.ConvTranspose2d(
|
288 |
+
in_channels=in_channels,
|
289 |
+
out_channels=out_channels,
|
290 |
+
kernel_size=(3, 3),
|
291 |
+
stride=stride,
|
292 |
+
padding=(1, 1),
|
293 |
+
output_padding=out_padding,
|
294 |
+
bias=False,
|
295 |
+
),
|
296 |
+
nn.BatchNorm2d(out_channels, momentum=momentum),
|
297 |
+
nn.ReLU(),
|
298 |
+
)
|
299 |
+
self.conv2 = nn.ModuleList()
|
300 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
301 |
+
for i in range(n_blocks - 1):
|
302 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
303 |
+
|
304 |
+
def forward(self, x, concat_tensor):
|
305 |
+
x = self.conv1(x)
|
306 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
307 |
+
for i, conv2 in enumerate(self.conv2):
|
308 |
+
x = conv2(x)
|
309 |
+
return x
|
310 |
+
|
311 |
+
|
312 |
+
class Decoder(nn.Module):
|
313 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
314 |
+
super(Decoder, self).__init__()
|
315 |
+
self.layers = nn.ModuleList()
|
316 |
+
self.n_decoders = n_decoders
|
317 |
+
for i in range(self.n_decoders):
|
318 |
+
out_channels = in_channels // 2
|
319 |
+
self.layers.append(
|
320 |
+
ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
|
321 |
+
)
|
322 |
+
in_channels = out_channels
|
323 |
+
|
324 |
+
def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
|
325 |
+
for i, layer in enumerate(self.layers):
|
326 |
+
x = layer(x, concat_tensors[-1 - i])
|
327 |
+
return x
|
328 |
+
|
329 |
+
|
330 |
+
class DeepUnet(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
kernel_size,
|
334 |
+
n_blocks,
|
335 |
+
en_de_layers=5,
|
336 |
+
inter_layers=4,
|
337 |
+
in_channels=1,
|
338 |
+
en_out_channels=16,
|
339 |
+
):
|
340 |
+
super(DeepUnet, self).__init__()
|
341 |
+
self.encoder = Encoder(
|
342 |
+
in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
|
343 |
+
)
|
344 |
+
self.intermediate = Intermediate(
|
345 |
+
self.encoder.out_channel // 2,
|
346 |
+
self.encoder.out_channel,
|
347 |
+
inter_layers,
|
348 |
+
n_blocks,
|
349 |
+
)
|
350 |
+
self.decoder = Decoder(
|
351 |
+
self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
|
352 |
+
)
|
353 |
+
|
354 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
355 |
+
x, concat_tensors = self.encoder(x)
|
356 |
+
x = self.intermediate(x)
|
357 |
+
x = self.decoder(x, concat_tensors)
|
358 |
+
return x
|
359 |
+
|
360 |
+
|
361 |
+
class E2E(nn.Module):
|
362 |
+
def __init__(
|
363 |
+
self,
|
364 |
+
n_blocks,
|
365 |
+
n_gru,
|
366 |
+
kernel_size,
|
367 |
+
en_de_layers=5,
|
368 |
+
inter_layers=4,
|
369 |
+
in_channels=1,
|
370 |
+
en_out_channels=16,
|
371 |
+
):
|
372 |
+
super(E2E, self).__init__()
|
373 |
+
self.unet = DeepUnet(
|
374 |
+
kernel_size,
|
375 |
+
n_blocks,
|
376 |
+
en_de_layers,
|
377 |
+
inter_layers,
|
378 |
+
in_channels,
|
379 |
+
en_out_channels,
|
380 |
+
)
|
381 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
382 |
+
if n_gru:
|
383 |
+
self.fc = nn.Sequential(
|
384 |
+
BiGRU(3 * 128, 256, n_gru),
|
385 |
+
nn.Linear(512, 360),
|
386 |
+
nn.Dropout(0.25),
|
387 |
+
nn.Sigmoid(),
|
388 |
+
)
|
389 |
+
else:
|
390 |
+
self.fc = nn.Sequential(
|
391 |
+
nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, mel):
|
395 |
+
# print(mel.shape)
|
396 |
+
mel = mel.transpose(-1, -2).unsqueeze(1)
|
397 |
+
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
398 |
+
x = self.fc(x)
|
399 |
+
# print(x.shape)
|
400 |
+
return x
|
401 |
+
|
402 |
+
|
403 |
+
from librosa.filters import mel
|
404 |
+
|
405 |
+
|
406 |
+
class MelSpectrogram(torch.nn.Module):
|
407 |
+
def __init__(
|
408 |
+
self,
|
409 |
+
is_half,
|
410 |
+
n_mel_channels,
|
411 |
+
sampling_rate,
|
412 |
+
win_length,
|
413 |
+
hop_length,
|
414 |
+
n_fft=None,
|
415 |
+
mel_fmin=0,
|
416 |
+
mel_fmax=None,
|
417 |
+
clamp=1e-5,
|
418 |
+
):
|
419 |
+
super().__init__()
|
420 |
+
n_fft = win_length if n_fft is None else n_fft
|
421 |
+
self.hann_window = {}
|
422 |
+
mel_basis = mel(
|
423 |
+
sr=sampling_rate,
|
424 |
+
n_fft=n_fft,
|
425 |
+
n_mels=n_mel_channels,
|
426 |
+
fmin=mel_fmin,
|
427 |
+
fmax=mel_fmax,
|
428 |
+
htk=True,
|
429 |
+
)
|
430 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
431 |
+
self.register_buffer("mel_basis", mel_basis)
|
432 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
433 |
+
self.hop_length = hop_length
|
434 |
+
self.win_length = win_length
|
435 |
+
self.sampling_rate = sampling_rate
|
436 |
+
self.n_mel_channels = n_mel_channels
|
437 |
+
self.clamp = clamp
|
438 |
+
self.is_half = is_half
|
439 |
+
|
440 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
441 |
+
factor = 2 ** (keyshift / 12)
|
442 |
+
n_fft_new = int(np.round(self.n_fft * factor))
|
443 |
+
win_length_new = int(np.round(self.win_length * factor))
|
444 |
+
hop_length_new = int(np.round(self.hop_length * speed))
|
445 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
446 |
+
if keyshift_key not in self.hann_window:
|
447 |
+
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
|
448 |
+
audio.device
|
449 |
+
)
|
450 |
+
if "privateuseone" in str(audio.device):
|
451 |
+
if not hasattr(self, "stft"):
|
452 |
+
self.stft = STFT(
|
453 |
+
filter_length=n_fft_new,
|
454 |
+
hop_length=hop_length_new,
|
455 |
+
win_length=win_length_new,
|
456 |
+
window="hann",
|
457 |
+
).to(audio.device)
|
458 |
+
magnitude = self.stft.transform(audio)
|
459 |
+
else:
|
460 |
+
fft = torch.stft(
|
461 |
+
audio,
|
462 |
+
n_fft=n_fft_new,
|
463 |
+
hop_length=hop_length_new,
|
464 |
+
win_length=win_length_new,
|
465 |
+
window=self.hann_window[keyshift_key],
|
466 |
+
center=center,
|
467 |
+
return_complex=True,
|
468 |
+
)
|
469 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
470 |
+
if keyshift != 0:
|
471 |
+
size = self.n_fft // 2 + 1
|
472 |
+
resize = magnitude.size(1)
|
473 |
+
if resize < size:
|
474 |
+
magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
475 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
476 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
477 |
+
if self.is_half == True:
|
478 |
+
mel_output = mel_output.half()
|
479 |
+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
480 |
+
return log_mel_spec
|
481 |
+
|
482 |
+
|
483 |
+
class RMVPE:
|
484 |
+
def __init__(self, model_path: str, is_half, device=None, use_jit=False):
|
485 |
+
self.resample_kernel = {}
|
486 |
+
self.resample_kernel = {}
|
487 |
+
self.is_half = is_half
|
488 |
+
if device is None:
|
489 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
490 |
+
self.device = device
|
491 |
+
self.mel_extractor = MelSpectrogram(
|
492 |
+
is_half, 128, 16000, 1024, 160, None, 30, 8000
|
493 |
+
).to(device)
|
494 |
+
if "privateuseone" in str(device):
|
495 |
+
import onnxruntime as ort
|
496 |
+
|
497 |
+
ort_session = ort.InferenceSession(
|
498 |
+
"%s/rmvpe.onnx" % os.environ["rmvpe_root"],
|
499 |
+
providers=["DmlExecutionProvider"],
|
500 |
+
)
|
501 |
+
self.model = ort_session
|
502 |
+
else:
|
503 |
+
if str(self.device) == "cuda":
|
504 |
+
self.device = torch.device("cuda:0")
|
505 |
+
|
506 |
+
def get_default_model():
|
507 |
+
model = E2E(4, 1, (2, 2))
|
508 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
509 |
+
model.load_state_dict(ckpt)
|
510 |
+
model.eval()
|
511 |
+
if is_half:
|
512 |
+
model = model.half()
|
513 |
+
else:
|
514 |
+
model = model.float()
|
515 |
+
return model
|
516 |
+
|
517 |
+
self.model = get_default_model()
|
518 |
+
|
519 |
+
self.model = self.model.to(device)
|
520 |
+
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
|
521 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
|
522 |
+
|
523 |
+
def mel2hidden(self, mel):
|
524 |
+
with torch.no_grad():
|
525 |
+
n_frames = mel.shape[-1]
|
526 |
+
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
|
527 |
+
if n_pad > 0:
|
528 |
+
mel = F.pad(mel, (0, n_pad), mode="constant")
|
529 |
+
if "privateuseone" in str(self.device):
|
530 |
+
onnx_input_name = self.model.get_inputs()[0].name
|
531 |
+
onnx_outputs_names = self.model.get_outputs()[0].name
|
532 |
+
hidden = self.model.run(
|
533 |
+
[onnx_outputs_names],
|
534 |
+
input_feed={onnx_input_name: mel.cpu().numpy()},
|
535 |
+
)[0]
|
536 |
+
else:
|
537 |
+
mel = mel.half() if self.is_half else mel.float()
|
538 |
+
hidden = self.model(mel)
|
539 |
+
return hidden[:, :n_frames]
|
540 |
+
|
541 |
+
def decode(self, hidden, thred=0.03):
|
542 |
+
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
543 |
+
f0 = 10 * (2 ** (cents_pred / 1200))
|
544 |
+
f0[f0 == 10] = 0
|
545 |
+
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
|
546 |
+
return f0
|
547 |
+
|
548 |
+
def infer_from_audio(self, audio, thred=0.03):
|
549 |
+
# torch.cuda.synchronize()
|
550 |
+
# t0 = ttime()
|
551 |
+
if not torch.is_tensor(audio):
|
552 |
+
audio = torch.from_numpy(audio)
|
553 |
+
mel = self.mel_extractor(
|
554 |
+
audio.float().to(self.device).unsqueeze(0), center=True
|
555 |
+
)
|
556 |
+
# print(123123123,mel.device.type)
|
557 |
+
# torch.cuda.synchronize()
|
558 |
+
# t1 = ttime()
|
559 |
+
hidden = self.mel2hidden(mel)
|
560 |
+
# torch.cuda.synchronize()
|
561 |
+
# t2 = ttime()
|
562 |
+
# print(234234,hidden.device.type)
|
563 |
+
if "privateuseone" not in str(self.device):
|
564 |
+
hidden = hidden.squeeze(0).cpu().numpy()
|
565 |
+
else:
|
566 |
+
hidden = hidden[0]
|
567 |
+
if self.is_half == True:
|
568 |
+
hidden = hidden.astype("float32")
|
569 |
+
|
570 |
+
f0 = self.decode(hidden, thred=thred)
|
571 |
+
# torch.cuda.synchronize()
|
572 |
+
# t3 = ttime()
|
573 |
+
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
|
574 |
+
return f0
|
575 |
+
|
576 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
577 |
+
# t0 = ttime()
|
578 |
+
center = np.argmax(salience, axis=1) # 帧长#index
|
579 |
+
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
|
580 |
+
# t1 = ttime()
|
581 |
+
center += 4
|
582 |
+
todo_salience = []
|
583 |
+
todo_cents_mapping = []
|
584 |
+
starts = center - 4
|
585 |
+
ends = center + 5
|
586 |
+
for idx in range(salience.shape[0]):
|
587 |
+
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
588 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
589 |
+
# t2 = ttime()
|
590 |
+
todo_salience = np.array(todo_salience) # 帧长,9
|
591 |
+
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
|
592 |
+
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
593 |
+
weight_sum = np.sum(todo_salience, 1) # 帧长
|
594 |
+
devided = product_sum / weight_sum # 帧长
|
595 |
+
# t3 = ttime()
|
596 |
+
maxx = np.max(salience, axis=1) # 帧长
|
597 |
+
devided[maxx <= thred] = 0
|
598 |
+
# t4 = ttime()
|
599 |
+
# print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
|
600 |
+
return devided
|
modules/wavenet.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from modules.encodec import SConv1d
|
7 |
+
|
8 |
+
from . import commons
|
9 |
+
LRELU_SLOPE = 0.1
|
10 |
+
|
11 |
+
class LayerNorm(nn.Module):
|
12 |
+
def __init__(self, channels, eps=1e-5):
|
13 |
+
super().__init__()
|
14 |
+
self.channels = channels
|
15 |
+
self.eps = eps
|
16 |
+
|
17 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
18 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = x.transpose(1, -1)
|
22 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
23 |
+
return x.transpose(1, -1)
|
24 |
+
|
25 |
+
|
26 |
+
class ConvReluNorm(nn.Module):
|
27 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
28 |
+
super().__init__()
|
29 |
+
self.in_channels = in_channels
|
30 |
+
self.hidden_channels = hidden_channels
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.kernel_size = kernel_size
|
33 |
+
self.n_layers = n_layers
|
34 |
+
self.p_dropout = p_dropout
|
35 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
36 |
+
|
37 |
+
self.conv_layers = nn.ModuleList()
|
38 |
+
self.norm_layers = nn.ModuleList()
|
39 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
40 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
41 |
+
self.relu_drop = nn.Sequential(
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Dropout(p_dropout))
|
44 |
+
for _ in range(n_layers - 1):
|
45 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
46 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
47 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
48 |
+
self.proj.weight.data.zero_()
|
49 |
+
self.proj.bias.data.zero_()
|
50 |
+
|
51 |
+
def forward(self, x, x_mask):
|
52 |
+
x_org = x
|
53 |
+
for i in range(self.n_layers):
|
54 |
+
x = self.conv_layers[i](x * x_mask)
|
55 |
+
x = self.norm_layers[i](x)
|
56 |
+
x = self.relu_drop(x)
|
57 |
+
x = x_org + self.proj(x)
|
58 |
+
return x * x_mask
|
59 |
+
|
60 |
+
|
61 |
+
class DDSConv(nn.Module):
|
62 |
+
"""
|
63 |
+
Dialted and Depth-Separable Convolution
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
67 |
+
super().__init__()
|
68 |
+
self.channels = channels
|
69 |
+
self.kernel_size = kernel_size
|
70 |
+
self.n_layers = n_layers
|
71 |
+
self.p_dropout = p_dropout
|
72 |
+
|
73 |
+
self.drop = nn.Dropout(p_dropout)
|
74 |
+
self.convs_sep = nn.ModuleList()
|
75 |
+
self.convs_1x1 = nn.ModuleList()
|
76 |
+
self.norms_1 = nn.ModuleList()
|
77 |
+
self.norms_2 = nn.ModuleList()
|
78 |
+
for i in range(n_layers):
|
79 |
+
dilation = kernel_size ** i
|
80 |
+
padding = (kernel_size * dilation - dilation) // 2
|
81 |
+
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
82 |
+
groups=channels, dilation=dilation, padding=padding
|
83 |
+
))
|
84 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
85 |
+
self.norms_1.append(LayerNorm(channels))
|
86 |
+
self.norms_2.append(LayerNorm(channels))
|
87 |
+
|
88 |
+
def forward(self, x, x_mask, g=None):
|
89 |
+
if g is not None:
|
90 |
+
x = x + g
|
91 |
+
for i in range(self.n_layers):
|
92 |
+
y = self.convs_sep[i](x * x_mask)
|
93 |
+
y = self.norms_1[i](y)
|
94 |
+
y = F.gelu(y)
|
95 |
+
y = self.convs_1x1[i](y)
|
96 |
+
y = self.norms_2[i](y)
|
97 |
+
y = F.gelu(y)
|
98 |
+
y = self.drop(y)
|
99 |
+
x = x + y
|
100 |
+
return x * x_mask
|
101 |
+
|
102 |
+
|
103 |
+
class WN(torch.nn.Module):
|
104 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0, causal=False):
|
105 |
+
super(WN, self).__init__()
|
106 |
+
conv1d_type = SConv1d
|
107 |
+
assert (kernel_size % 2 == 1)
|
108 |
+
self.hidden_channels = hidden_channels
|
109 |
+
self.kernel_size = kernel_size,
|
110 |
+
self.dilation_rate = dilation_rate
|
111 |
+
self.n_layers = n_layers
|
112 |
+
self.gin_channels = gin_channels
|
113 |
+
self.p_dropout = p_dropout
|
114 |
+
|
115 |
+
self.in_layers = torch.nn.ModuleList()
|
116 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
117 |
+
self.drop = nn.Dropout(p_dropout)
|
118 |
+
|
119 |
+
if gin_channels != 0:
|
120 |
+
self.cond_layer = conv1d_type(gin_channels, 2 * hidden_channels * n_layers, 1, norm='weight_norm')
|
121 |
+
|
122 |
+
for i in range(n_layers):
|
123 |
+
dilation = dilation_rate ** i
|
124 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
125 |
+
in_layer = conv1d_type(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation,
|
126 |
+
padding=padding, norm='weight_norm', causal=causal)
|
127 |
+
self.in_layers.append(in_layer)
|
128 |
+
|
129 |
+
# last one is not necessary
|
130 |
+
if i < n_layers - 1:
|
131 |
+
res_skip_channels = 2 * hidden_channels
|
132 |
+
else:
|
133 |
+
res_skip_channels = hidden_channels
|
134 |
+
|
135 |
+
res_skip_layer = conv1d_type(hidden_channels, res_skip_channels, 1, norm='weight_norm', causal=causal)
|
136 |
+
self.res_skip_layers.append(res_skip_layer)
|
137 |
+
|
138 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
139 |
+
output = torch.zeros_like(x)
|
140 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
141 |
+
|
142 |
+
if g is not None:
|
143 |
+
g = self.cond_layer(g)
|
144 |
+
|
145 |
+
for i in range(self.n_layers):
|
146 |
+
x_in = self.in_layers[i](x)
|
147 |
+
if g is not None:
|
148 |
+
cond_offset = i * 2 * self.hidden_channels
|
149 |
+
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
150 |
+
else:
|
151 |
+
g_l = torch.zeros_like(x_in)
|
152 |
+
|
153 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(
|
154 |
+
x_in,
|
155 |
+
g_l,
|
156 |
+
n_channels_tensor)
|
157 |
+
acts = self.drop(acts)
|
158 |
+
|
159 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
160 |
+
if i < self.n_layers - 1:
|
161 |
+
res_acts = res_skip_acts[:, :self.hidden_channels, :]
|
162 |
+
x = (x + res_acts) * x_mask
|
163 |
+
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
164 |
+
else:
|
165 |
+
output = output + res_skip_acts
|
166 |
+
return output * x_mask
|
167 |
+
|
168 |
+
def remove_weight_norm(self):
|
169 |
+
if self.gin_channels != 0:
|
170 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
171 |
+
for l in self.in_layers:
|
172 |
+
torch.nn.utils.remove_weight_norm(l)
|
173 |
+
for l in self.res_skip_layers:
|
174 |
+
torch.nn.utils.remove_weight_norm(l)
|