JohnnySilverhand-Bert-VITS2 / models_onnx.py
hk-gosuto's picture
first commit
47d9c0f
import math
import torch
from torch import nn
from torch.nn import functional as F
import commons
import modules
import attentions_onnx
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from commons import init_weights, get_padding
from text import symbols, num_tones, num_languages
class DurationDiscriminator(nn.Module): # vits2
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.dur_proj = nn.Conv1d(1, filter_channels, 1)
self.pre_out_conv_1 = nn.Conv1d(
2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
self.pre_out_conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
def forward_probability(self, x, x_mask, dur, g=None):
dur = self.dur_proj(dur)
x = torch.cat([x, dur], dim=1)
x = self.pre_out_conv_1(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_1(x)
x = self.drop(x)
x = self.pre_out_conv_2(x * x_mask)
x = torch.relu(x)
x = self.pre_out_norm_2(x)
x = self.drop(x)
x = x * x_mask
x = x.transpose(1, 2)
output_prob = self.output_layer(x)
return output_prob
def forward(self, x, x_mask, dur_r, dur_hat, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
output_probs = []
for dur in [dur_r, dur_hat]:
output_prob = self.forward_probability(x, x_mask, dur, g)
output_probs.append(output_prob)
return output_probs
class TransformerCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
share_parameter=False,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
self.wn = (
attentions_onnx.FFT(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
isflow=True,
gin_channels=self.gin_channels,
)
if share_parameter
else None
)
for i in range(n_flows):
self.flows.append(
modules.TransformerCouplingLayer(
channels,
hidden_channels,
kernel_size,
n_layers,
n_heads,
p_dropout,
filter_channels,
mean_only=True,
wn_sharing_parameter=self.wn,
gin_channels=self.gin_channels,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=True):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class StochasticDurationPredictor(nn.Module):
def __init__(
self,
in_channels,
filter_channels,
kernel_size,
p_dropout,
n_flows=4,
gin_channels=0,
):
super().__init__()
filter_channels = in_channels # it needs to be removed from future version.
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.n_flows = n_flows
self.gin_channels = gin_channels
self.log_flow = modules.Log()
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
def forward(self, x, x_mask, z, g=None):
x = torch.detach(x)
x = self.pre(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.convs(x, x_mask)
x = self.proj(x) * x_mask
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
for flow in flows:
z = flow(z, x_mask, g=x, reverse=True)
z0, z1 = torch.split(z, [1, 1], 1)
logw = z0
return logw
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
def forward(self, x, x_mask, g=None):
x = torch.detach(x)
if g is not None:
g = torch.detach(g)
x = x + self.cond(g)
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
gin_channels=0,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.gin_channels = gin_channels
self.emb = nn.Embedding(len(symbols), hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
self.tone_emb = nn.Embedding(num_tones, hidden_channels)
nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
self.language_emb = nn.Embedding(num_languages, hidden_channels)
nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
self.encoder = attentions_onnx.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
gin_channels=self.gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, tone, language, bert, ja_bert, en_bert, g=None):
x_mask = torch.ones_like(x).unsqueeze(0)
bert_emb = self.bert_proj(bert.transpose(0, 1).unsqueeze(0)).transpose(1, 2)
ja_bert_emb = self.ja_bert_proj(ja_bert.transpose(0, 1).unsqueeze(0)).transpose(
1, 2
)
en_bert_emb = self.en_bert_proj(en_bert.transpose(0, 1).unsqueeze(0)).transpose(
1, 2
)
x = (
self.emb(x)
+ self.tone_emb(tone)
+ self.language_emb(language)
+ bert_emb
+ ja_bert_emb
+ en_bert_emb
) * math.sqrt(
self.hidden_channels
) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = x_mask.to(x.dtype)
x = self.encoder(x * x_mask, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
return x, m, logs, x_mask
class ResidualCouplingBlock(nn.Module):
def __init__(
self,
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
n_flows=4,
gin_channels=0,
):
super().__init__()
self.channels = channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.n_flows = n_flows
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for i in range(n_flows):
self.flows.append(
modules.ResidualCouplingLayer(
channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
mean_only=True,
)
)
self.flows.append(modules.Flip())
def forward(self, x, x_mask, g=None, reverse=True):
if not reverse:
for flow in self.flows:
x, _ = flow(x, x_mask, g=g, reverse=reverse)
else:
for flow in reversed(self.flows):
x = flow(x, x_mask, g=g, reverse=reverse)
return x
class PosteriorEncoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(
hidden_channels,
kernel_size,
dilation_rate,
n_layers,
gin_channels=gin_channels,
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
m, logs = torch.split(stats, self.out_channels, dim=1)
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
return z, m, logs, x_mask
class Generator(torch.nn.Module):
def __init__(
self,
initial_channel,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g=None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, modules.LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print("Removing weight norm...")
for layer in self.ups:
remove_weight_norm(layer)
for layer in self.resblocks:
layer.remove_weight_norm()
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
self.use_spectral_norm = use_spectral_norm
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(
Conv2d(
1,
32,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
32,
128,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
128,
512,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
512,
1024,
(kernel_size, 1),
(stride, 1),
padding=(get_padding(kernel_size, 1), 0),
)
),
norm_f(
Conv2d(
1024,
1024,
(kernel_size, 1),
1,
padding=(get_padding(kernel_size, 1), 0),
)
),
]
)
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for layer in self.convs:
x = layer(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for layer in self.convs:
x = layer(x)
x = F.leaky_relu(x, modules.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
y_d_gs.append(y_d_g)
fmap_rs.append(fmap_r)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class ReferenceEncoder(nn.Module):
"""
inputs --- [N, Ty/r, n_mels*r] mels
outputs --- [N, ref_enc_gru_size]
"""
def __init__(self, spec_channels, gin_channels=0):
super().__init__()
self.spec_channels = spec_channels
ref_enc_filters = [32, 32, 64, 64, 128, 128]
K = len(ref_enc_filters)
filters = [1] + ref_enc_filters
convs = [
weight_norm(
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
)
)
for i in range(K)
]
self.convs = nn.ModuleList(convs)
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
self.gru = nn.GRU(
input_size=ref_enc_filters[-1] * out_channels,
hidden_size=256 // 2,
batch_first=True,
)
self.proj = nn.Linear(128, gin_channels)
def forward(self, inputs, mask=None):
N = inputs.size(0)
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
for conv in self.convs:
out = conv(out)
# out = wn(out)
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
T = out.size(1)
N = out.size(0)
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
self.gru.flatten_parameters()
memory, out = self.gru(out) # out --- [1, N, 128]
return self.proj(out.squeeze(0))
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
n_vocab,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=256,
gin_channels=256,
use_sdp=True,
n_flow_layer=4,
n_layers_trans_flow=4,
flow_share_parameter=False,
use_transformer_flow=True,
**kwargs,
):
super().__init__()
self.n_vocab = n_vocab
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.n_layers_trans_flow = n_layers_trans_flow
self.use_spk_conditioned_encoder = kwargs.get(
"use_spk_conditioned_encoder", True
)
self.use_sdp = use_sdp
self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
self.current_mas_noise_scale = self.mas_noise_scale_initial
if self.use_spk_conditioned_encoder and gin_channels > 0:
self.enc_gin_channels = gin_channels
self.enc_p = TextEncoder(
n_vocab,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
gin_channels=self.enc_gin_channels,
)
self.dec = Generator(
inter_channels,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=gin_channels,
)
self.enc_q = PosteriorEncoder(
spec_channels,
inter_channels,
hidden_channels,
5,
1,
16,
gin_channels=gin_channels,
)
if use_transformer_flow:
self.flow = TransformerCouplingBlock(
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers_trans_flow,
5,
p_dropout,
n_flow_layer,
gin_channels=gin_channels,
share_parameter=flow_share_parameter,
)
else:
self.flow = ResidualCouplingBlock(
inter_channels,
hidden_channels,
5,
1,
n_flow_layer,
gin_channels=gin_channels,
)
self.sdp = StochasticDurationPredictor(
hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
)
self.dp = DurationPredictor(
hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
)
if n_speakers >= 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
else:
self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
def export_onnx(
self,
path,
max_len=None,
sdp_ratio=0,
y=None,
):
noise_scale = 0.667
length_scale = 1
noise_scale_w = 0.8
x = torch.LongTensor(
[
0,
97,
0,
8,
0,
78,
0,
8,
0,
76,
0,
37,
0,
40,
0,
97,
0,
8,
0,
23,
0,
8,
0,
74,
0,
26,
0,
104,
0,
]
).unsqueeze(0)
tone = torch.zeros_like(x)
language = torch.zeros_like(x)
x_lengths = torch.LongTensor([x.shape[1]])
sid = torch.LongTensor([0])
bert = torch.randn(size=(x.shape[1], 1024))
ja_bert = torch.randn(size=(x.shape[1], 1024))
en_bert = torch.randn(size=(x.shape[1], 1024))
if self.n_speakers > 0:
g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
torch.onnx.export(
self.emb_g,
(sid),
f"onnx/{path}/{path}_emb.onnx",
input_names=["sid"],
output_names=["g"],
verbose=True,
)
else:
g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
torch.onnx.export(
self.enc_p,
(x, x_lengths, tone, language, bert, ja_bert, en_bert, g),
f"onnx/{path}/{path}_enc_p.onnx",
input_names=[
"x",
"x_lengths",
"t",
"language",
"bert_0",
"bert_1",
"bert_2",
"g",
],
output_names=["xout", "m_p", "logs_p", "x_mask"],
dynamic_axes={
"x": [0, 1],
"t": [0, 1],
"language": [0, 1],
"bert_0": [0],
"bert_1": [0],
"bert_2": [0],
"xout": [0, 2],
"m_p": [0, 2],
"logs_p": [0, 2],
"x_mask": [0, 2],
},
verbose=True,
opset_version=16,
)
x, m_p, logs_p, x_mask = self.enc_p(
x, x_lengths, tone, language, bert, ja_bert, en_bert, g=g
)
zinput = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale_w
)
torch.onnx.export(
self.sdp,
(x, x_mask, zinput, g),
f"onnx/{path}/{path}_sdp.onnx",
input_names=["x", "x_mask", "zin", "g"],
output_names=["logw"],
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "zin": [0, 2], "logw": [0, 2]},
verbose=True,
)
torch.onnx.export(
self.dp,
(x, x_mask, g),
f"onnx/{path}/{path}_dp.onnx",
input_names=["x", "x_mask", "g"],
output_names=["logw"],
dynamic_axes={"x": [0, 2], "x_mask": [0, 2], "logw": [0, 2]},
verbose=True,
)
logw = self.sdp(x, x_mask, zinput, g=g) * (sdp_ratio) + self.dp(
x, x_mask, g=g
) * (1 - sdp_ratio)
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
x_mask.dtype
)
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = commons.generate_path(w_ceil, attn_mask)
m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
torch.onnx.export(
self.flow,
(z_p, y_mask, g),
f"onnx/{path}/{path}_flow.onnx",
input_names=["z_p", "y_mask", "g"],
output_names=["z"],
dynamic_axes={"z_p": [0, 2], "y_mask": [0, 2], "z": [0, 2]},
verbose=True,
)
z = self.flow(z_p, y_mask, g=g, reverse=True)
z_in = (z * y_mask)[:, :, :max_len]
torch.onnx.export(
self.dec,
(z_in, g),
f"onnx/{path}/{path}_dec.onnx",
input_names=["z_in", "g"],
output_names=["o"],
dynamic_axes={"z_in": [0, 2], "o": [0, 2]},
verbose=True,
)
o = self.dec((z * y_mask)[:, :, :max_len], g=g)