ElesisSiegherts commited on
Commit
ed6c2db
1 Parent(s): 1b9cb8c

Upload 7 files

Browse files
Files changed (7) hide show
  1. losses.py +58 -0
  2. mel_processing.py +142 -0
  3. models.py +1044 -0
  4. models_onnx.py +986 -0
  5. modules.py +597 -0
  6. preprocess_text.py +140 -0
  7. re_matching.py +82 -0
losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr) ** 2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += r_loss + g_loss
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg) ** 2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
mel_processing.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ import warnings
5
+
6
+ # warnings.simplefilter(action='ignore', category=FutureWarning)
7
+ warnings.filterwarnings(action="ignore")
8
+ MAX_WAV_VALUE = 32768.0
9
+
10
+
11
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
+ """
13
+ PARAMS
14
+ ------
15
+ C: compression factor
16
+ """
17
+ return torch.log(torch.clamp(x, min=clip_val) * C)
18
+
19
+
20
+ def dynamic_range_decompression_torch(x, C=1):
21
+ """
22
+ PARAMS
23
+ ------
24
+ C: compression factor used to compress
25
+ """
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ output = dynamic_range_compression_torch(magnitudes)
31
+ return output
32
+
33
+
34
+ def spectral_de_normalize_torch(magnitudes):
35
+ output = dynamic_range_decompression_torch(magnitudes)
36
+ return output
37
+
38
+
39
+ mel_basis = {}
40
+ hann_window = {}
41
+
42
+
43
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
44
+ if torch.min(y) < -1.0:
45
+ print("min value is ", torch.min(y))
46
+ if torch.max(y) > 1.0:
47
+ print("max value is ", torch.max(y))
48
+
49
+ global hann_window
50
+ dtype_device = str(y.dtype) + "_" + str(y.device)
51
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
52
+ if wnsize_dtype_device not in hann_window:
53
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
54
+ dtype=y.dtype, device=y.device
55
+ )
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1),
59
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
60
+ mode="reflect",
61
+ )
62
+ y = y.squeeze(1)
63
+
64
+ spec = torch.stft(
65
+ y,
66
+ n_fft,
67
+ hop_length=hop_size,
68
+ win_length=win_size,
69
+ window=hann_window[wnsize_dtype_device],
70
+ center=center,
71
+ pad_mode="reflect",
72
+ normalized=False,
73
+ onesided=True,
74
+ return_complex=False,
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
78
+ return spec
79
+
80
+
81
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
82
+ global mel_basis
83
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
84
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
85
+ if fmax_dtype_device not in mel_basis:
86
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
87
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
88
+ dtype=spec.dtype, device=spec.device
89
+ )
90
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
91
+ spec = spectral_normalize_torch(spec)
92
+ return spec
93
+
94
+
95
+ def mel_spectrogram_torch(
96
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
97
+ ):
98
+ if torch.min(y) < -1.0:
99
+ print("min value is ", torch.min(y))
100
+ if torch.max(y) > 1.0:
101
+ print("max value is ", torch.max(y))
102
+
103
+ global mel_basis, hann_window
104
+ dtype_device = str(y.dtype) + "_" + str(y.device)
105
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
106
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
107
+ if fmax_dtype_device not in mel_basis:
108
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
109
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
110
+ dtype=y.dtype, device=y.device
111
+ )
112
+ if wnsize_dtype_device not in hann_window:
113
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
114
+ dtype=y.dtype, device=y.device
115
+ )
116
+
117
+ y = torch.nn.functional.pad(
118
+ y.unsqueeze(1),
119
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
120
+ mode="reflect",
121
+ )
122
+ y = y.squeeze(1)
123
+
124
+ spec = torch.stft(
125
+ y,
126
+ n_fft,
127
+ hop_length=hop_size,
128
+ win_length=win_size,
129
+ window=hann_window[wnsize_dtype_device],
130
+ center=center,
131
+ pad_mode="reflect",
132
+ normalized=False,
133
+ onesided=True,
134
+ return_complex=False,
135
+ )
136
+
137
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
138
+
139
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
140
+ spec = spectral_normalize_torch(spec)
141
+
142
+ return spec
models.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from vector_quantize_pytorch import VectorQuantize
14
+
15
+ from commons import init_weights, get_padding
16
+ from text import symbols, num_tones, num_languages
17
+
18
+
19
+ class DurationDiscriminator(nn.Module): # vits2
20
+ def __init__(
21
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
22
+ ):
23
+ super().__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.filter_channels = filter_channels
27
+ self.kernel_size = kernel_size
28
+ self.p_dropout = p_dropout
29
+ self.gin_channels = gin_channels
30
+
31
+ self.drop = nn.Dropout(p_dropout)
32
+ self.conv_1 = nn.Conv1d(
33
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_1 = modules.LayerNorm(filter_channels)
36
+ self.conv_2 = nn.Conv1d(
37
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
38
+ )
39
+ self.norm_2 = modules.LayerNorm(filter_channels)
40
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
41
+
42
+ self.pre_out_conv_1 = nn.Conv1d(
43
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
46
+ self.pre_out_conv_2 = nn.Conv1d(
47
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
48
+ )
49
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
50
+
51
+ if gin_channels != 0:
52
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
53
+
54
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
55
+
56
+ def forward_probability(self, x, x_mask, dur, g=None):
57
+ dur = self.dur_proj(dur)
58
+ x = torch.cat([x, dur], dim=1)
59
+ x = self.pre_out_conv_1(x * x_mask)
60
+ x = torch.relu(x)
61
+ x = self.pre_out_norm_1(x)
62
+ x = self.drop(x)
63
+ x = self.pre_out_conv_2(x * x_mask)
64
+ x = torch.relu(x)
65
+ x = self.pre_out_norm_2(x)
66
+ x = self.drop(x)
67
+ x = x * x_mask
68
+ x = x.transpose(1, 2)
69
+ output_prob = self.output_layer(x)
70
+ return output_prob
71
+
72
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
73
+ x = torch.detach(x)
74
+ if g is not None:
75
+ g = torch.detach(g)
76
+ x = x + self.cond(g)
77
+ x = self.conv_1(x * x_mask)
78
+ x = torch.relu(x)
79
+ x = self.norm_1(x)
80
+ x = self.drop(x)
81
+ x = self.conv_2(x * x_mask)
82
+ x = torch.relu(x)
83
+ x = self.norm_2(x)
84
+ x = self.drop(x)
85
+
86
+ output_probs = []
87
+ for dur in [dur_r, dur_hat]:
88
+ output_prob = self.forward_probability(x, x_mask, dur, g)
89
+ output_probs.append(output_prob)
90
+
91
+ return output_probs
92
+
93
+
94
+ class TransformerCouplingBlock(nn.Module):
95
+ def __init__(
96
+ self,
97
+ channels,
98
+ hidden_channels,
99
+ filter_channels,
100
+ n_heads,
101
+ n_layers,
102
+ kernel_size,
103
+ p_dropout,
104
+ n_flows=4,
105
+ gin_channels=0,
106
+ share_parameter=False,
107
+ ):
108
+ super().__init__()
109
+ self.channels = channels
110
+ self.hidden_channels = hidden_channels
111
+ self.kernel_size = kernel_size
112
+ self.n_layers = n_layers
113
+ self.n_flows = n_flows
114
+ self.gin_channels = gin_channels
115
+
116
+ self.flows = nn.ModuleList()
117
+
118
+ self.wn = (
119
+ attentions.FFT(
120
+ hidden_channels,
121
+ filter_channels,
122
+ n_heads,
123
+ n_layers,
124
+ kernel_size,
125
+ p_dropout,
126
+ isflow=True,
127
+ gin_channels=self.gin_channels,
128
+ )
129
+ if share_parameter
130
+ else None
131
+ )
132
+
133
+ for i in range(n_flows):
134
+ self.flows.append(
135
+ modules.TransformerCouplingLayer(
136
+ channels,
137
+ hidden_channels,
138
+ kernel_size,
139
+ n_layers,
140
+ n_heads,
141
+ p_dropout,
142
+ filter_channels,
143
+ mean_only=True,
144
+ wn_sharing_parameter=self.wn,
145
+ gin_channels=self.gin_channels,
146
+ )
147
+ )
148
+ self.flows.append(modules.Flip())
149
+
150
+ def forward(self, x, x_mask, g=None, reverse=False):
151
+ if not reverse:
152
+ for flow in self.flows:
153
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
154
+ else:
155
+ for flow in reversed(self.flows):
156
+ x = flow(x, x_mask, g=g, reverse=reverse)
157
+ return x
158
+
159
+
160
+ class StochasticDurationPredictor(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ filter_channels,
165
+ kernel_size,
166
+ p_dropout,
167
+ n_flows=4,
168
+ gin_channels=0,
169
+ ):
170
+ super().__init__()
171
+ filter_channels = in_channels # it needs to be removed from future version.
172
+ self.in_channels = in_channels
173
+ self.filter_channels = filter_channels
174
+ self.kernel_size = kernel_size
175
+ self.p_dropout = p_dropout
176
+ self.n_flows = n_flows
177
+ self.gin_channels = gin_channels
178
+
179
+ self.log_flow = modules.Log()
180
+ self.flows = nn.ModuleList()
181
+ self.flows.append(modules.ElementwiseAffine(2))
182
+ for i in range(n_flows):
183
+ self.flows.append(
184
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
185
+ )
186
+ self.flows.append(modules.Flip())
187
+
188
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
189
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
190
+ self.post_convs = modules.DDSConv(
191
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
192
+ )
193
+ self.post_flows = nn.ModuleList()
194
+ self.post_flows.append(modules.ElementwiseAffine(2))
195
+ for i in range(4):
196
+ self.post_flows.append(
197
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
198
+ )
199
+ self.post_flows.append(modules.Flip())
200
+
201
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
202
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
203
+ self.convs = modules.DDSConv(
204
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
205
+ )
206
+ if gin_channels != 0:
207
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
208
+
209
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
210
+ x = torch.detach(x)
211
+ x = self.pre(x)
212
+ if g is not None:
213
+ g = torch.detach(g)
214
+ x = x + self.cond(g)
215
+ x = self.convs(x, x_mask)
216
+ x = self.proj(x) * x_mask
217
+
218
+ if not reverse:
219
+ flows = self.flows
220
+ assert w is not None
221
+
222
+ logdet_tot_q = 0
223
+ h_w = self.post_pre(w)
224
+ h_w = self.post_convs(h_w, x_mask)
225
+ h_w = self.post_proj(h_w) * x_mask
226
+ e_q = (
227
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
228
+ * x_mask
229
+ )
230
+ z_q = e_q
231
+ for flow in self.post_flows:
232
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
233
+ logdet_tot_q += logdet_q
234
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
235
+ u = torch.sigmoid(z_u) * x_mask
236
+ z0 = (w - u) * x_mask
237
+ logdet_tot_q += torch.sum(
238
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
239
+ )
240
+ logq = (
241
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
242
+ - logdet_tot_q
243
+ )
244
+
245
+ logdet_tot = 0
246
+ z0, logdet = self.log_flow(z0, x_mask)
247
+ logdet_tot += logdet
248
+ z = torch.cat([z0, z1], 1)
249
+ for flow in flows:
250
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
251
+ logdet_tot = logdet_tot + logdet
252
+ nll = (
253
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
254
+ - logdet_tot
255
+ )
256
+ return nll + logq # [b]
257
+ else:
258
+ flows = list(reversed(self.flows))
259
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
260
+ z = (
261
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
262
+ * noise_scale
263
+ )
264
+ for flow in flows:
265
+ z = flow(z, x_mask, g=x, reverse=reverse)
266
+ z0, z1 = torch.split(z, [1, 1], 1)
267
+ logw = z0
268
+ return logw
269
+
270
+
271
+ class DurationPredictor(nn.Module):
272
+ def __init__(
273
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
274
+ ):
275
+ super().__init__()
276
+
277
+ self.in_channels = in_channels
278
+ self.filter_channels = filter_channels
279
+ self.kernel_size = kernel_size
280
+ self.p_dropout = p_dropout
281
+ self.gin_channels = gin_channels
282
+
283
+ self.drop = nn.Dropout(p_dropout)
284
+ self.conv_1 = nn.Conv1d(
285
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
286
+ )
287
+ self.norm_1 = modules.LayerNorm(filter_channels)
288
+ self.conv_2 = nn.Conv1d(
289
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
290
+ )
291
+ self.norm_2 = modules.LayerNorm(filter_channels)
292
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
293
+
294
+ if gin_channels != 0:
295
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
296
+
297
+ def forward(self, x, x_mask, g=None):
298
+ x = torch.detach(x)
299
+ if g is not None:
300
+ g = torch.detach(g)
301
+ x = x + self.cond(g)
302
+ x = self.conv_1(x * x_mask)
303
+ x = torch.relu(x)
304
+ x = self.norm_1(x)
305
+ x = self.drop(x)
306
+ x = self.conv_2(x * x_mask)
307
+ x = torch.relu(x)
308
+ x = self.norm_2(x)
309
+ x = self.drop(x)
310
+ x = self.proj(x * x_mask)
311
+ return x * x_mask
312
+
313
+
314
+ class TextEncoder(nn.Module):
315
+ def __init__(
316
+ self,
317
+ n_vocab,
318
+ out_channels,
319
+ hidden_channels,
320
+ filter_channels,
321
+ n_heads,
322
+ n_layers,
323
+ kernel_size,
324
+ p_dropout,
325
+ n_speakers,
326
+ gin_channels=0,
327
+ ):
328
+ super().__init__()
329
+ self.n_vocab = n_vocab
330
+ self.out_channels = out_channels
331
+ self.hidden_channels = hidden_channels
332
+ self.filter_channels = filter_channels
333
+ self.n_heads = n_heads
334
+ self.n_layers = n_layers
335
+ self.kernel_size = kernel_size
336
+ self.p_dropout = p_dropout
337
+ self.gin_channels = gin_channels
338
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
339
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
340
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
341
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
343
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
344
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
345
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
346
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
347
+ self.emo_proj = nn.Linear(1024, 1024)
348
+ self.emo_quantizer = [
349
+ VectorQuantize(
350
+ dim=1024,
351
+ codebook_size=10,
352
+ decay=0.8,
353
+ commitment_weight=1.0,
354
+ learnable_codebook=True,
355
+ ema_update=False,
356
+ )
357
+ ] * n_speakers
358
+ self.emo_q_proj = nn.Linear(1024, hidden_channels)
359
+
360
+ self.encoder = attentions.Encoder(
361
+ hidden_channels,
362
+ filter_channels,
363
+ n_heads,
364
+ n_layers,
365
+ kernel_size,
366
+ p_dropout,
367
+ gin_channels=self.gin_channels,
368
+ )
369
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
370
+
371
+ def forward(
372
+ self, x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=None
373
+ ):
374
+ sid = sid.cpu()
375
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
376
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
377
+ en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
378
+ if emo.size(-1) == 1024:
379
+ emo_emb = self.emo_proj(emo.unsqueeze(1))
380
+ emo_commit_loss = torch.zeros(1)
381
+ emo_emb_ = []
382
+ for i in range(emo_emb.size(0)):
383
+ temp_emo_emb, _, temp_emo_commit_loss = self.emo_quantizer[sid[i]](
384
+ emo_emb[i].unsqueeze(0).cpu()
385
+ )
386
+ emo_commit_loss += temp_emo_commit_loss
387
+ emo_emb_.append(temp_emo_emb)
388
+ emo_emb = torch.cat(emo_emb_, dim=0).to(emo_emb.device)
389
+ emo_commit_loss = emo_commit_loss.to(emo_emb.device)
390
+ else:
391
+ emo_emb = (
392
+ self.emo_quantizer[sid[0]]
393
+ .get_output_from_indices(emo.to(torch.int).cpu())
394
+ .unsqueeze(0)
395
+ .to(emo.device)
396
+ )
397
+ emo_commit_loss = torch.zeros(1)
398
+ x = (
399
+ self.emb(x)
400
+ + self.tone_emb(tone)
401
+ + self.language_emb(language)
402
+ + bert_emb
403
+ + ja_bert_emb
404
+ + en_bert_emb
405
+ + self.emo_q_proj(emo_emb)
406
+ ) * math.sqrt(
407
+ self.hidden_channels
408
+ ) # [b, t, h]
409
+ x = torch.transpose(x, 1, -1) # [b, h, t]
410
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
411
+ x.dtype
412
+ )
413
+
414
+ x = self.encoder(x * x_mask, x_mask, g=g)
415
+ stats = self.proj(x) * x_mask
416
+
417
+ m, logs = torch.split(stats, self.out_channels, dim=1)
418
+ return x, m, logs, x_mask, emo_commit_loss
419
+
420
+
421
+ class ResidualCouplingBlock(nn.Module):
422
+ def __init__(
423
+ self,
424
+ channels,
425
+ hidden_channels,
426
+ kernel_size,
427
+ dilation_rate,
428
+ n_layers,
429
+ n_flows=4,
430
+ gin_channels=0,
431
+ ):
432
+ super().__init__()
433
+ self.channels = channels
434
+ self.hidden_channels = hidden_channels
435
+ self.kernel_size = kernel_size
436
+ self.dilation_rate = dilation_rate
437
+ self.n_layers = n_layers
438
+ self.n_flows = n_flows
439
+ self.gin_channels = gin_channels
440
+
441
+ self.flows = nn.ModuleList()
442
+ for i in range(n_flows):
443
+ self.flows.append(
444
+ modules.ResidualCouplingLayer(
445
+ channels,
446
+ hidden_channels,
447
+ kernel_size,
448
+ dilation_rate,
449
+ n_layers,
450
+ gin_channels=gin_channels,
451
+ mean_only=True,
452
+ )
453
+ )
454
+ self.flows.append(modules.Flip())
455
+
456
+ def forward(self, x, x_mask, g=None, reverse=False):
457
+ if not reverse:
458
+ for flow in self.flows:
459
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
460
+ else:
461
+ for flow in reversed(self.flows):
462
+ x = flow(x, x_mask, g=g, reverse=reverse)
463
+ return x
464
+
465
+
466
+ class PosteriorEncoder(nn.Module):
467
+ def __init__(
468
+ self,
469
+ in_channels,
470
+ out_channels,
471
+ hidden_channels,
472
+ kernel_size,
473
+ dilation_rate,
474
+ n_layers,
475
+ gin_channels=0,
476
+ ):
477
+ super().__init__()
478
+ self.in_channels = in_channels
479
+ self.out_channels = out_channels
480
+ self.hidden_channels = hidden_channels
481
+ self.kernel_size = kernel_size
482
+ self.dilation_rate = dilation_rate
483
+ self.n_layers = n_layers
484
+ self.gin_channels = gin_channels
485
+
486
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
487
+ self.enc = modules.WN(
488
+ hidden_channels,
489
+ kernel_size,
490
+ dilation_rate,
491
+ n_layers,
492
+ gin_channels=gin_channels,
493
+ )
494
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
495
+
496
+ def forward(self, x, x_lengths, g=None):
497
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
498
+ x.dtype
499
+ )
500
+ x = self.pre(x) * x_mask
501
+ x = self.enc(x, x_mask, g=g)
502
+ stats = self.proj(x) * x_mask
503
+ m, logs = torch.split(stats, self.out_channels, dim=1)
504
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
505
+ return z, m, logs, x_mask
506
+
507
+
508
+ class Generator(torch.nn.Module):
509
+ def __init__(
510
+ self,
511
+ initial_channel,
512
+ resblock,
513
+ resblock_kernel_sizes,
514
+ resblock_dilation_sizes,
515
+ upsample_rates,
516
+ upsample_initial_channel,
517
+ upsample_kernel_sizes,
518
+ gin_channels=0,
519
+ ):
520
+ super(Generator, self).__init__()
521
+ self.num_kernels = len(resblock_kernel_sizes)
522
+ self.num_upsamples = len(upsample_rates)
523
+ self.conv_pre = Conv1d(
524
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
525
+ )
526
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
527
+
528
+ self.ups = nn.ModuleList()
529
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
530
+ self.ups.append(
531
+ weight_norm(
532
+ ConvTranspose1d(
533
+ upsample_initial_channel // (2**i),
534
+ upsample_initial_channel // (2 ** (i + 1)),
535
+ k,
536
+ u,
537
+ padding=(k - u) // 2,
538
+ )
539
+ )
540
+ )
541
+
542
+ self.resblocks = nn.ModuleList()
543
+ for i in range(len(self.ups)):
544
+ ch = upsample_initial_channel // (2 ** (i + 1))
545
+ for j, (k, d) in enumerate(
546
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
547
+ ):
548
+ self.resblocks.append(resblock(ch, k, d))
549
+
550
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
551
+ self.ups.apply(init_weights)
552
+
553
+ if gin_channels != 0:
554
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
555
+
556
+ def forward(self, x, g=None):
557
+ x = self.conv_pre(x)
558
+ if g is not None:
559
+ x = x + self.cond(g)
560
+
561
+ for i in range(self.num_upsamples):
562
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
563
+ x = self.ups[i](x)
564
+ xs = None
565
+ for j in range(self.num_kernels):
566
+ if xs is None:
567
+ xs = self.resblocks[i * self.num_kernels + j](x)
568
+ else:
569
+ xs += self.resblocks[i * self.num_kernels + j](x)
570
+ x = xs / self.num_kernels
571
+ x = F.leaky_relu(x)
572
+ x = self.conv_post(x)
573
+ x = torch.tanh(x)
574
+
575
+ return x
576
+
577
+ def remove_weight_norm(self):
578
+ print("Removing weight norm...")
579
+ for layer in self.ups:
580
+ remove_weight_norm(layer)
581
+ for layer in self.resblocks:
582
+ layer.remove_weight_norm()
583
+
584
+
585
+ class DiscriminatorP(torch.nn.Module):
586
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
587
+ super(DiscriminatorP, self).__init__()
588
+ self.period = period
589
+ self.use_spectral_norm = use_spectral_norm
590
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
591
+ self.convs = nn.ModuleList(
592
+ [
593
+ norm_f(
594
+ Conv2d(
595
+ 1,
596
+ 32,
597
+ (kernel_size, 1),
598
+ (stride, 1),
599
+ padding=(get_padding(kernel_size, 1), 0),
600
+ )
601
+ ),
602
+ norm_f(
603
+ Conv2d(
604
+ 32,
605
+ 128,
606
+ (kernel_size, 1),
607
+ (stride, 1),
608
+ padding=(get_padding(kernel_size, 1), 0),
609
+ )
610
+ ),
611
+ norm_f(
612
+ Conv2d(
613
+ 128,
614
+ 512,
615
+ (kernel_size, 1),
616
+ (stride, 1),
617
+ padding=(get_padding(kernel_size, 1), 0),
618
+ )
619
+ ),
620
+ norm_f(
621
+ Conv2d(
622
+ 512,
623
+ 1024,
624
+ (kernel_size, 1),
625
+ (stride, 1),
626
+ padding=(get_padding(kernel_size, 1), 0),
627
+ )
628
+ ),
629
+ norm_f(
630
+ Conv2d(
631
+ 1024,
632
+ 1024,
633
+ (kernel_size, 1),
634
+ 1,
635
+ padding=(get_padding(kernel_size, 1), 0),
636
+ )
637
+ ),
638
+ ]
639
+ )
640
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
641
+
642
+ def forward(self, x):
643
+ fmap = []
644
+
645
+ # 1d to 2d
646
+ b, c, t = x.shape
647
+ if t % self.period != 0: # pad first
648
+ n_pad = self.period - (t % self.period)
649
+ x = F.pad(x, (0, n_pad), "reflect")
650
+ t = t + n_pad
651
+ x = x.view(b, c, t // self.period, self.period)
652
+
653
+ for layer in self.convs:
654
+ x = layer(x)
655
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
656
+ fmap.append(x)
657
+ x = self.conv_post(x)
658
+ fmap.append(x)
659
+ x = torch.flatten(x, 1, -1)
660
+
661
+ return x, fmap
662
+
663
+
664
+ class DiscriminatorS(torch.nn.Module):
665
+ def __init__(self, use_spectral_norm=False):
666
+ super(DiscriminatorS, self).__init__()
667
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
668
+ self.convs = nn.ModuleList(
669
+ [
670
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
671
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
672
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
673
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
674
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
675
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
676
+ ]
677
+ )
678
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
679
+
680
+ def forward(self, x):
681
+ fmap = []
682
+
683
+ for layer in self.convs:
684
+ x = layer(x)
685
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
686
+ fmap.append(x)
687
+ x = self.conv_post(x)
688
+ fmap.append(x)
689
+ x = torch.flatten(x, 1, -1)
690
+
691
+ return x, fmap
692
+
693
+
694
+ class MultiPeriodDiscriminator(torch.nn.Module):
695
+ def __init__(self, use_spectral_norm=False):
696
+ super(MultiPeriodDiscriminator, self).__init__()
697
+ periods = [2, 3, 5, 7, 11]
698
+
699
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
700
+ discs = discs + [
701
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
702
+ ]
703
+ self.discriminators = nn.ModuleList(discs)
704
+
705
+ def forward(self, y, y_hat):
706
+ y_d_rs = []
707
+ y_d_gs = []
708
+ fmap_rs = []
709
+ fmap_gs = []
710
+ for i, d in enumerate(self.discriminators):
711
+ y_d_r, fmap_r = d(y)
712
+ y_d_g, fmap_g = d(y_hat)
713
+ y_d_rs.append(y_d_r)
714
+ y_d_gs.append(y_d_g)
715
+ fmap_rs.append(fmap_r)
716
+ fmap_gs.append(fmap_g)
717
+
718
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
719
+
720
+
721
+ class ReferenceEncoder(nn.Module):
722
+ """
723
+ inputs --- [N, Ty/r, n_mels*r] mels
724
+ outputs --- [N, ref_enc_gru_size]
725
+ """
726
+
727
+ def __init__(self, spec_channels, gin_channels=0):
728
+ super().__init__()
729
+ self.spec_channels = spec_channels
730
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
731
+ K = len(ref_enc_filters)
732
+ filters = [1] + ref_enc_filters
733
+ convs = [
734
+ weight_norm(
735
+ nn.Conv2d(
736
+ in_channels=filters[i],
737
+ out_channels=filters[i + 1],
738
+ kernel_size=(3, 3),
739
+ stride=(2, 2),
740
+ padding=(1, 1),
741
+ )
742
+ )
743
+ for i in range(K)
744
+ ]
745
+ self.convs = nn.ModuleList(convs)
746
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
747
+
748
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
749
+ self.gru = nn.GRU(
750
+ input_size=ref_enc_filters[-1] * out_channels,
751
+ hidden_size=256 // 2,
752
+ batch_first=True,
753
+ )
754
+ self.proj = nn.Linear(128, gin_channels)
755
+
756
+ def forward(self, inputs, mask=None):
757
+ N = inputs.size(0)
758
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
759
+ for conv in self.convs:
760
+ out = conv(out)
761
+ # out = wn(out)
762
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
763
+
764
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
765
+ T = out.size(1)
766
+ N = out.size(0)
767
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
768
+
769
+ self.gru.flatten_parameters()
770
+ memory, out = self.gru(out) # out --- [1, N, 128]
771
+
772
+ return self.proj(out.squeeze(0))
773
+
774
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
775
+ for i in range(n_convs):
776
+ L = (L - kernel_size + 2 * pad) // stride + 1
777
+ return L
778
+
779
+
780
+ class SynthesizerTrn(nn.Module):
781
+ """
782
+ Synthesizer for Training
783
+ """
784
+
785
+ def __init__(
786
+ self,
787
+ n_vocab,
788
+ spec_channels,
789
+ segment_size,
790
+ inter_channels,
791
+ hidden_channels,
792
+ filter_channels,
793
+ n_heads,
794
+ n_layers,
795
+ kernel_size,
796
+ p_dropout,
797
+ resblock,
798
+ resblock_kernel_sizes,
799
+ resblock_dilation_sizes,
800
+ upsample_rates,
801
+ upsample_initial_channel,
802
+ upsample_kernel_sizes,
803
+ n_speakers=256,
804
+ gin_channels=256,
805
+ use_sdp=True,
806
+ n_flow_layer=4,
807
+ n_layers_trans_flow=4,
808
+ flow_share_parameter=False,
809
+ use_transformer_flow=True,
810
+ **kwargs
811
+ ):
812
+ super().__init__()
813
+ self.n_vocab = n_vocab
814
+ self.spec_channels = spec_channels
815
+ self.inter_channels = inter_channels
816
+ self.hidden_channels = hidden_channels
817
+ self.filter_channels = filter_channels
818
+ self.n_heads = n_heads
819
+ self.n_layers = n_layers
820
+ self.kernel_size = kernel_size
821
+ self.p_dropout = p_dropout
822
+ self.resblock = resblock
823
+ self.resblock_kernel_sizes = resblock_kernel_sizes
824
+ self.resblock_dilation_sizes = resblock_dilation_sizes
825
+ self.upsample_rates = upsample_rates
826
+ self.upsample_initial_channel = upsample_initial_channel
827
+ self.upsample_kernel_sizes = upsample_kernel_sizes
828
+ self.segment_size = segment_size
829
+ self.n_speakers = n_speakers
830
+ self.gin_channels = gin_channels
831
+ self.n_layers_trans_flow = n_layers_trans_flow
832
+ self.use_spk_conditioned_encoder = kwargs.get(
833
+ "use_spk_conditioned_encoder", True
834
+ )
835
+ self.use_sdp = use_sdp
836
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
837
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
838
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
839
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
840
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
841
+ self.enc_gin_channels = gin_channels
842
+ self.enc_p = TextEncoder(
843
+ n_vocab,
844
+ inter_channels,
845
+ hidden_channels,
846
+ filter_channels,
847
+ n_heads,
848
+ n_layers,
849
+ kernel_size,
850
+ p_dropout,
851
+ self.n_speakers,
852
+ gin_channels=self.enc_gin_channels,
853
+ )
854
+ self.dec = Generator(
855
+ inter_channels,
856
+ resblock,
857
+ resblock_kernel_sizes,
858
+ resblock_dilation_sizes,
859
+ upsample_rates,
860
+ upsample_initial_channel,
861
+ upsample_kernel_sizes,
862
+ gin_channels=gin_channels,
863
+ )
864
+ self.enc_q = PosteriorEncoder(
865
+ spec_channels,
866
+ inter_channels,
867
+ hidden_channels,
868
+ 5,
869
+ 1,
870
+ 16,
871
+ gin_channels=gin_channels,
872
+ )
873
+ if use_transformer_flow:
874
+ self.flow = TransformerCouplingBlock(
875
+ inter_channels,
876
+ hidden_channels,
877
+ filter_channels,
878
+ n_heads,
879
+ n_layers_trans_flow,
880
+ 5,
881
+ p_dropout,
882
+ n_flow_layer,
883
+ gin_channels=gin_channels,
884
+ share_parameter=flow_share_parameter,
885
+ )
886
+ else:
887
+ self.flow = ResidualCouplingBlock(
888
+ inter_channels,
889
+ hidden_channels,
890
+ 5,
891
+ 1,
892
+ n_flow_layer,
893
+ gin_channels=gin_channels,
894
+ )
895
+ self.sdp = StochasticDurationPredictor(
896
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
897
+ )
898
+ self.dp = DurationPredictor(
899
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
900
+ )
901
+
902
+ if n_speakers >= 1:
903
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
904
+ else:
905
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
906
+
907
+ def forward(
908
+ self,
909
+ x,
910
+ x_lengths,
911
+ y,
912
+ y_lengths,
913
+ sid,
914
+ tone,
915
+ language,
916
+ bert,
917
+ ja_bert,
918
+ en_bert,
919
+ emo=None,
920
+ ):
921
+ if self.n_speakers > 0:
922
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
923
+ else:
924
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
925
+ x, m_p, logs_p, x_mask, loss_commit = self.enc_p(
926
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
927
+ )
928
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
929
+ z_p = self.flow(z, y_mask, g=g)
930
+
931
+ with torch.no_grad():
932
+ # negative cross-entropy
933
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
934
+ neg_cent1 = torch.sum(
935
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
936
+ ) # [b, 1, t_s]
937
+ neg_cent2 = torch.matmul(
938
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
939
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
940
+ neg_cent3 = torch.matmul(
941
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
942
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
943
+ neg_cent4 = torch.sum(
944
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
945
+ ) # [b, 1, t_s]
946
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
947
+ if self.use_noise_scaled_mas:
948
+ epsilon = (
949
+ torch.std(neg_cent)
950
+ * torch.randn_like(neg_cent)
951
+ * self.current_mas_noise_scale
952
+ )
953
+ neg_cent = neg_cent + epsilon
954
+
955
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
956
+ attn = (
957
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
958
+ .unsqueeze(1)
959
+ .detach()
960
+ )
961
+
962
+ w = attn.sum(2)
963
+
964
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
965
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
966
+
967
+ logw_ = torch.log(w + 1e-6) * x_mask
968
+ logw = self.dp(x, x_mask, g=g)
969
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
970
+ x_mask
971
+ ) # for averaging
972
+
973
+ l_length = l_length_dp + l_length_sdp
974
+
975
+ # expand prior
976
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
977
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
978
+
979
+ z_slice, ids_slice = commons.rand_slice_segments(
980
+ z, y_lengths, self.segment_size
981
+ )
982
+ o = self.dec(z_slice, g=g)
983
+ return (
984
+ o,
985
+ l_length,
986
+ attn,
987
+ ids_slice,
988
+ x_mask,
989
+ y_mask,
990
+ (z, z_p, m_p, logs_p, m_q, logs_q),
991
+ (x, logw, logw_),
992
+ loss_commit,
993
+ )
994
+
995
+ def infer(
996
+ self,
997
+ x,
998
+ x_lengths,
999
+ sid,
1000
+ tone,
1001
+ language,
1002
+ bert,
1003
+ ja_bert,
1004
+ en_bert,
1005
+ emo=None,
1006
+ noise_scale=0.667,
1007
+ length_scale=1,
1008
+ noise_scale_w=0.8,
1009
+ max_len=None,
1010
+ sdp_ratio=0,
1011
+ y=None,
1012
+ ):
1013
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1014
+ # g = self.gst(y)
1015
+ if self.n_speakers > 0:
1016
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1017
+ else:
1018
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1019
+ x, m_p, logs_p, x_mask, _ = self.enc_p(
1020
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, emo, sid, g=g
1021
+ )
1022
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1023
+ sdp_ratio
1024
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1025
+ w = torch.exp(logw) * x_mask * length_scale
1026
+ w_ceil = torch.ceil(w)
1027
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1028
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1029
+ x_mask.dtype
1030
+ )
1031
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1032
+ attn = commons.generate_path(w_ceil, attn_mask)
1033
+
1034
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1035
+ 1, 2
1036
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1037
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1038
+ 1, 2
1039
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1040
+
1041
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1042
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1043
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1044
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
models_onnx.py ADDED
@@ -0,0 +1,986 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions_onnx
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+ from text import symbols, num_tones, num_languages
14
+
15
+
16
+ class DurationDiscriminator(nn.Module): # vits2
17
+ def __init__(
18
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
19
+ ):
20
+ super().__init__()
21
+
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
+
39
+ self.pre_out_conv_1 = nn.Conv1d(
40
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
+ )
42
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
+ self.pre_out_conv_2 = nn.Conv1d(
44
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
+ )
46
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
+
48
+ if gin_channels != 0:
49
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
+
51
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
+
53
+ def forward_probability(self, x, x_mask, dur, g=None):
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = self.pre_out_conv_1(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_1(x)
59
+ x = self.drop(x)
60
+ x = self.pre_out_conv_2(x * x_mask)
61
+ x = torch.relu(x)
62
+ x = self.pre_out_norm_2(x)
63
+ x = self.drop(x)
64
+ x = x * x_mask
65
+ x = x.transpose(1, 2)
66
+ output_prob = self.output_layer(x)
67
+ return output_prob
68
+
69
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
+ x = torch.detach(x)
71
+ if g is not None:
72
+ g = torch.detach(g)
73
+ x = x + self.cond(g)
74
+ x = self.conv_1(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_1(x)
77
+ x = self.drop(x)
78
+ x = self.conv_2(x * x_mask)
79
+ x = torch.relu(x)
80
+ x = self.norm_2(x)
81
+ x = self.drop(x)
82
+
83
+ output_probs = []
84
+ for dur in [dur_r, dur_hat]:
85
+ output_prob = self.forward_probability(x, x_mask, dur, g)
86
+ output_probs.append(output_prob)
87
+
88
+ return output_probs
89
+
90
+
91
+ class TransformerCouplingBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ n_heads,
98
+ n_layers,
99
+ kernel_size,
100
+ p_dropout,
101
+ n_flows=4,
102
+ gin_channels=0,
103
+ share_parameter=False,
104
+ ):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions_onnx.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=True):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, z, g=None):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ flows = list(reversed(self.flows))
216
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
217
+ for flow in flows:
218
+ z = flow(z, x_mask, g=x, reverse=True)
219
+ z0, z1 = torch.split(z, [1, 1], 1)
220
+ logw = z0
221
+ return logw
222
+
223
+
224
+ class DurationPredictor(nn.Module):
225
+ def __init__(
226
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
227
+ ):
228
+ super().__init__()
229
+
230
+ self.in_channels = in_channels
231
+ self.filter_channels = filter_channels
232
+ self.kernel_size = kernel_size
233
+ self.p_dropout = p_dropout
234
+ self.gin_channels = gin_channels
235
+
236
+ self.drop = nn.Dropout(p_dropout)
237
+ self.conv_1 = nn.Conv1d(
238
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
239
+ )
240
+ self.norm_1 = modules.LayerNorm(filter_channels)
241
+ self.conv_2 = nn.Conv1d(
242
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
243
+ )
244
+ self.norm_2 = modules.LayerNorm(filter_channels)
245
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
246
+
247
+ if gin_channels != 0:
248
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
249
+
250
+ def forward(self, x, x_mask, g=None):
251
+ x = torch.detach(x)
252
+ if g is not None:
253
+ g = torch.detach(g)
254
+ x = x + self.cond(g)
255
+ x = self.conv_1(x * x_mask)
256
+ x = torch.relu(x)
257
+ x = self.norm_1(x)
258
+ x = self.drop(x)
259
+ x = self.conv_2(x * x_mask)
260
+ x = torch.relu(x)
261
+ x = self.norm_2(x)
262
+ x = self.drop(x)
263
+ x = self.proj(x * x_mask)
264
+ return x * x_mask
265
+
266
+
267
+ class TextEncoder(nn.Module):
268
+ def __init__(
269
+ self,
270
+ n_vocab,
271
+ out_channels,
272
+ hidden_channels,
273
+ filter_channels,
274
+ n_heads,
275
+ n_layers,
276
+ kernel_size,
277
+ p_dropout,
278
+ gin_channels=0,
279
+ ):
280
+ super().__init__()
281
+ self.n_vocab = n_vocab
282
+ self.out_channels = out_channels
283
+ self.hidden_channels = hidden_channels
284
+ self.filter_channels = filter_channels
285
+ self.n_heads = n_heads
286
+ self.n_layers = n_layers
287
+ self.kernel_size = kernel_size
288
+ self.p_dropout = p_dropout
289
+ self.gin_channels = gin_channels
290
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
291
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
292
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
293
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
294
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
295
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
296
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
297
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
298
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
299
+
300
+ self.encoder = attentions_onnx.Encoder(
301
+ hidden_channels,
302
+ filter_channels,
303
+ n_heads,
304
+ n_layers,
305
+ kernel_size,
306
+ p_dropout,
307
+ gin_channels=self.gin_channels,
308
+ )
309
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
310
+
311
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
312
+ x_mask = torch.ones_like(x).unsqueeze(0)
313
+ bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
314
+ ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
315
+ 1, 2
316
+ )
317
+ en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
318
+ 1, 2
319
+ )
320
+ x = (
321
+ self.emb(x)
322
+ + self.tone_emb(tone)
323
+ + self.language_emb(language)
324
+ + bert_emb
325
+ + ja_bert_emb
326
+ + en_bert_emb
327
+ ) * math.sqrt(
328
+ self.hidden_channels
329
+ ) # [b, t, h]
330
+ x = torch.transpose(x, 1, -1) # [b, h, t]
331
+ x_mask = x_mask.to(x.dtype)
332
+
333
+ x = self.encoder(x * x_mask, x_mask, g=g)
334
+ stats = self.proj(x) * x_mask
335
+
336
+ m, logs = torch.split(stats, self.out_channels, dim=1)
337
+ return x, m, logs, x_mask
338
+
339
+
340
+ class ResidualCouplingBlock(nn.Module):
341
+ def __init__(
342
+ self,
343
+ channels,
344
+ hidden_channels,
345
+ kernel_size,
346
+ dilation_rate,
347
+ n_layers,
348
+ n_flows=4,
349
+ gin_channels=0,
350
+ ):
351
+ super().__init__()
352
+ self.channels = channels
353
+ self.hidden_channels = hidden_channels
354
+ self.kernel_size = kernel_size
355
+ self.dilation_rate = dilation_rate
356
+ self.n_layers = n_layers
357
+ self.n_flows = n_flows
358
+ self.gin_channels = gin_channels
359
+
360
+ self.flows = nn.ModuleList()
361
+ for i in range(n_flows):
362
+ self.flows.append(
363
+ modules.ResidualCouplingLayer(
364
+ channels,
365
+ hidden_channels,
366
+ kernel_size,
367
+ dilation_rate,
368
+ n_layers,
369
+ gin_channels=gin_channels,
370
+ mean_only=True,
371
+ )
372
+ )
373
+ self.flows.append(modules.Flip())
374
+
375
+ def forward(self, x, x_mask, g=None, reverse=True):
376
+ if not reverse:
377
+ for flow in self.flows:
378
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
379
+ else:
380
+ for flow in reversed(self.flows):
381
+ x = flow(x, x_mask, g=g, reverse=reverse)
382
+ return x
383
+
384
+
385
+ class PosteriorEncoder(nn.Module):
386
+ def __init__(
387
+ self,
388
+ in_channels,
389
+ out_channels,
390
+ hidden_channels,
391
+ kernel_size,
392
+ dilation_rate,
393
+ n_layers,
394
+ gin_channels=0,
395
+ ):
396
+ super().__init__()
397
+ self.in_channels = in_channels
398
+ self.out_channels = out_channels
399
+ self.hidden_channels = hidden_channels
400
+ self.kernel_size = kernel_size
401
+ self.dilation_rate = dilation_rate
402
+ self.n_layers = n_layers
403
+ self.gin_channels = gin_channels
404
+
405
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
406
+ self.enc = modules.WN(
407
+ hidden_channels,
408
+ kernel_size,
409
+ dilation_rate,
410
+ n_layers,
411
+ gin_channels=gin_channels,
412
+ )
413
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
414
+
415
+ def forward(self, x, x_lengths, g=None):
416
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
417
+ x.dtype
418
+ )
419
+ x = self.pre(x) * x_mask
420
+ x = self.enc(x, x_mask, g=g)
421
+ stats = self.proj(x) * x_mask
422
+ m, logs = torch.split(stats, self.out_channels, dim=1)
423
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
424
+ return z, m, logs, x_mask
425
+
426
+
427
+ class Generator(torch.nn.Module):
428
+ def __init__(
429
+ self,
430
+ initial_channel,
431
+ resblock,
432
+ resblock_kernel_sizes,
433
+ resblock_dilation_sizes,
434
+ upsample_rates,
435
+ upsample_initial_channel,
436
+ upsample_kernel_sizes,
437
+ gin_channels=0,
438
+ ):
439
+ super(Generator, self).__init__()
440
+ self.num_kernels = len(resblock_kernel_sizes)
441
+ self.num_upsamples = len(upsample_rates)
442
+ self.conv_pre = Conv1d(
443
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
444
+ )
445
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
446
+
447
+ self.ups = nn.ModuleList()
448
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
449
+ self.ups.append(
450
+ weight_norm(
451
+ ConvTranspose1d(
452
+ upsample_initial_channel // (2**i),
453
+ upsample_initial_channel // (2 ** (i + 1)),
454
+ k,
455
+ u,
456
+ padding=(k - u) // 2,
457
+ )
458
+ )
459
+ )
460
+
461
+ self.resblocks = nn.ModuleList()
462
+ for i in range(len(self.ups)):
463
+ ch = upsample_initial_channel // (2 ** (i + 1))
464
+ for j, (k, d) in enumerate(
465
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
466
+ ):
467
+ self.resblocks.append(resblock(ch, k, d))
468
+
469
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
470
+ self.ups.apply(init_weights)
471
+
472
+ if gin_channels != 0:
473
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
474
+
475
+ def forward(self, x, g=None):
476
+ x = self.conv_pre(x)
477
+ if g is not None:
478
+ x = x + self.cond(g)
479
+
480
+ for i in range(self.num_upsamples):
481
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
482
+ x = self.ups[i](x)
483
+ xs = None
484
+ for j in range(self.num_kernels):
485
+ if xs is None:
486
+ xs = self.resblocks[i * self.num_kernels + j](x)
487
+ else:
488
+ xs += self.resblocks[i * self.num_kernels + j](x)
489
+ x = xs / self.num_kernels
490
+ x = F.leaky_relu(x)
491
+ x = self.conv_post(x)
492
+ x = torch.tanh(x)
493
+
494
+ return x
495
+
496
+ def remove_weight_norm(self):
497
+ print("Removing weight norm...")
498
+ for layer in self.ups:
499
+ remove_weight_norm(layer)
500
+ for layer in self.resblocks:
501
+ layer.remove_weight_norm()
502
+
503
+
504
+ class DiscriminatorP(torch.nn.Module):
505
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
506
+ super(DiscriminatorP, self).__init__()
507
+ self.period = period
508
+ self.use_spectral_norm = use_spectral_norm
509
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
510
+ self.convs = nn.ModuleList(
511
+ [
512
+ norm_f(
513
+ Conv2d(
514
+ 1,
515
+ 32,
516
+ (kernel_size, 1),
517
+ (stride, 1),
518
+ padding=(get_padding(kernel_size, 1), 0),
519
+ )
520
+ ),
521
+ norm_f(
522
+ Conv2d(
523
+ 32,
524
+ 128,
525
+ (kernel_size, 1),
526
+ (stride, 1),
527
+ padding=(get_padding(kernel_size, 1), 0),
528
+ )
529
+ ),
530
+ norm_f(
531
+ Conv2d(
532
+ 128,
533
+ 512,
534
+ (kernel_size, 1),
535
+ (stride, 1),
536
+ padding=(get_padding(kernel_size, 1), 0),
537
+ )
538
+ ),
539
+ norm_f(
540
+ Conv2d(
541
+ 512,
542
+ 1024,
543
+ (kernel_size, 1),
544
+ (stride, 1),
545
+ padding=(get_padding(kernel_size, 1), 0),
546
+ )
547
+ ),
548
+ norm_f(
549
+ Conv2d(
550
+ 1024,
551
+ 1024,
552
+ (kernel_size, 1),
553
+ 1,
554
+ padding=(get_padding(kernel_size, 1), 0),
555
+ )
556
+ ),
557
+ ]
558
+ )
559
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
560
+
561
+ def forward(self, x):
562
+ fmap = []
563
+
564
+ # 1d to 2d
565
+ b, c, t = x.shape
566
+ if t % self.period != 0: # pad first
567
+ n_pad = self.period - (t % self.period)
568
+ x = F.pad(x, (0, n_pad), "reflect")
569
+ t = t + n_pad
570
+ x = x.view(b, c, t // self.period, self.period)
571
+
572
+ for layer in self.convs:
573
+ x = layer(x)
574
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
575
+ fmap.append(x)
576
+ x = self.conv_post(x)
577
+ fmap.append(x)
578
+ x = torch.flatten(x, 1, -1)
579
+
580
+ return x, fmap
581
+
582
+
583
+ class DiscriminatorS(torch.nn.Module):
584
+ def __init__(self, use_spectral_norm=False):
585
+ super(DiscriminatorS, self).__init__()
586
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
587
+ self.convs = nn.ModuleList(
588
+ [
589
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
590
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
591
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
592
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
593
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
594
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
595
+ ]
596
+ )
597
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
598
+
599
+ def forward(self, x):
600
+ fmap = []
601
+
602
+ for layer in self.convs:
603
+ x = layer(x)
604
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
605
+ fmap.append(x)
606
+ x = self.conv_post(x)
607
+ fmap.append(x)
608
+ x = torch.flatten(x, 1, -1)
609
+
610
+ return x, fmap
611
+
612
+
613
+ class MultiPeriodDiscriminator(torch.nn.Module):
614
+ def __init__(self, use_spectral_norm=False):
615
+ super(MultiPeriodDiscriminator, self).__init__()
616
+ periods = [2, 3, 5, 7, 11]
617
+
618
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
619
+ discs = discs + [
620
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
621
+ ]
622
+ self.discriminators = nn.ModuleList(discs)
623
+
624
+ def forward(self, y, y_hat):
625
+ y_d_rs = []
626
+ y_d_gs = []
627
+ fmap_rs = []
628
+ fmap_gs = []
629
+ for i, d in enumerate(self.discriminators):
630
+ y_d_r, fmap_r = d(y)
631
+ y_d_g, fmap_g = d(y_hat)
632
+ y_d_rs.append(y_d_r)
633
+ y_d_gs.append(y_d_g)
634
+ fmap_rs.append(fmap_r)
635
+ fmap_gs.append(fmap_g)
636
+
637
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
638
+
639
+
640
+ class ReferenceEncoder(nn.Module):
641
+ """
642
+ inputs --- [N, Ty/r, n_mels*r] mels
643
+ outputs --- [N, ref_enc_gru_size]
644
+ """
645
+
646
+ def __init__(self, spec_channels, gin_channels=0):
647
+ super().__init__()
648
+ self.spec_channels = spec_channels
649
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
650
+ K = len(ref_enc_filters)
651
+ filters = [1] + ref_enc_filters
652
+ convs = [
653
+ weight_norm(
654
+ nn.Conv2d(
655
+ in_channels=filters[i],
656
+ out_channels=filters[i + 1],
657
+ kernel_size=(3, 3),
658
+ stride=(2, 2),
659
+ padding=(1, 1),
660
+ )
661
+ )
662
+ for i in range(K)
663
+ ]
664
+ self.convs = nn.ModuleList(convs)
665
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
666
+
667
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
668
+ self.gru = nn.GRU(
669
+ input_size=ref_enc_filters[-1] * out_channels,
670
+ hidden_size=256 // 2,
671
+ batch_first=True,
672
+ )
673
+ self.proj = nn.Linear(128, gin_channels)
674
+
675
+ def forward(self, inputs, mask=None):
676
+ N = inputs.size(0)
677
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
678
+ for conv in self.convs:
679
+ out = conv(out)
680
+ # out = wn(out)
681
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
682
+
683
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
684
+ T = out.size(1)
685
+ N = out.size(0)
686
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
687
+
688
+ self.gru.flatten_parameters()
689
+ memory, out = self.gru(out) # out --- [1, N, 128]
690
+
691
+ return self.proj(out.squeeze(0))
692
+
693
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
694
+ for i in range(n_convs):
695
+ L = (L - kernel_size + 2 * pad) // stride + 1
696
+ return L
697
+
698
+
699
+ class SynthesizerTrn(nn.Module):
700
+ """
701
+ Synthesizer for Training
702
+ """
703
+
704
+ def __init__(
705
+ self,
706
+ n_vocab,
707
+ spec_channels,
708
+ segment_size,
709
+ inter_channels,
710
+ hidden_channels,
711
+ filter_channels,
712
+ n_heads,
713
+ n_layers,
714
+ kernel_size,
715
+ p_dropout,
716
+ resblock,
717
+ resblock_kernel_sizes,
718
+ resblock_dilation_sizes,
719
+ upsample_rates,
720
+ upsample_initial_channel,
721
+ upsample_kernel_sizes,
722
+ n_speakers=256,
723
+ gin_channels=256,
724
+ use_sdp=True,
725
+ n_flow_layer=4,
726
+ n_layers_trans_flow=4,
727
+ flow_share_parameter=False,
728
+ use_transformer_flow=True,
729
+ **kwargs,
730
+ ):
731
+ super().__init__()
732
+ self.n_vocab = n_vocab
733
+ self.spec_channels = spec_channels
734
+ self.inter_channels = inter_channels
735
+ self.hidden_channels = hidden_channels
736
+ self.filter_channels = filter_channels
737
+ self.n_heads = n_heads
738
+ self.n_layers = n_layers
739
+ self.kernel_size = kernel_size
740
+ self.p_dropout = p_dropout
741
+ self.resblock = resblock
742
+ self.resblock_kernel_sizes = resblock_kernel_sizes
743
+ self.resblock_dilation_sizes = resblock_dilation_sizes
744
+ self.upsample_rates = upsample_rates
745
+ self.upsample_initial_channel = upsample_initial_channel
746
+ self.upsample_kernel_sizes = upsample_kernel_sizes
747
+ self.segment_size = segment_size
748
+ self.n_speakers = n_speakers
749
+ self.gin_channels = gin_channels
750
+ self.n_layers_trans_flow = n_layers_trans_flow
751
+ self.use_spk_conditioned_encoder = kwargs.get(
752
+ "use_spk_conditioned_encoder", True
753
+ )
754
+ self.use_sdp = use_sdp
755
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
756
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
757
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
758
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
759
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
760
+ self.enc_gin_channels = gin_channels
761
+ self.enc_p = TextEncoder(
762
+ n_vocab,
763
+ inter_channels,
764
+ hidden_channels,
765
+ filter_channels,
766
+ n_heads,
767
+ n_layers,
768
+ kernel_size,
769
+ p_dropout,
770
+ gin_channels=self.enc_gin_channels,
771
+ )
772
+ self.dec = Generator(
773
+ inter_channels,
774
+ resblock,
775
+ resblock_kernel_sizes,
776
+ resblock_dilation_sizes,
777
+ upsample_rates,
778
+ upsample_initial_channel,
779
+ upsample_kernel_sizes,
780
+ gin_channels=gin_channels,
781
+ )
782
+ self.enc_q = PosteriorEncoder(
783
+ spec_channels,
784
+ inter_channels,
785
+ hidden_channels,
786
+ 5,
787
+ 1,
788
+ 16,
789
+ gin_channels=gin_channels,
790
+ )
791
+ if use_transformer_flow:
792
+ self.flow = TransformerCouplingBlock(
793
+ inter_channels,
794
+ hidden_channels,
795
+ filter_channels,
796
+ n_heads,
797
+ n_layers_trans_flow,
798
+ 5,
799
+ p_dropout,
800
+ n_flow_layer,
801
+ gin_channels=gin_channels,
802
+ share_parameter=flow_share_parameter,
803
+ )
804
+ else:
805
+ self.flow = ResidualCouplingBlock(
806
+ inter_channels,
807
+ hidden_channels,
808
+ 5,
809
+ 1,
810
+ n_flow_layer,
811
+ gin_channels=gin_channels,
812
+ )
813
+ self.sdp = StochasticDurationPredictor(
814
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
815
+ )
816
+ self.dp = DurationPredictor(
817
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
818
+ )
819
+
820
+ if n_speakers >= 1:
821
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
822
+ else:
823
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
824
+
825
+ def export_onnx(
826
+ self,
827
+ path,
828
+ max_len=None,
829
+ sdp_ratio=0,
830
+ y=None,
831
+ ):
832
+ noise_scale = 0.667
833
+ length_scale = 1
834
+ noise_scale_w = 0.8
835
+ x = torch.LongTensor(
836
+ [
837
+ 0,
838
+ 97,
839
+ 0,
840
+ 8,
841
+ 0,
842
+ 78,
843
+ 0,
844
+ 8,
845
+ 0,
846
+ 76,
847
+ 0,
848
+ 37,
849
+ 0,
850
+ 40,
851
+ 0,
852
+ 97,
853
+ 0,
854
+ 8,
855
+ 0,
856
+ 23,
857
+ 0,
858
+ 8,
859
+ 0,
860
+ 74,
861
+ 0,
862
+ 26,
863
+ 0,
864
+ 104,
865
+ 0,
866
+ ]
867
+ ).unsqueeze(0)
868
+ tone = torch.zeros_like(x)
869
+ language = torch.zeros_like(x)
870
+ x_lengths = torch.LongTensor([x.shape[1]])
871
+ sid = torch.LongTensor([0])
872
+ bert = torch.randn(size=(x.shape[1], 1024))
873
+ ja_bert = torch.randn(size=(x.shape[1], 1024))
874
+ en_bert = torch.randn(size=(x.shape[1], 1024))
875
+
876
+ if self.n_speakers > 0:
877
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
878
+ torch.onnx.export(
879
+ self.emb_g,
880
+ (sid),
881
+ f"onnx/{path}/{path}_emb.onnx",
882
+ input_names=["sid"],
883
+ output_names=["g"],
884
+ verbose=True,
885
+ )
886
+ else:
887
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
888
+
889
+ torch.onnx.export(
890
+ self.enc_p,
891
+ (x, x_lengths, tone, language, bert, ja_bert, en_bert, g),
892
+ f"onnx/{path}/{path}_enc_p.onnx",
893
+ input_names=[
894
+ "x",
895
+ "x_lengths",
896
+ "t",
897
+ "language",
898
+ "bert_0",
899
+ "bert_1",
900
+ "bert_2",
901
+ "g",
902
+ ],
903
+ output_names=["xout", "m_p", "logs_p", "x_mask"],
904
+ dynamic_axes={
905
+ "x": [0, 1],
906
+ "t": [0, 1],
907
+ "language": [0, 1],
908
+ "bert_0": [0],
909
+ "bert_1": [0],
910
+ "bert_2": [0],
911
+ "xout": [0, 2],
912
+ "m_p": [0, 2],
913
+ "logs_p": [0, 2],
914
+ "x_mask": [0, 2],
915
+ },
916
+ verbose=True,
917
+ opset_version=16,
918
+ )
919
+ x, m_p, logs_p, x_mask = self.enc_p(
920
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
921
+ )
922
+ zinput = (
923
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
924
+ * noise_scale_w
925
+ )
926
+ torch.onnx.export(
927
+ self.sdp,
928
+ (x, x_mask, zinput, g),
929
+ f"onnx/{path}/{path}_sdp.onnx",
930
+ input_names=["x", "x_mask", "zin", "g"],
931
+ output_names=["logw"],
932
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
933
+ verbose=True,
934
+ )
935
+ torch.onnx.export(
936
+ self.dp,
937
+ (x, x_mask, g),
938
+ f"onnx/{path}/{path}_dp.onnx",
939
+ input_names=["x", "x_mask", "g"],
940
+ output_names=["logw"],
941
+ dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
942
+ verbose=True,
943
+ )
944
+ logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
945
+ x, x_mask, g=g
946
+ ) * (1 - sdp_ratio)
947
+ w = torch.exp(logw) * x_mask * length_scale
948
+ w_ceil = torch.ceil(w)
949
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
950
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
951
+ x_mask.dtype
952
+ )
953
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
954
+ attn = commons.generate_path(w_ceil, attn_mask)
955
+
956
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
957
+ 1, 2
958
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
959
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
960
+ 1, 2
961
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
962
+
963
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
964
+ torch.onnx.export(
965
+ self.flow,
966
+ (z_p, y_mask, g),
967
+ f"onnx/{path}/{path}_flow.onnx",
968
+ input_names=["z_p", "y_mask", "g"],
969
+ output_names=["z"],
970
+ dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
971
+ verbose=True,
972
+ )
973
+
974
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
975
+ z_in = (z * y_mask)[:, :, :max_len]
976
+
977
+ torch.onnx.export(
978
+ self.dec,
979
+ (z_in, g),
980
+ f"onnx/{path}/{path}_dec.onnx",
981
+ input_names=["z_in", "g"],
982
+ output_names=["o"],
983
+ dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
984
+ verbose=True,
985
+ )
986
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
modules.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ import commons
10
+ from commons import init_weights, get_padding
11
+ from transforms import piecewise_rational_quadratic_transform
12
+ from attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert channels % 2 == 0, "channels should be divisible by 2"
534
+ super().__init__()
535
+ self.channels = channels
536
+ self.hidden_channels = hidden_channels
537
+ self.kernel_size = kernel_size
538
+ self.n_layers = n_layers
539
+ self.half_channels = channels // 2
540
+ self.mean_only = mean_only
541
+
542
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
543
+ self.enc = (
544
+ Encoder(
545
+ hidden_channels,
546
+ filter_channels,
547
+ n_heads,
548
+ n_layers,
549
+ kernel_size,
550
+ p_dropout,
551
+ isflow=True,
552
+ gin_channels=gin_channels,
553
+ )
554
+ if wn_sharing_parameter is None
555
+ else wn_sharing_parameter
556
+ )
557
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
558
+ self.post.weight.data.zero_()
559
+ self.post.bias.data.zero_()
560
+
561
+ def forward(self, x, x_mask, g=None, reverse=False):
562
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
563
+ h = self.pre(x0) * x_mask
564
+ h = self.enc(h, x_mask, g=g)
565
+ stats = self.post(h) * x_mask
566
+ if not self.mean_only:
567
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
568
+ else:
569
+ m = stats
570
+ logs = torch.zeros_like(m)
571
+
572
+ if not reverse:
573
+ x1 = m + x1 * torch.exp(logs) * x_mask
574
+ x = torch.cat([x0, x1], 1)
575
+ logdet = torch.sum(logs, [1, 2])
576
+ return x, logdet
577
+ else:
578
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
579
+ x = torch.cat([x0, x1], 1)
580
+ return x
581
+
582
+ x1, logabsdet = piecewise_rational_quadratic_transform(
583
+ x1,
584
+ unnormalized_widths,
585
+ unnormalized_heights,
586
+ unnormalized_derivatives,
587
+ inverse=reverse,
588
+ tails="linear",
589
+ tail_bound=self.tail_bound,
590
+ )
591
+
592
+ x = torch.cat([x0, x1], 1) * x_mask
593
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
594
+ if not reverse:
595
+ return x, logdet
596
+ else:
597
+ return x
preprocess_text.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from random import shuffle
4
+ from typing import Optional
5
+ import os
6
+
7
+ from tqdm import tqdm
8
+ import click
9
+ from text.cleaner import clean_text
10
+ from config import config
11
+ from infer import latest_version
12
+
13
+ preprocess_text_config = config.preprocess_text_config
14
+
15
+
16
+ @click.command()
17
+ @click.option(
18
+ "--transcription-path",
19
+ default=preprocess_text_config.transcription_path,
20
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
21
+ )
22
+ @click.option("--cleaned-path", default=preprocess_text_config.cleaned_path)
23
+ @click.option("--train-path", default=preprocess_text_config.train_path)
24
+ @click.option("--val-path", default=preprocess_text_config.val_path)
25
+ @click.option(
26
+ "--config-path",
27
+ default=preprocess_text_config.config_path,
28
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
29
+ )
30
+ @click.option("--val-per-spk", default=preprocess_text_config.val_per_spk)
31
+ @click.option("--max-val-total", default=preprocess_text_config.max_val_total)
32
+ @click.option("--clean/--no-clean", default=preprocess_text_config.clean)
33
+ @click.option("-y", "--yml_config")
34
+ def preprocess(
35
+ transcription_path: str,
36
+ cleaned_path: Optional[str],
37
+ train_path: str,
38
+ val_path: str,
39
+ config_path: str,
40
+ val_per_spk: int,
41
+ max_val_total: int,
42
+ clean: bool,
43
+ yml_config: str, # 这个不要删
44
+ ):
45
+ if cleaned_path == "" or cleaned_path is None:
46
+ cleaned_path = transcription_path + ".cleaned"
47
+
48
+ if clean:
49
+ with open(cleaned_path, "w", encoding="utf-8") as out_file:
50
+ with open(transcription_path, "r", encoding="utf-8") as trans_file:
51
+ lines = trans_file.readlines()
52
+ # print(lines, ' ', len(lines))
53
+ if len(lines) != 0:
54
+ for line in tqdm(lines):
55
+ try:
56
+ utt, spk, language, text = line.strip().split("|")
57
+ norm_text, phones, tones, word2ph = clean_text(
58
+ text, language
59
+ )
60
+ out_file.write(
61
+ "{}|{}|{}|{}|{}|{}|{}\n".format(
62
+ utt,
63
+ spk,
64
+ language,
65
+ norm_text,
66
+ " ".join(phones),
67
+ " ".join([str(i) for i in tones]),
68
+ " ".join([str(i) for i in word2ph]),
69
+ )
70
+ )
71
+ except Exception as e:
72
+ print(line)
73
+ print(f"生成训练集和验证集时发生错误!, 详细信息:\n{e}")
74
+
75
+ transcription_path = cleaned_path
76
+ spk_utt_map = defaultdict(list)
77
+ spk_id_map = {}
78
+ current_sid = 0
79
+
80
+ with open(transcription_path, "r", encoding="utf-8") as f:
81
+ audioPaths = set()
82
+ countSame = 0
83
+ countNotFound = 0
84
+ for line in f.readlines():
85
+ utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
86
+ if utt in audioPaths:
87
+ # 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题
88
+ print(f"重复音频文本:{line}")
89
+ countSame += 1
90
+ continue
91
+ if not os.path.isfile(utt):
92
+ # 过滤数据集错误:不存在对应音频
93
+ print(f"没有找到对应的音频:{utt}")
94
+ countNotFound += 1
95
+ continue
96
+ audioPaths.add(utt)
97
+ spk_utt_map[spk].append(line)
98
+
99
+ if spk not in spk_id_map.keys():
100
+ spk_id_map[spk] = current_sid
101
+ current_sid += 1
102
+ print(f"总重复音频数:{countSame},总未找到的音频数:{countNotFound}")
103
+
104
+ train_list = []
105
+ val_list = []
106
+
107
+ for spk, utts in spk_utt_map.items():
108
+ shuffle(utts)
109
+ val_list += utts[:val_per_spk]
110
+ train_list += utts[val_per_spk:]
111
+
112
+ if len(val_list) > max_val_total:
113
+ train_list += val_list[max_val_total:]
114
+ val_list = val_list[:max_val_total]
115
+
116
+ with open(train_path, "w", encoding="utf-8") as f:
117
+ for line in train_list:
118
+ f.write(line)
119
+
120
+ with open(val_path, "w", encoding="utf-8") as f:
121
+ for line in val_list:
122
+ f.write(line)
123
+
124
+ json_config = json.load(open(config_path, encoding="utf-8"))
125
+ json_config["data"]["spk2id"] = spk_id_map
126
+ # 新增写入:写入训练版本、数据集路径
127
+ json_config["version"] = latest_version
128
+ json_config["data"]["training_files"] = os.path.normpath(train_path).replace(
129
+ "\\", "/"
130
+ )
131
+ json_config["data"]["validation_files"] = os.path.normpath(val_path).replace(
132
+ "\\", "/"
133
+ )
134
+ with open(config_path, "w", encoding="utf-8") as f:
135
+ json.dump(json_config, f, indent=2, ensure_ascii=False)
136
+ print("训练集和验证集生成完成!")
137
+
138
+
139
+ if __name__ == "__main__":
140
+ preprocess()
re_matching.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def extract_language_and_text_updated(speaker, dialogue):
5
+ # 使用正则表达式匹配<语言>标签和其后的文本
6
+ pattern_language_text = r"<(\S+?)>([^<]+)"
7
+ matches = re.findall(pattern_language_text, dialogue, re.DOTALL)
8
+ speaker = speaker[1:-1]
9
+ # 清理文本:去除两边的空白字符
10
+ matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches]
11
+ matches_cleaned.append(speaker)
12
+ return matches_cleaned
13
+
14
+
15
+ def validate_text(input_text):
16
+ # 验证说话人的正则表达式
17
+ pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)"
18
+
19
+ # 使用re.DOTALL标志使.匹配包括换行符在内的所有字符
20
+ matches = re.findall(pattern_speaker, input_text, re.DOTALL)
21
+
22
+ # 对每个匹配到的说话人内容进行进一步验证
23
+ for _, dialogue in matches:
24
+ language_text_matches = extract_language_and_text_updated(_, dialogue)
25
+ if not language_text_matches:
26
+ return (
27
+ False,
28
+ "Error: Invalid format detected in dialogue content. Please check your input.",
29
+ )
30
+
31
+ # 如果输入的文本中没有找到任何匹配项
32
+ if not matches:
33
+ return (
34
+ False,
35
+ "Error: No valid speaker format detected. Please check your input.",
36
+ )
37
+
38
+ return True, "Input is valid."
39
+
40
+
41
+ def text_matching(text: str) -> list:
42
+ speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)"
43
+ matches = re.findall(speaker_pattern, text, re.DOTALL)
44
+ result = []
45
+ for speaker, dialogue in matches:
46
+ result.append(extract_language_and_text_updated(speaker, dialogue))
47
+ print(result)
48
+ return result
49
+
50
+
51
+ def cut_para(text):
52
+ splitted_para = re.split("[\n]", text) # 按段分
53
+ splitted_para = [
54
+ sentence.strip() for sentence in splitted_para if sentence.strip()
55
+ ] # 删除空字符串
56
+ return splitted_para
57
+
58
+
59
+ def cut_sent(para):
60
+ para = re.sub("([。!;?\?])([^”’])", r"\1\n\2", para) # 单字符断句符
61
+ para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号
62
+ para = re.sub("(\…{2})([^”’])", r"\1\n\2", para) # 中文省略号
63
+ para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para)
64
+ para = para.rstrip() # 段尾如果有多余的\n就去掉它
65
+ return para.split("\n")
66
+
67
+
68
+ if __name__ == "__main__":
69
+ text = """
70
+ [说话人1]
71
+ [说话人2]<zh>你好吗?<jp>元気ですか?<jp>こんにちは,世界。<zh>你好吗?
72
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
73
+ """
74
+ text_matching(text)
75
+ # 测试函数
76
+ test_text = """
77
+ [说话人1]<zh>你好,こんにちは!<jp>こんにちは,世界。
78
+ [说话人2]<zh>你好吗?
79
+ """
80
+ text_matching(test_text)
81
+ res = validate_text(test_text)
82
+ print(res)