harveen
Add Tamil
3b92d66
raw history blame
No virus
12.9 kB
import math
import torch
from torch import nn
from torch.nn import functional as F
import modules
import commons
import attentions
import monotonic_align
class DurationPredictor(nn.Module):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
super().__init__()
self.in_channels = in_channels
self.filter_channels = filter_channels
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_1 = attentions.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.norm_2 = attentions.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
def forward(self, x, x_mask):
x = self.conv_1(x * x_mask)
x = torch.relu(x)
x = self.norm_1(x)
x = self.drop(x)
x = self.conv_2(x * x_mask)
x = torch.relu(x)
x = self.norm_2(x)
x = self.drop(x)
x = self.proj(x * x_mask)
return x * x_mask
class TextEncoder(nn.Module):
def __init__(
self,
n_vocab,
out_channels,
hidden_channels,
filter_channels,
filter_channels_dp,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=None,
block_length=None,
mean_only=False,
prenet=False,
gin_channels=0,
):
super().__init__()
self.n_vocab = n_vocab
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.window_size = window_size
self.block_length = block_length
self.mean_only = mean_only
self.prenet = prenet
self.gin_channels = gin_channels
self.emb = nn.Embedding(n_vocab, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
if prenet:
self.pre = modules.ConvReluNorm(
hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
n_layers=3,
p_dropout=0.5,
)
self.encoder = attentions.Encoder(
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
window_size=window_size,
block_length=block_length,
)
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
if not mean_only:
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
self.proj_w = DurationPredictor(
hidden_channels + gin_channels, filter_channels_dp, kernel_size, p_dropout
)
def forward(self, x, x_lengths, g=None):
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
x = torch.transpose(x, 1, -1) # [b, h, t]
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
if self.prenet:
x = self.pre(x, x_mask)
x = self.encoder(x, x_mask)
if g is not None:
g_exp = g.expand(-1, -1, x.size(-1))
x_dp = torch.cat([torch.detach(x), g_exp], 1)
else:
x_dp = torch.detach(x)
x_m = self.proj_m(x) * x_mask
if not self.mean_only:
x_logs = self.proj_s(x) * x_mask
else:
x_logs = torch.zeros_like(x_m)
logw = self.proj_w(x_dp, x_mask)
return x_m, x_logs, logw, x_mask
class FlowSpecDecoder(nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
n_blocks,
n_layers,
p_dropout=0.0,
n_split=4,
n_sqz=2,
sigmoid_scale=False,
gin_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_blocks = n_blocks
self.n_layers = n_layers
self.p_dropout = p_dropout
self.n_split = n_split
self.n_sqz = n_sqz
self.sigmoid_scale = sigmoid_scale
self.gin_channels = gin_channels
self.flows = nn.ModuleList()
for b in range(n_blocks):
self.flows.append(modules.ActNorm(channels=in_channels * n_sqz))
self.flows.append(
modules.InvConvNear(channels=in_channels * n_sqz, n_split=n_split)
)
self.flows.append(
attentions.CouplingBlock(
in_channels * n_sqz,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
n_layers=n_layers,
gin_channels=gin_channels,
p_dropout=p_dropout,
sigmoid_scale=sigmoid_scale,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:
flows = self.flows
logdet_tot = 0
else:
flows = reversed(self.flows)
logdet_tot = None
if self.n_sqz > 1:
x, x_mask = commons.squeeze(x, x_mask, self.n_sqz)
for f in flows:
if not reverse:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
logdet_tot += logdet
else:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
if self.n_sqz > 1:
x, x_mask = commons.unsqueeze(x, x_mask, self.n_sqz)
return x, logdet_tot
def store_inverse(self):
for f in self.flows:
f.store_inverse()
class FlowGenerator(nn.Module):
def __init__(
self,
n_vocab,
hidden_channels,
filter_channels,
filter_channels_dp,
out_channels,
kernel_size=3,
n_heads=2,
n_layers_enc=6,
p_dropout=0.0,
n_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
n_block_layers=4,
p_dropout_dec=0.0,
n_speakers=0,
gin_channels=0,
n_split=4,
n_sqz=1,
sigmoid_scale=False,
window_size=None,
block_length=None,
mean_only=False,
hidden_channels_enc=None,
hidden_channels_dec=None,
prenet=False,
**kwargs
):
super().__init__()
self.n_vocab = n_vocab
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.filter_channels_dp = filter_channels_dp
self.out_channels = out_channels
self.kernel_size = kernel_size
self.n_heads = n_heads
self.n_layers_enc = n_layers_enc
self.p_dropout = p_dropout
self.n_blocks_dec = n_blocks_dec
self.kernel_size_dec = kernel_size_dec
self.dilation_rate = dilation_rate
self.n_block_layers = n_block_layers
self.p_dropout_dec = p_dropout_dec
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.n_split = n_split
self.n_sqz = n_sqz
self.sigmoid_scale = sigmoid_scale
self.window_size = window_size
self.block_length = block_length
self.mean_only = mean_only
self.hidden_channels_enc = hidden_channels_enc
self.hidden_channels_dec = hidden_channels_dec
self.prenet = prenet
self.encoder = TextEncoder(
n_vocab,
out_channels,
hidden_channels_enc or hidden_channels,
filter_channels,
filter_channels_dp,
n_heads,
n_layers_enc,
kernel_size,
p_dropout,
window_size=window_size,
block_length=block_length,
mean_only=mean_only,
prenet=prenet,
gin_channels=gin_channels,
)
self.decoder = FlowSpecDecoder(
out_channels,
hidden_channels_dec or hidden_channels,
kernel_size_dec,
dilation_rate,
n_blocks_dec,
n_block_layers,
p_dropout=p_dropout_dec,
n_split=n_split,
n_sqz=n_sqz,
sigmoid_scale=sigmoid_scale,
gin_channels=gin_channels,
)
if n_speakers > 1:
self.emb_g = nn.Embedding(n_speakers, gin_channels)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
def forward(
self,
x,
x_lengths,
y=None,
y_lengths=None,
g=None,
gen=False,
noise_scale=1.0,
length_scale=1.0,
):
if g is not None:
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths, g=g)
if gen:
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
else:
y_max_length = y.size(2)
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to(
x_mask.dtype
)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
if gen:
attn = commons.generate_path(
w_ceil.squeeze(1), attn_mask.squeeze(1)
).unsqueeze(1)
z_m = torch.matmul(
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)
).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)
).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
y, logdet = self.decoder(z, z_mask, g=g, reverse=True)
return (
(y, z_m, z_logs, logdet, z_mask),
(x_m, x_logs, x_mask),
(attn, logw, logw_),
)
else:
z, logdet = self.decoder(y, z_mask, g=g, reverse=False)
with torch.no_grad():
x_s_sq_r = torch.exp(-2 * x_logs)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(
-1
) # [b, t, 1]
logp2 = torch.matmul(
x_s_sq_r.transpose(1, 2), -0.5 * (z ** 2)
) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul(
(x_m * x_s_sq_r).transpose(1, 2), z
) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(
-1
) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = (
monotonic_align.maximum_path(logp, attn_mask.squeeze(1))
.unsqueeze(1)
.detach()
)
z_m = torch.matmul(
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)
).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)
).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
return (
(z, z_m, z_logs, logdet, z_mask),
(x_m, x_logs, x_mask),
(attn, logw, logw_),
)
def preprocess(self, y, y_lengths, y_max_length):
if y_max_length is not None:
y_max_length = (y_max_length // self.n_sqz) * self.n_sqz
y = y[:, :, :y_max_length]
y_lengths = (y_lengths // self.n_sqz) * self.n_sqz
return y, y_lengths, y_max_length
def store_inverse(self):
self.decoder.store_inverse()