|
import torch |
|
import math |
|
import inspect |
|
from torch import nn |
|
from torch import Tensor |
|
from typing import Tuple |
|
from typing import Optional |
|
from torch.nn.functional import fold, unfold |
|
import numpy as np |
|
|
|
from . import activations, normalizations |
|
from .normalizations import gLN |
|
|
|
|
|
def has_arg(fn, name): |
|
"""Checks if a callable accepts a given keyword argument. |
|
Args: |
|
fn (callable): Callable to inspect. |
|
name (str): Check if ``fn`` can be called with ``name`` as a keyword |
|
argument. |
|
Returns: |
|
bool: whether ``fn`` accepts a ``name`` keyword argument. |
|
""" |
|
signature = inspect.signature(fn) |
|
parameter = signature.parameters.get(name) |
|
if parameter is None: |
|
return False |
|
return parameter.kind in ( |
|
inspect.Parameter.POSITIONAL_OR_KEYWORD, |
|
inspect.Parameter.KEYWORD_ONLY, |
|
) |
|
|
|
|
|
class SingleRNN(nn.Module): |
|
"""Module for a RNN block. |
|
Inspired from https://github.com/yluo42/TAC/blob/master/utility/models.py |
|
Licensed under CC BY-NC-SA 3.0 US. |
|
Args: |
|
rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can |
|
also be passed in lowercase letters. |
|
input_size (int): Dimension of the input feature. The input should have |
|
shape [batch, seq_len, input_size]. |
|
hidden_size (int): Dimension of the hidden state. |
|
n_layers (int, optional): Number of layers used in RNN. Default is 1. |
|
dropout (float, optional): Dropout ratio. Default is 0. |
|
bidirectional (bool, optional): Whether the RNN layers are |
|
bidirectional. Default is ``False``. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
rnn_type, |
|
input_size, |
|
hidden_size, |
|
n_layers=1, |
|
dropout=0, |
|
bidirectional=False, |
|
): |
|
super(SingleRNN, self).__init__() |
|
assert rnn_type.upper() in ["RNN", "LSTM", "GRU"] |
|
rnn_type = rnn_type.upper() |
|
self.rnn_type = rnn_type |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.n_layers = n_layers |
|
self.dropout = dropout |
|
self.bidirectional = bidirectional |
|
self.rnn = getattr(nn, rnn_type)( |
|
input_size, |
|
hidden_size, |
|
num_layers=n_layers, |
|
dropout=dropout, |
|
batch_first=True, |
|
bidirectional=bool(bidirectional), |
|
) |
|
|
|
@property |
|
def output_size(self): |
|
return self.hidden_size * (2 if self.bidirectional else 1) |
|
|
|
def forward(self, inp): |
|
""" Input shape [batch, seq, feats] """ |
|
self.rnn.flatten_parameters() |
|
output = inp |
|
rnn_output, _ = self.rnn(output) |
|
return rnn_output |
|
|
|
|
|
class LSTMBlockTF(nn.Module): |
|
def __init__( |
|
self, |
|
in_chan, |
|
hid_size, |
|
norm_type="gLN", |
|
bidirectional=True, |
|
rnn_type="LSTM", |
|
num_layers=1, |
|
dropout=0, |
|
): |
|
super(LSTMBlockTF, self).__init__() |
|
self.RNN = SingleRNN( |
|
rnn_type, |
|
in_chan, |
|
hid_size, |
|
num_layers, |
|
dropout=dropout, |
|
bidirectional=bidirectional, |
|
) |
|
self.linear = nn.Linear(self.RNN.output_size, in_chan) |
|
self.norm = normalizations.get(norm_type)(in_chan) |
|
|
|
def forward(self, x): |
|
B, F, T = x.size() |
|
output = self.RNN(x.transpose(1, 2)) |
|
output = self.linear(output) |
|
output = output.transpose(1, -1) |
|
output = self.norm(output) |
|
return output + x |
|
|
|
|
|
|
|
class Linear(nn.Module): |
|
""" |
|
Wrapper class of torch.nn.Linear |
|
Weight initialize by xavier initialization and bias initialize to zeros. |
|
""" |
|
|
|
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: |
|
super(Linear, self).__init__() |
|
self.linear = nn.Linear(in_features, out_features, bias=bias) |
|
nn.init.xavier_uniform_(self.linear.weight) |
|
if bias: |
|
nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x): |
|
return self.linear(x) |
|
|
|
|
|
class Swish(nn.Module): |
|
""" |
|
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied |
|
to a variety of challenging domains such as Image classification and Machine translation. |
|
""" |
|
|
|
def __init__(self): |
|
super(Swish, self).__init__() |
|
|
|
def forward(self, inputs): |
|
return inputs * inputs.sigmoid() |
|
|
|
|
|
class Transpose(nn.Module): |
|
""" Wrapper class of torch.transpose() for Sequential module. """ |
|
|
|
def __init__(self, shape: tuple): |
|
super(Transpose, self).__init__() |
|
self.shape = shape |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return x.transpose(*self.shape) |
|
|
|
|
|
class GLU(nn.Module): |
|
""" |
|
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing |
|
in the paper “Language Modeling with Gated Convolutional Networks” |
|
""" |
|
|
|
def __init__(self, dim: int) -> None: |
|
super(GLU, self).__init__() |
|
self.dim = dim |
|
|
|
def forward(self, inputs: Tensor) -> Tensor: |
|
outputs, gate = inputs.chunk(2, dim=self.dim) |
|
return outputs * gate.sigmoid() |
|
|
|
|
|
class FeedForwardModule(nn.Module): |
|
def __init__( |
|
self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1, |
|
) -> None: |
|
super(FeedForwardModule, self).__init__() |
|
self.sequential = nn.Sequential( |
|
nn.LayerNorm(encoder_dim), |
|
Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), |
|
Swish(), |
|
nn.Dropout(p=dropout_p), |
|
Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), |
|
nn.Dropout(p=dropout_p), |
|
) |
|
|
|
def forward(self, inputs): |
|
return self.sequential(inputs) |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
""" |
|
Positional Encoding proposed in "Attention Is All You Need". |
|
Since transformer contains no recurrence and no convolution, in order for the model to make |
|
use of the order of the sequence, we must add some positional information. |
|
"Attention Is All You Need" use sine and cosine functions of different frequencies: |
|
PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) |
|
PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) |
|
""" |
|
|
|
def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: |
|
super(PositionalEncoding, self).__init__() |
|
pe = torch.zeros(max_len, d_model, requires_grad=False) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp( |
|
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) |
|
) |
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0) |
|
self.register_buffer("pe", pe) |
|
|
|
def forward(self, length: int) -> Tensor: |
|
return self.pe[:, :length] |
|
|
|
|
|
class RelativeMultiHeadAttention(nn.Module): |
|
""" |
|
Multi-head attention with relative positional encoding. |
|
This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" |
|
Args: |
|
d_model (int): The dimension of model |
|
num_heads (int): The number of attention heads. |
|
dropout_p (float): probability of dropout |
|
Inputs: query, key, value, pos_embedding, mask |
|
- **query** (batch, time, dim): Tensor containing query vector |
|
- **key** (batch, time, dim): Tensor containing key vector |
|
- **value** (batch, time, dim): Tensor containing value vector |
|
- **pos_embedding** (batch, time, dim): Positional embedding tensor |
|
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked |
|
Returns: |
|
- **outputs**: Tensor produces by relative multi head attention module. |
|
""" |
|
|
|
def __init__( |
|
self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1, |
|
): |
|
super(RelativeMultiHeadAttention, self).__init__() |
|
assert d_model % num_heads == 0, "d_model % num_heads should be zero." |
|
self.d_model = d_model |
|
self.d_head = int(d_model / num_heads) |
|
self.num_heads = num_heads |
|
self.sqrt_dim = math.sqrt(d_model) |
|
|
|
self.query_proj = Linear(d_model, d_model) |
|
self.key_proj = Linear(d_model, d_model) |
|
self.value_proj = Linear(d_model, d_model) |
|
self.pos_proj = Linear(d_model, d_model, bias=False) |
|
|
|
self.dropout = nn.Dropout(p=dropout_p) |
|
self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) |
|
torch.nn.init.xavier_uniform_(self.u_bias) |
|
torch.nn.init.xavier_uniform_(self.v_bias) |
|
|
|
self.out_proj = Linear(d_model, d_model) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
pos_embedding: Tensor, |
|
mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
batch_size = value.size(0) |
|
|
|
query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) |
|
key = ( |
|
self.key_proj(key) |
|
.view(batch_size, -1, self.num_heads, self.d_head) |
|
.permute(0, 2, 1, 3) |
|
) |
|
value = ( |
|
self.value_proj(value) |
|
.view(batch_size, -1, self.num_heads, self.d_head) |
|
.permute(0, 2, 1, 3) |
|
) |
|
pos_embedding = self.pos_proj(pos_embedding).view( |
|
batch_size, -1, self.num_heads, self.d_head |
|
) |
|
|
|
content_score = torch.matmul( |
|
(query + self.u_bias).transpose(1, 2), key.transpose(2, 3) |
|
) |
|
pos_score = torch.matmul( |
|
(query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1) |
|
) |
|
pos_score = self._relative_shift(pos_score) |
|
|
|
score = (content_score + pos_score) / self.sqrt_dim |
|
|
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
score.masked_fill_(mask, -1e9) |
|
|
|
attn = torch.nn.functional.softmax(score, -1) |
|
attn = self.dropout(attn) |
|
|
|
context = torch.matmul(attn, value).transpose(1, 2) |
|
context = context.contiguous().view(batch_size, -1, self.d_model) |
|
|
|
return self.out_proj(context) |
|
|
|
def _relative_shift(self, pos_score: Tensor) -> Tensor: |
|
batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() |
|
zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) |
|
padded_pos_score = torch.cat([zeros, pos_score], dim=-1) |
|
|
|
padded_pos_score = padded_pos_score.view( |
|
batch_size, num_heads, seq_length2 + 1, seq_length1 |
|
) |
|
pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) |
|
|
|
return pos_score |
|
|
|
|
|
class MultiHeadedSelfAttentionModule(nn.Module): |
|
""" |
|
Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, |
|
the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention |
|
module to generalize better on different input length and the resulting encoder is more robust to the variance of |
|
the utterance length. Conformer use prenorm residual units with dropout which helps training |
|
and regularizing deeper models. |
|
Args: |
|
d_model (int): The dimension of model |
|
num_heads (int): The number of attention heads. |
|
dropout_p (float): probability of dropout |
|
device (torch.device): torch device (cuda or cpu) |
|
Inputs: inputs, mask |
|
- **inputs** (batch, time, dim): Tensor containing input vector |
|
- **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked |
|
Returns: |
|
- **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. |
|
""" |
|
|
|
def __init__( |
|
self, d_model: int, num_heads: int, dropout_p: float = 0.1, is_casual=True |
|
): |
|
super(MultiHeadedSelfAttentionModule, self).__init__() |
|
self.positional_encoding = PositionalEncoding(d_model) |
|
self.layer_norm = nn.LayerNorm(d_model) |
|
self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) |
|
self.dropout = nn.Dropout(p=dropout_p) |
|
self.is_casual = is_casual |
|
|
|
def forward(self, inputs: Tensor): |
|
batch_size, seq_length, _ = inputs.size() |
|
pos_embedding = self.positional_encoding(seq_length) |
|
pos_embedding = pos_embedding.repeat(batch_size, 1, 1) |
|
|
|
mask = None |
|
if self.is_casual: |
|
mask = torch.triu( |
|
torch.ones((seq_length, seq_length), dtype=torch.uint8).to( |
|
inputs.device |
|
), |
|
diagonal=1, |
|
) |
|
mask = mask.unsqueeze(0).expand(batch_size, -1, -1).bool() |
|
|
|
inputs = self.layer_norm(inputs) |
|
outputs = self.attention( |
|
inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask |
|
) |
|
|
|
return self.dropout(outputs) |
|
|
|
|
|
class ResidualConnectionModule(nn.Module): |
|
""" |
|
Residual Connection Module. |
|
outputs = (module(inputs) x module_factor + inputs x input_factor) |
|
""" |
|
|
|
def __init__( |
|
self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0 |
|
): |
|
super(ResidualConnectionModule, self).__init__() |
|
self.module = module |
|
self.module_factor = module_factor |
|
self.input_factor = input_factor |
|
|
|
def forward(self, inputs): |
|
return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) |
|
|
|
|
|
class DepthwiseConv1d(nn.Module): |
|
""" |
|
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, |
|
this operation is termed in literature as depthwise convolution. |
|
Args: |
|
in_channels (int): Number of channels in the input |
|
out_channels (int): Number of channels produced by the convolution |
|
kernel_size (int or tuple): Size of the convolving kernel |
|
stride (int, optional): Stride of the convolution. Default: 1 |
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 |
|
bias (bool, optional): If True, adds a learnable bias to the output. Default: True |
|
Inputs: inputs |
|
- **inputs** (batch, in_channels, time): Tensor containing input vector |
|
Returns: outputs |
|
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
bias: bool = False, |
|
is_casual: bool = True, |
|
) -> None: |
|
super(DepthwiseConv1d, self).__init__() |
|
assert ( |
|
out_channels % in_channels == 0 |
|
), "out_channels should be constant multiple of in_channels" |
|
if is_casual: |
|
padding = kernel_size - 1 |
|
else: |
|
padding = (kernel_size - 1) // 2 |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
groups=in_channels, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
self.is_casual = is_casual |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, inputs: Tensor) -> Tensor: |
|
if self.is_casual: |
|
return self.conv(inputs)[:, :, : -(self.kernel_size - 1)] |
|
return self.conv(inputs) |
|
|
|
|
|
class PointwiseConv1d(nn.Module): |
|
""" |
|
When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. |
|
This operation often used to match dimensions. |
|
Args: |
|
in_channels (int): Number of channels in the input |
|
out_channels (int): Number of channels produced by the convolution |
|
stride (int, optional): Stride of the convolution. Default: 1 |
|
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 |
|
bias (bool, optional): If True, adds a learnable bias to the output. Default: True |
|
Inputs: inputs |
|
- **inputs** (batch, in_channels, time): Tensor containing input vector |
|
Returns: outputs |
|
- **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
stride: int = 1, |
|
padding: int = 0, |
|
bias: bool = True, |
|
) -> None: |
|
super(PointwiseConv1d, self).__init__() |
|
self.conv = nn.Conv1d( |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
) |
|
|
|
def forward(self, inputs: Tensor) -> Tensor: |
|
return self.conv(inputs) |
|
|
|
|
|
class ConformerConvModule(nn.Module): |
|
""" |
|
Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). |
|
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution |
|
to aid training deep models. |
|
Args: |
|
in_channels (int): Number of channels in the input |
|
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 |
|
dropout_p (float, optional): probability of dropout |
|
device (torch.device): torch device (cuda or cpu) |
|
Inputs: inputs |
|
inputs (batch, time, dim): Tensor contains input sequences |
|
Outputs: outputs |
|
outputs (batch, time, dim): Tensor produces by conformer convolution module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
kernel_size: int = 31, |
|
expansion_factor: int = 2, |
|
dropout_p: float = 0.1, |
|
is_casual: bool = True, |
|
) -> None: |
|
super(ConformerConvModule, self).__init__() |
|
assert ( |
|
kernel_size - 1 |
|
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
|
assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" |
|
|
|
self.sequential = nn.Sequential( |
|
nn.LayerNorm(in_channels), |
|
Transpose(shape=(1, 2)), |
|
PointwiseConv1d( |
|
in_channels, |
|
in_channels * expansion_factor, |
|
stride=1, |
|
padding=0, |
|
bias=True, |
|
), |
|
GLU(dim=1), |
|
DepthwiseConv1d( |
|
in_channels, in_channels, kernel_size, stride=1, is_casual=is_casual |
|
), |
|
nn.BatchNorm1d(in_channels), |
|
Swish(), |
|
PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), |
|
nn.Dropout(p=dropout_p), |
|
) |
|
|
|
def forward(self, inputs: Tensor) -> Tensor: |
|
return self.sequential(inputs).transpose(1, 2) |
|
|
|
|
|
class TransformerLayer(nn.Module): |
|
def __init__( |
|
self, in_chan=128, n_head=8, n_att=1, dropout=0.1, max_len=500, is_casual=True |
|
): |
|
super(TransformerLayer, self).__init__() |
|
self.in_chan = in_chan |
|
self.n_head = n_head |
|
self.dropout = dropout |
|
self.max_len = max_len |
|
self.n_att = n_att |
|
|
|
self.seq = nn.Sequential( |
|
ResidualConnectionModule( |
|
FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), |
|
module_factor=0.5, |
|
), |
|
ResidualConnectionModule( |
|
MultiHeadedSelfAttentionModule(in_chan, n_head, dropout, is_casual) |
|
), |
|
ResidualConnectionModule( |
|
ConformerConvModule(in_chan, 31, 2, dropout, is_casual=is_casual) |
|
), |
|
ResidualConnectionModule( |
|
FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), |
|
module_factor=0.5, |
|
), |
|
nn.LayerNorm(in_chan), |
|
) |
|
|
|
def forward(self, x): |
|
return self.seq(x) |
|
|
|
|
|
class TransformerBlockTF(nn.Module): |
|
def __init__( |
|
self, |
|
in_chan, |
|
n_head=8, |
|
n_att=1, |
|
dropout=0.1, |
|
max_len=500, |
|
norm_type="cLN", |
|
is_casual=True, |
|
): |
|
super(TransformerBlockTF, self).__init__() |
|
self.transformer = TransformerLayer( |
|
in_chan, n_head, n_att, dropout, max_len, is_casual |
|
) |
|
self.norm = normalizations.get(norm_type)(in_chan) |
|
|
|
def forward(self, x): |
|
B, F, T = x.size() |
|
output = self.transformer(x.permute(0, 2, 1).contiguous()) |
|
output = output.permute(0, 2, 1).contiguous() |
|
output = self.norm(output) |
|
return output + x |
|
|
|
|
|
|
|
|
|
|
|
class DPRNNBlock(nn.Module): |
|
def __init__( |
|
self, |
|
in_chan, |
|
hid_size, |
|
norm_type="gLN", |
|
bidirectional=True, |
|
rnn_type="LSTM", |
|
num_layers=1, |
|
dropout=0, |
|
): |
|
super(DPRNNBlock, self).__init__() |
|
self.intra_RNN = SingleRNN( |
|
rnn_type, |
|
in_chan, |
|
hid_size, |
|
num_layers, |
|
dropout=dropout, |
|
bidirectional=True, |
|
) |
|
self.inter_RNN = SingleRNN( |
|
rnn_type, |
|
in_chan, |
|
hid_size, |
|
num_layers, |
|
dropout=dropout, |
|
bidirectional=bidirectional, |
|
) |
|
self.intra_linear = nn.Linear(self.intra_RNN.output_size, in_chan) |
|
self.intra_norm = normalizations.get(norm_type)(in_chan) |
|
|
|
self.inter_linear = nn.Linear(self.inter_RNN.output_size, in_chan) |
|
self.inter_norm = normalizations.get(norm_type)(in_chan) |
|
|
|
def forward(self, x): |
|
""" Input shape : [batch, feats, chunk_size, num_chunks] """ |
|
B, N, K, L = x.size() |
|
output = x |
|
|
|
x = x.transpose(1, -1).reshape(B * L, K, N) |
|
x = self.intra_RNN(x) |
|
x = self.intra_linear(x) |
|
x = x.reshape(B, L, K, N).transpose(1, -1) |
|
x = self.intra_norm(x) |
|
output = output + x |
|
|
|
x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N) |
|
x = self.inter_RNN(x) |
|
x = self.inter_linear(x) |
|
x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1).contiguous() |
|
x = self.inter_norm(x) |
|
return output + x |
|
|
|
|
|
class DPRNN(nn.Module): |
|
def __init__( |
|
self, |
|
in_chan, |
|
n_src, |
|
out_chan=None, |
|
bn_chan=128, |
|
hid_size=128, |
|
chunk_size=100, |
|
hop_size=None, |
|
n_repeats=6, |
|
norm_type="gLN", |
|
mask_act="relu", |
|
bidirectional=True, |
|
rnn_type="LSTM", |
|
num_layers=1, |
|
dropout=0, |
|
): |
|
super(DPRNN, self).__init__() |
|
self.in_chan = in_chan |
|
out_chan = out_chan if out_chan is not None else in_chan |
|
self.out_chan = out_chan |
|
self.bn_chan = bn_chan |
|
self.hid_size = hid_size |
|
self.chunk_size = chunk_size |
|
hop_size = hop_size if hop_size is not None else chunk_size // 2 |
|
self.hop_size = hop_size |
|
self.n_repeats = n_repeats |
|
self.n_src = n_src |
|
self.norm_type = norm_type |
|
self.mask_act = mask_act |
|
self.bidirectional = bidirectional |
|
self.rnn_type = rnn_type |
|
self.num_layers = num_layers |
|
self.dropout = dropout |
|
|
|
layer_norm = normalizations.get(norm_type)(in_chan) |
|
bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) |
|
self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) |
|
|
|
|
|
net = [] |
|
for x in range(self.n_repeats): |
|
net += [ |
|
DPRNNBlock( |
|
bn_chan, |
|
hid_size, |
|
norm_type=norm_type, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
) |
|
] |
|
self.net = nn.Sequential(*net) |
|
|
|
net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) |
|
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) |
|
|
|
self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) |
|
self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) |
|
self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) |
|
|
|
|
|
mask_nl_class = activations.get(mask_act) |
|
|
|
if has_arg(mask_nl_class, "dim"): |
|
self.output_act = mask_nl_class(dim=1) |
|
else: |
|
self.output_act = mask_nl_class() |
|
|
|
def forward(self, mixture_w): |
|
r"""Forward. |
|
Args: |
|
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ |
|
Returns: |
|
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ |
|
""" |
|
batch, n_filters, n_frames = mixture_w.size() |
|
output = self.bottleneck(mixture_w) |
|
output = unfold( |
|
output.unsqueeze(-1), |
|
kernel_size=(self.chunk_size, 1), |
|
padding=(self.chunk_size, 0), |
|
stride=(self.hop_size, 1), |
|
) |
|
n_chunks = output.shape[-1] |
|
output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) |
|
|
|
output = self.net(output) |
|
|
|
output = self.first_out(output) |
|
output = output.reshape( |
|
batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks |
|
) |
|
|
|
|
|
to_unfold = self.bn_chan * self.chunk_size |
|
output = fold( |
|
output.reshape(batch * self.n_src, to_unfold, n_chunks), |
|
(n_frames, 1), |
|
kernel_size=(self.chunk_size, 1), |
|
padding=(self.chunk_size, 0), |
|
stride=(self.hop_size, 1), |
|
) |
|
|
|
output = output.reshape(batch * self.n_src, self.bn_chan, -1) |
|
|
|
|
|
score = self.mask_net(output) |
|
est_mask = self.output_act(score) |
|
est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) |
|
return est_mask |
|
|
|
def get_config(self): |
|
config = { |
|
"in_chan": self.in_chan, |
|
"out_chan": self.out_chan, |
|
"bn_chan": self.bn_chan, |
|
"hid_size": self.hid_size, |
|
"chunk_size": self.chunk_size, |
|
"hop_size": self.hop_size, |
|
"n_repeats": self.n_repeats, |
|
"n_src": self.n_src, |
|
"norm_type": self.norm_type, |
|
"mask_act": self.mask_act, |
|
"bidirectional": self.bidirectional, |
|
"rnn_type": self.rnn_type, |
|
"num_layers": self.num_layers, |
|
"dropout": self.dropout, |
|
} |
|
return config |
|
|
|
|
|
class DPRNNLinear(nn.Module): |
|
def __init__( |
|
self, |
|
in_chan, |
|
n_src, |
|
out_chan=None, |
|
bn_chan=128, |
|
hid_size=128, |
|
chunk_size=100, |
|
hop_size=None, |
|
n_repeats=6, |
|
norm_type="gLN", |
|
mask_act="relu", |
|
bidirectional=True, |
|
rnn_type="LSTM", |
|
num_layers=1, |
|
dropout=0, |
|
): |
|
super(DPRNNLinear, self).__init__() |
|
self.in_chan = in_chan |
|
out_chan = out_chan if out_chan is not None else in_chan |
|
self.out_chan = out_chan |
|
self.bn_chan = bn_chan |
|
self.hid_size = hid_size |
|
self.chunk_size = chunk_size |
|
hop_size = hop_size if hop_size is not None else chunk_size // 2 |
|
self.hop_size = hop_size |
|
self.n_repeats = n_repeats |
|
self.n_src = n_src |
|
self.norm_type = norm_type |
|
self.mask_act = mask_act |
|
self.bidirectional = bidirectional |
|
self.rnn_type = rnn_type |
|
self.num_layers = num_layers |
|
self.dropout = dropout |
|
|
|
layer_norm = normalizations.get(norm_type)(in_chan) |
|
bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) |
|
self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) |
|
|
|
|
|
net = [] |
|
for x in range(self.n_repeats): |
|
net += [ |
|
DPRNNBlock( |
|
bn_chan, |
|
hid_size, |
|
norm_type=norm_type, |
|
bidirectional=bidirectional, |
|
rnn_type=rnn_type, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
) |
|
] |
|
self.net = nn.Sequential(*net) |
|
|
|
net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) |
|
self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) |
|
|
|
|
|
self.net_out = nn.Linear(bn_chan, out_chan) |
|
self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) |
|
self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) |
|
|
|
|
|
mask_nl_class = activations.get(mask_act) |
|
|
|
if has_arg(mask_nl_class, "dim"): |
|
self.output_act = mask_nl_class(dim=1) |
|
else: |
|
self.output_act = mask_nl_class() |
|
|
|
def forward(self, mixture_w): |
|
r"""Forward. |
|
Args: |
|
mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ |
|
Returns: |
|
:class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ |
|
""" |
|
batch, n_filters, n_frames = mixture_w.size() |
|
output = self.bottleneck(mixture_w) |
|
output = unfold( |
|
output.unsqueeze(-1), |
|
kernel_size=(self.chunk_size, 1), |
|
padding=(self.chunk_size, 0), |
|
stride=(self.hop_size, 1), |
|
) |
|
n_chunks = output.shape[-1] |
|
output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) |
|
|
|
output = self.net(output) |
|
|
|
output = self.first_out(output) |
|
output = output.reshape( |
|
batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks |
|
) |
|
|
|
|
|
to_unfold = self.bn_chan * self.chunk_size |
|
output = fold( |
|
output.reshape(batch * self.n_src, to_unfold, n_chunks), |
|
(n_frames, 1), |
|
kernel_size=(self.chunk_size, 1), |
|
padding=(self.chunk_size, 0), |
|
stride=(self.hop_size, 1), |
|
) |
|
|
|
output = output.reshape(batch * self.n_src, self.bn_chan, -1) |
|
output = self.net_out(output.transpose(1, 1)).transpose(1, 2) * self.net_gate( |
|
output |
|
) |
|
|
|
score = self.mask_net(output) |
|
est_mask = self.output_act(score) |
|
est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) |
|
return est_mask |
|
|
|
def get_config(self): |
|
config = { |
|
"in_chan": self.in_chan, |
|
"out_chan": self.out_chan, |
|
"bn_chan": self.bn_chan, |
|
"hid_size": self.hid_size, |
|
"chunk_size": self.chunk_size, |
|
"hop_size": self.hop_size, |
|
"n_repeats": self.n_repeats, |
|
"n_src": self.n_src, |
|
"norm_type": self.norm_type, |
|
"mask_act": self.mask_act, |
|
"bidirectional": self.bidirectional, |
|
"rnn_type": self.rnn_type, |
|
"num_layers": self.num_layers, |
|
"dropout": self.dropout, |
|
} |
|
return config |
|
|