|
import torch |
|
|
|
from torch import nn |
|
import math |
|
from functools import partial |
|
from einops import rearrange, repeat |
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PCmer(nn.Module): |
|
"""The encoder that is used in the Transformer model.""" |
|
|
|
def __init__(self, |
|
num_layers, |
|
num_heads, |
|
dim_model, |
|
dim_keys, |
|
dim_values, |
|
residual_dropout, |
|
attention_dropout): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
self.num_heads = num_heads |
|
self.dim_model = dim_model |
|
self.dim_values = dim_values |
|
self.dim_keys = dim_keys |
|
self.residual_dropout = residual_dropout |
|
self.attention_dropout = attention_dropout |
|
|
|
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)]) |
|
|
|
|
|
|
|
def forward(self, phone, mask=None): |
|
|
|
|
|
for (i, layer) in enumerate(self._layers): |
|
phone = layer(phone, mask) |
|
|
|
return phone |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _EncoderLayer(nn.Module): |
|
"""One layer of the encoder. |
|
|
|
Attributes: |
|
attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence. |
|
feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism. |
|
""" |
|
|
|
def __init__(self, parent: PCmer): |
|
"""Creates a new instance of ``_EncoderLayer``. |
|
|
|
Args: |
|
parent (Encoder): The encoder that the layers is created for. |
|
""" |
|
super().__init__() |
|
|
|
|
|
self.conformer = ConformerConvModule(parent.dim_model) |
|
self.norm = nn.LayerNorm(parent.dim_model) |
|
self.dropout = nn.Dropout(parent.residual_dropout) |
|
|
|
|
|
self.attn = Attention(dim = parent.dim_model, |
|
heads = parent.num_heads,dim_head=32 |
|
|
|
) |
|
|
|
|
|
|
|
def forward(self, phone, mask=None): |
|
|
|
|
|
phone = phone + (self.attn(self.norm(phone), mask=mask)) |
|
|
|
phone = phone + (self.conformer(phone)) |
|
|
|
return phone |
|
|
|
def calc_same_padding(kernel_size): |
|
pad = kernel_size // 2 |
|
return (pad, pad - (kernel_size + 1) % 2) |
|
|
|
|
|
|
|
class Swish(nn.Module): |
|
def forward(self, x): |
|
return x * x.sigmoid() |
|
|
|
class Transpose(nn.Module): |
|
def __init__(self, dims): |
|
super().__init__() |
|
assert len(dims) == 2, 'dims must be a tuple of two dimensions' |
|
self.dims = dims |
|
|
|
def forward(self, x): |
|
return x.transpose(*self.dims) |
|
|
|
class GLU(nn.Module): |
|
def __init__(self, dim): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
out, gate = x.chunk(2, dim=self.dim) |
|
return out * gate.sigmoid() |
|
|
|
class DepthWiseConv1d(nn.Module): |
|
def __init__(self, chan_in, chan_out, kernel_size, padding): |
|
super().__init__() |
|
self.padding = padding |
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) |
|
|
|
def forward(self, x): |
|
x = F.pad(x, self.padding) |
|
return self.conv(x) |
|
|
|
class ConformerConvModule(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
causal = False, |
|
expansion_factor = 2, |
|
kernel_size = 31, |
|
dropout = 0.): |
|
super().__init__() |
|
|
|
inner_dim = dim * expansion_factor |
|
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) |
|
|
|
self.net = nn.Sequential( |
|
nn.LayerNorm(dim), |
|
Transpose((1, 2)), |
|
nn.Conv1d(dim, inner_dim * 2, 1), |
|
GLU(dim=1), |
|
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), |
|
|
|
Swish(), |
|
nn.Conv1d(inner_dim, dim, 1), |
|
Transpose((1, 2)), |
|
nn.Dropout(dropout) |
|
) |
|
|
|
def forward(self, x): |
|
return self.net(x) |
|
|
|
|
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, dim, heads=4, dim_head=64, conditiondim=None): |
|
super().__init__() |
|
if conditiondim is None: |
|
conditiondim = dim |
|
|
|
self.scale = dim_head ** -0.5 |
|
self.heads = heads |
|
hidden_dim = dim_head * heads |
|
self.to_q = nn.Linear(dim, hidden_dim, bias=False) |
|
self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False) |
|
|
|
self.to_out = nn.Sequential(nn.Linear(hidden_dim, dim, ), |
|
) |
|
|
|
def forward(self, q, kv=None, mask=None): |
|
|
|
if kv is None: |
|
kv = q |
|
|
|
|
|
|
|
|
|
q = self.to_q(q) |
|
k, v = self.to_kv(kv).chunk(2, dim=2) |
|
|
|
q, k, v = map( |
|
lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v) |
|
) |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1).unsqueeze(1) |
|
|
|
with torch.backends.cuda.sdp_kernel( |
|
): |
|
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) |
|
|
|
out = rearrange(out, "b h t c -> b t (h c) ", h=self.heads, ) |
|
return self.to_out(out) |
|
|