File size: 4,677 Bytes
9b2107c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from torch import nn
from TTS.tts.layers.generic.normalization import ActNorm
from TTS.tts.layers.glow_tts.glow import CouplingBlock, InvConvNear
def squeeze(x, x_mask=None, num_sqz=2):
"""GlowTTS squeeze operation
Increase number of channels and reduce number of time steps
by the same factor.
Note:
each 's' is a n-dimensional vector.
``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]``
"""
b, c, t = x.size()
t = (t // num_sqz) * num_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz]
else:
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask
def unsqueeze(x, x_mask=None, num_sqz=2):
"""GlowTTS unsqueeze operation (revert the squeeze)
Note:
each 's' is a n-dimensional vector.
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5, s2, s4, s6]]``
"""
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz)
else:
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask
class Decoder(nn.Module):
"""Stack of Glow Decoder Modules.
::
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
Args:
in_channels (int): channels of input tensor.
hidden_channels (int): hidden decoder channels.
kernel_size (int): Coupling block kernel size. (Wavenet filter kernel size.)
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_flow_blocks (int): number of decoder blocks.
num_coupling_layers (int): number coupling layers. (number of wavenet layers.)
dropout_p (float): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
"""
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.num_flow_blocks = num_flow_blocks
self.num_coupling_layers = num_coupling_layers
self.dropout_p = dropout_p
self.num_splits = num_splits
self.num_squeeze = num_squeeze
self.sigmoid_scale = sigmoid_scale
self.c_in_channels = c_in_channels
self.flows = nn.ModuleList()
for _ in range(num_flow_blocks):
self.flows.append(ActNorm(channels=in_channels * num_squeeze))
self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits))
self.flows.append(
CouplingBlock(
in_channels * num_squeeze,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_coupling_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
sigmoid_scale=sigmoid_scale,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
"""
Shapes:
- x: :math:`[B, C, T]`
- x_mask: :math:`[B, 1 ,T]`
- g: :math:`[B, C]`
"""
if not reverse:
flows = self.flows
logdet_tot = 0
else:
flows = reversed(self.flows)
logdet_tot = None
if self.num_squeeze > 1:
x, x_mask = squeeze(x, x_mask, self.num_squeeze)
for f in flows:
if not reverse:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
logdet_tot += logdet
else:
x, logdet = f(x, x_mask, g=g, reverse=reverse)
if self.num_squeeze > 1:
x, x_mask = unsqueeze(x, x_mask, self.num_squeeze)
return x, logdet_tot
def store_inverse(self):
for f in self.flows:
f.store_inverse()
|