Spaces:
Sleeping
Sleeping
| import torch | |
| import math | |
| import inspect | |
| from torch import nn | |
| from torch import Tensor | |
| from typing import Tuple | |
| from typing import Optional | |
| from torch.nn.functional import fold, unfold | |
| import numpy as np | |
| from . import activations, normalizations | |
| from .normalizations import gLN | |
| def has_arg(fn, name): | |
| """Checks if a callable accepts a given keyword argument. | |
| Args: | |
| fn (callable): Callable to inspect. | |
| name (str): Check if ``fn`` can be called with ``name`` as a keyword | |
| argument. | |
| Returns: | |
| bool: whether ``fn`` accepts a ``name`` keyword argument. | |
| """ | |
| signature = inspect.signature(fn) | |
| parameter = signature.parameters.get(name) | |
| if parameter is None: | |
| return False | |
| return parameter.kind in ( | |
| inspect.Parameter.POSITIONAL_OR_KEYWORD, | |
| inspect.Parameter.KEYWORD_ONLY, | |
| ) | |
| class SingleRNN(nn.Module): | |
| """Module for a RNN block. | |
| Inspired from https://github.com/yluo42/TAC/blob/master/utility/models.py | |
| Licensed under CC BY-NC-SA 3.0 US. | |
| Args: | |
| rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can | |
| also be passed in lowercase letters. | |
| input_size (int): Dimension of the input feature. The input should have | |
| shape [batch, seq_len, input_size]. | |
| hidden_size (int): Dimension of the hidden state. | |
| n_layers (int, optional): Number of layers used in RNN. Default is 1. | |
| dropout (float, optional): Dropout ratio. Default is 0. | |
| bidirectional (bool, optional): Whether the RNN layers are | |
| bidirectional. Default is ``False``. | |
| """ | |
| def __init__( | |
| self, | |
| rnn_type, | |
| input_size, | |
| hidden_size, | |
| n_layers=1, | |
| dropout=0, | |
| bidirectional=False, | |
| ): | |
| super(SingleRNN, self).__init__() | |
| assert rnn_type.upper() in ["RNN", "LSTM", "GRU"] | |
| rnn_type = rnn_type.upper() | |
| self.rnn_type = rnn_type | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.n_layers = n_layers | |
| self.dropout = dropout | |
| self.bidirectional = bidirectional | |
| self.rnn = getattr(nn, rnn_type)( | |
| input_size, | |
| hidden_size, | |
| num_layers=n_layers, | |
| dropout=dropout, | |
| batch_first=True, | |
| bidirectional=bool(bidirectional), | |
| ) | |
| def output_size(self): | |
| return self.hidden_size * (2 if self.bidirectional else 1) | |
| def forward(self, inp): | |
| """ Input shape [batch, seq, feats] """ | |
| self.rnn.flatten_parameters() # Enables faster multi-GPU training. | |
| output = inp | |
| rnn_output, _ = self.rnn(output) | |
| return rnn_output | |
| class LSTMBlockTF(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| hid_size, | |
| norm_type="gLN", | |
| bidirectional=True, | |
| rnn_type="LSTM", | |
| num_layers=1, | |
| dropout=0, | |
| ): | |
| super(LSTMBlockTF, self).__init__() | |
| self.RNN = SingleRNN( | |
| rnn_type, | |
| in_chan, | |
| hid_size, | |
| num_layers, | |
| dropout=dropout, | |
| bidirectional=bidirectional, | |
| ) | |
| self.linear = nn.Linear(self.RNN.output_size, in_chan) | |
| self.norm = normalizations.get(norm_type)(in_chan) | |
| def forward(self, x): | |
| B, F, T = x.size() | |
| output = self.RNN(x.transpose(1, 2)) # B, T, N | |
| output = self.linear(output) | |
| output = output.transpose(1, -1) # B, N, T | |
| output = self.norm(output) | |
| return output + x | |
| # ===================Transformer====================== | |
| class Linear(nn.Module): | |
| """ | |
| Wrapper class of torch.nn.Linear | |
| Weight initialize by xavier initialization and bias initialize to zeros. | |
| """ | |
| def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: | |
| super(Linear, self).__init__() | |
| self.linear = nn.Linear(in_features, out_features, bias=bias) | |
| nn.init.xavier_uniform_(self.linear.weight) | |
| if bias: | |
| nn.init.zeros_(self.linear.bias) | |
| def forward(self, x): | |
| return self.linear(x) | |
| class Swish(nn.Module): | |
| """ | |
| Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied | |
| to a variety of challenging domains such as Image classification and Machine translation. | |
| """ | |
| def __init__(self): | |
| super(Swish, self).__init__() | |
| def forward(self, inputs): | |
| return inputs * inputs.sigmoid() | |
| class Transpose(nn.Module): | |
| """ Wrapper class of torch.transpose() for Sequential module. """ | |
| def __init__(self, shape: tuple): | |
| super(Transpose, self).__init__() | |
| self.shape = shape | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x.transpose(*self.shape) | |
| class GLU(nn.Module): | |
| """ | |
| The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing | |
| in the paper “Language Modeling with Gated Convolutional Networks” | |
| """ | |
| def __init__(self, dim: int) -> None: | |
| super(GLU, self).__init__() | |
| self.dim = dim | |
| def forward(self, inputs: Tensor) -> Tensor: | |
| outputs, gate = inputs.chunk(2, dim=self.dim) | |
| return outputs * gate.sigmoid() | |
| class FeedForwardModule(nn.Module): | |
| def __init__( | |
| self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1, | |
| ) -> None: | |
| super(FeedForwardModule, self).__init__() | |
| self.sequential = nn.Sequential( | |
| nn.LayerNorm(encoder_dim), | |
| Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), | |
| Swish(), | |
| nn.Dropout(p=dropout_p), | |
| Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), | |
| nn.Dropout(p=dropout_p), | |
| ) | |
| def forward(self, inputs): | |
| return self.sequential(inputs) | |
| class PositionalEncoding(nn.Module): | |
| """ | |
| Positional Encoding proposed in "Attention Is All You Need". | |
| Since transformer contains no recurrence and no convolution, in order for the model to make | |
| use of the order of the sequence, we must add some positional information. | |
| "Attention Is All You Need" use sine and cosine functions of different frequencies: | |
| PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) | |
| PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) | |
| """ | |
| def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: | |
| super(PositionalEncoding, self).__init__() | |
| pe = torch.zeros(max_len, d_model, requires_grad=False) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, length: int) -> Tensor: | |
| return self.pe[:, :length] | |
| class RelativeMultiHeadAttention(nn.Module): | |
| """ | |
| Multi-head attention with relative positional encoding. | |
| This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" | |
| Args: | |
| d_model (int): The dimension of model | |
| num_heads (int): The number of attention heads. | |
| dropout_p (float): probability of dropout | |
| Inputs: query, key, value, pos_embedding, mask | |
| - **query** (batch, time, dim): Tensor containing query vector | |
| - **key** (batch, time, dim): Tensor containing key vector | |
| - **value** (batch, time, dim): Tensor containing value vector | |
| - **pos_embedding** (batch, time, dim): Positional embedding tensor | |
| - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
| Returns: | |
| - **outputs**: Tensor produces by relative multi head attention module. | |
| """ | |
| def __init__( | |
| self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1, | |
| ): | |
| super(RelativeMultiHeadAttention, self).__init__() | |
| assert d_model % num_heads == 0, "d_model % num_heads should be zero." | |
| self.d_model = d_model | |
| self.d_head = int(d_model / num_heads) | |
| self.num_heads = num_heads | |
| self.sqrt_dim = math.sqrt(d_model) | |
| self.query_proj = Linear(d_model, d_model) | |
| self.key_proj = Linear(d_model, d_model) | |
| self.value_proj = Linear(d_model, d_model) | |
| self.pos_proj = Linear(d_model, d_model, bias=False) | |
| self.dropout = nn.Dropout(p=dropout_p) | |
| self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
| self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) | |
| torch.nn.init.xavier_uniform_(self.u_bias) | |
| torch.nn.init.xavier_uniform_(self.v_bias) | |
| self.out_proj = Linear(d_model, d_model) | |
| def forward( | |
| self, | |
| query: Tensor, | |
| key: Tensor, | |
| value: Tensor, | |
| pos_embedding: Tensor, | |
| mask: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| batch_size = value.size(0) | |
| query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) | |
| key = ( | |
| self.key_proj(key) | |
| .view(batch_size, -1, self.num_heads, self.d_head) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| value = ( | |
| self.value_proj(value) | |
| .view(batch_size, -1, self.num_heads, self.d_head) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| pos_embedding = self.pos_proj(pos_embedding).view( | |
| batch_size, -1, self.num_heads, self.d_head | |
| ) | |
| content_score = torch.matmul( | |
| (query + self.u_bias).transpose(1, 2), key.transpose(2, 3) | |
| ) | |
| pos_score = torch.matmul( | |
| (query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1) | |
| ) | |
| pos_score = self._relative_shift(pos_score) | |
| score = (content_score + pos_score) / self.sqrt_dim | |
| if mask is not None: | |
| mask = mask.unsqueeze(1) | |
| score.masked_fill_(mask, -1e9) | |
| attn = torch.nn.functional.softmax(score, -1) | |
| attn = self.dropout(attn) | |
| context = torch.matmul(attn, value).transpose(1, 2) | |
| context = context.contiguous().view(batch_size, -1, self.d_model) | |
| return self.out_proj(context) | |
| def _relative_shift(self, pos_score: Tensor) -> Tensor: | |
| batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() | |
| zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) | |
| padded_pos_score = torch.cat([zeros, pos_score], dim=-1) | |
| padded_pos_score = padded_pos_score.view( | |
| batch_size, num_heads, seq_length2 + 1, seq_length1 | |
| ) | |
| pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) | |
| return pos_score | |
| class MultiHeadedSelfAttentionModule(nn.Module): | |
| """ | |
| Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, | |
| the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention | |
| module to generalize better on different input length and the resulting encoder is more robust to the variance of | |
| the utterance length. Conformer use prenorm residual units with dropout which helps training | |
| and regularizing deeper models. | |
| Args: | |
| d_model (int): The dimension of model | |
| num_heads (int): The number of attention heads. | |
| dropout_p (float): probability of dropout | |
| device (torch.device): torch device (cuda or cpu) | |
| Inputs: inputs, mask | |
| - **inputs** (batch, time, dim): Tensor containing input vector | |
| - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked | |
| Returns: | |
| - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. | |
| """ | |
| def __init__( | |
| self, d_model: int, num_heads: int, dropout_p: float = 0.1, is_casual=True | |
| ): | |
| super(MultiHeadedSelfAttentionModule, self).__init__() | |
| self.positional_encoding = PositionalEncoding(d_model) | |
| self.layer_norm = nn.LayerNorm(d_model) | |
| self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) | |
| self.dropout = nn.Dropout(p=dropout_p) | |
| self.is_casual = is_casual | |
| def forward(self, inputs: Tensor): | |
| batch_size, seq_length, _ = inputs.size() | |
| pos_embedding = self.positional_encoding(seq_length) | |
| pos_embedding = pos_embedding.repeat(batch_size, 1, 1) | |
| mask = None | |
| if self.is_casual: | |
| mask = torch.triu( | |
| torch.ones((seq_length, seq_length), dtype=torch.uint8).to( | |
| inputs.device | |
| ), | |
| diagonal=1, | |
| ) | |
| mask = mask.unsqueeze(0).expand(batch_size, -1, -1).bool() # [B, L, L] | |
| inputs = self.layer_norm(inputs) | |
| outputs = self.attention( | |
| inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask | |
| ) | |
| return self.dropout(outputs) | |
| class ResidualConnectionModule(nn.Module): | |
| """ | |
| Residual Connection Module. | |
| outputs = (module(inputs) x module_factor + inputs x input_factor) | |
| """ | |
| def __init__( | |
| self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0 | |
| ): | |
| super(ResidualConnectionModule, self).__init__() | |
| self.module = module | |
| self.module_factor = module_factor | |
| self.input_factor = input_factor | |
| def forward(self, inputs): | |
| return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) | |
| class DepthwiseConv1d(nn.Module): | |
| """ | |
| When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, | |
| this operation is termed in literature as depthwise convolution. | |
| Args: | |
| in_channels (int): Number of channels in the input | |
| out_channels (int): Number of channels produced by the convolution | |
| kernel_size (int or tuple): Size of the convolving kernel | |
| stride (int, optional): Stride of the convolution. Default: 1 | |
| padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
| bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
| Inputs: inputs | |
| - **inputs** (batch, in_channels, time): Tensor containing input vector | |
| Returns: outputs | |
| - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = False, | |
| is_casual: bool = True, | |
| ) -> None: | |
| super(DepthwiseConv1d, self).__init__() | |
| assert ( | |
| out_channels % in_channels == 0 | |
| ), "out_channels should be constant multiple of in_channels" | |
| if is_casual: | |
| padding = kernel_size - 1 | |
| else: | |
| padding = (kernel_size - 1) // 2 | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| groups=in_channels, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| self.is_casual = is_casual | |
| self.kernel_size = kernel_size | |
| def forward(self, inputs: Tensor) -> Tensor: | |
| if self.is_casual: | |
| return self.conv(inputs)[:, :, : -(self.kernel_size - 1)] | |
| return self.conv(inputs) | |
| class PointwiseConv1d(nn.Module): | |
| """ | |
| When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. | |
| This operation often used to match dimensions. | |
| Args: | |
| in_channels (int): Number of channels in the input | |
| out_channels (int): Number of channels produced by the convolution | |
| stride (int, optional): Stride of the convolution. Default: 1 | |
| padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 | |
| bias (bool, optional): If True, adds a learnable bias to the output. Default: True | |
| Inputs: inputs | |
| - **inputs** (batch, in_channels, time): Tensor containing input vector | |
| Returns: outputs | |
| - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| padding: int = 0, | |
| bias: bool = True, | |
| ) -> None: | |
| super(PointwiseConv1d, self).__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=1, | |
| stride=stride, | |
| padding=padding, | |
| bias=bias, | |
| ) | |
| def forward(self, inputs: Tensor) -> Tensor: | |
| return self.conv(inputs) | |
| class ConformerConvModule(nn.Module): | |
| """ | |
| Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). | |
| This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution | |
| to aid training deep models. | |
| Args: | |
| in_channels (int): Number of channels in the input | |
| kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 | |
| dropout_p (float, optional): probability of dropout | |
| device (torch.device): torch device (cuda or cpu) | |
| Inputs: inputs | |
| inputs (batch, time, dim): Tensor contains input sequences | |
| Outputs: outputs | |
| outputs (batch, time, dim): Tensor produces by conformer convolution module. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| kernel_size: int = 31, | |
| expansion_factor: int = 2, | |
| dropout_p: float = 0.1, | |
| is_casual: bool = True, | |
| ) -> None: | |
| super(ConformerConvModule, self).__init__() | |
| assert ( | |
| kernel_size - 1 | |
| ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" | |
| assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" | |
| self.sequential = nn.Sequential( | |
| nn.LayerNorm(in_channels), | |
| Transpose(shape=(1, 2)), | |
| PointwiseConv1d( | |
| in_channels, | |
| in_channels * expansion_factor, | |
| stride=1, | |
| padding=0, | |
| bias=True, | |
| ), | |
| GLU(dim=1), | |
| DepthwiseConv1d( | |
| in_channels, in_channels, kernel_size, stride=1, is_casual=is_casual | |
| ), | |
| nn.BatchNorm1d(in_channels), | |
| Swish(), | |
| PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), | |
| nn.Dropout(p=dropout_p), | |
| ) | |
| def forward(self, inputs: Tensor) -> Tensor: | |
| return self.sequential(inputs).transpose(1, 2) | |
| class TransformerLayer(nn.Module): | |
| def __init__( | |
| self, in_chan=128, n_head=8, n_att=1, dropout=0.1, max_len=500, is_casual=True | |
| ): | |
| super(TransformerLayer, self).__init__() | |
| self.in_chan = in_chan | |
| self.n_head = n_head | |
| self.dropout = dropout | |
| self.max_len = max_len | |
| self.n_att = n_att | |
| self.seq = nn.Sequential( | |
| ResidualConnectionModule( | |
| FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), | |
| module_factor=0.5, | |
| ), | |
| ResidualConnectionModule( | |
| MultiHeadedSelfAttentionModule(in_chan, n_head, dropout, is_casual) | |
| ), | |
| ResidualConnectionModule( | |
| ConformerConvModule(in_chan, 31, 2, dropout, is_casual=is_casual) | |
| ), | |
| ResidualConnectionModule( | |
| FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout), | |
| module_factor=0.5, | |
| ), | |
| nn.LayerNorm(in_chan), | |
| ) | |
| def forward(self, x): | |
| return self.seq(x) | |
| class TransformerBlockTF(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| n_head=8, | |
| n_att=1, | |
| dropout=0.1, | |
| max_len=500, | |
| norm_type="cLN", | |
| is_casual=True, | |
| ): | |
| super(TransformerBlockTF, self).__init__() | |
| self.transformer = TransformerLayer( | |
| in_chan, n_head, n_att, dropout, max_len, is_casual | |
| ) | |
| self.norm = normalizations.get(norm_type)(in_chan) | |
| def forward(self, x): | |
| B, F, T = x.size() | |
| output = self.transformer(x.permute(0, 2, 1).contiguous()) # B, T, N | |
| output = output.permute(0, 2, 1).contiguous() # B, N, T | |
| output = self.norm(output) | |
| return output + x | |
| # ==================================================== | |
| class DPRNNBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| hid_size, | |
| norm_type="gLN", | |
| bidirectional=True, | |
| rnn_type="LSTM", | |
| num_layers=1, | |
| dropout=0, | |
| ): | |
| super(DPRNNBlock, self).__init__() | |
| self.intra_RNN = SingleRNN( | |
| rnn_type, | |
| in_chan, | |
| hid_size, | |
| num_layers, | |
| dropout=dropout, | |
| bidirectional=True, | |
| ) | |
| self.inter_RNN = SingleRNN( | |
| rnn_type, | |
| in_chan, | |
| hid_size, | |
| num_layers, | |
| dropout=dropout, | |
| bidirectional=bidirectional, | |
| ) | |
| self.intra_linear = nn.Linear(self.intra_RNN.output_size, in_chan) | |
| self.intra_norm = normalizations.get(norm_type)(in_chan) | |
| self.inter_linear = nn.Linear(self.inter_RNN.output_size, in_chan) | |
| self.inter_norm = normalizations.get(norm_type)(in_chan) | |
| def forward(self, x): | |
| """ Input shape : [batch, feats, chunk_size, num_chunks] """ | |
| B, N, K, L = x.size() | |
| output = x # for skip connection | |
| # Intra-chunk processing | |
| x = x.transpose(1, -1).reshape(B * L, K, N) | |
| x = self.intra_RNN(x) | |
| x = self.intra_linear(x) | |
| x = x.reshape(B, L, K, N).transpose(1, -1) | |
| x = self.intra_norm(x) | |
| output = output + x | |
| # Inter-chunk processing | |
| x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N) | |
| x = self.inter_RNN(x) | |
| x = self.inter_linear(x) | |
| x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1).contiguous() | |
| x = self.inter_norm(x) | |
| return output + x | |
| class DPRNN(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| n_src, | |
| out_chan=None, | |
| bn_chan=128, | |
| hid_size=128, | |
| chunk_size=100, | |
| hop_size=None, | |
| n_repeats=6, | |
| norm_type="gLN", | |
| mask_act="relu", | |
| bidirectional=True, | |
| rnn_type="LSTM", | |
| num_layers=1, | |
| dropout=0, | |
| ): | |
| super(DPRNN, self).__init__() | |
| self.in_chan = in_chan | |
| out_chan = out_chan if out_chan is not None else in_chan | |
| self.out_chan = out_chan | |
| self.bn_chan = bn_chan | |
| self.hid_size = hid_size | |
| self.chunk_size = chunk_size | |
| hop_size = hop_size if hop_size is not None else chunk_size // 2 | |
| self.hop_size = hop_size | |
| self.n_repeats = n_repeats | |
| self.n_src = n_src | |
| self.norm_type = norm_type | |
| self.mask_act = mask_act | |
| self.bidirectional = bidirectional | |
| self.rnn_type = rnn_type | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| layer_norm = normalizations.get(norm_type)(in_chan) | |
| bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) | |
| self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) | |
| # Succession of DPRNNBlocks. | |
| net = [] | |
| for x in range(self.n_repeats): | |
| net += [ | |
| DPRNNBlock( | |
| bn_chan, | |
| hid_size, | |
| norm_type=norm_type, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| ) | |
| ] | |
| self.net = nn.Sequential(*net) | |
| # Masking in 3D space | |
| net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) | |
| self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) | |
| # Gating and masking in 2D space (after fold) | |
| self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) | |
| self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) | |
| self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) | |
| # Get activation function. | |
| mask_nl_class = activations.get(mask_act) | |
| # For softmax, feed the source dimension. | |
| if has_arg(mask_nl_class, "dim"): | |
| self.output_act = mask_nl_class(dim=1) | |
| else: | |
| self.output_act = mask_nl_class() | |
| def forward(self, mixture_w): | |
| r"""Forward. | |
| Args: | |
| mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ | |
| Returns: | |
| :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ | |
| """ | |
| batch, n_filters, n_frames = mixture_w.size() | |
| output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames] | |
| output = unfold( | |
| output.unsqueeze(-1), | |
| kernel_size=(self.chunk_size, 1), | |
| padding=(self.chunk_size, 0), | |
| stride=(self.hop_size, 1), | |
| ) | |
| n_chunks = output.shape[-1] | |
| output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) | |
| # Apply stacked DPRNN Blocks sequentially | |
| output = self.net(output) | |
| # Map to sources with kind of 2D masks | |
| output = self.first_out(output) | |
| output = output.reshape( | |
| batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks | |
| ) | |
| # Overlap and add: | |
| # [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames] | |
| to_unfold = self.bn_chan * self.chunk_size | |
| output = fold( | |
| output.reshape(batch * self.n_src, to_unfold, n_chunks), | |
| (n_frames, 1), | |
| kernel_size=(self.chunk_size, 1), | |
| padding=(self.chunk_size, 0), | |
| stride=(self.hop_size, 1), | |
| ) | |
| # Apply gating | |
| output = output.reshape(batch * self.n_src, self.bn_chan, -1) | |
| # output = self.net_out(output) * self.net_gate(output) | |
| # Compute mask | |
| score = self.mask_net(output) | |
| est_mask = self.output_act(score) | |
| est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) | |
| return est_mask | |
| def get_config(self): | |
| config = { | |
| "in_chan": self.in_chan, | |
| "out_chan": self.out_chan, | |
| "bn_chan": self.bn_chan, | |
| "hid_size": self.hid_size, | |
| "chunk_size": self.chunk_size, | |
| "hop_size": self.hop_size, | |
| "n_repeats": self.n_repeats, | |
| "n_src": self.n_src, | |
| "norm_type": self.norm_type, | |
| "mask_act": self.mask_act, | |
| "bidirectional": self.bidirectional, | |
| "rnn_type": self.rnn_type, | |
| "num_layers": self.num_layers, | |
| "dropout": self.dropout, | |
| } | |
| return config | |
| class DPRNNLinear(nn.Module): | |
| def __init__( | |
| self, | |
| in_chan, | |
| n_src, | |
| out_chan=None, | |
| bn_chan=128, | |
| hid_size=128, | |
| chunk_size=100, | |
| hop_size=None, | |
| n_repeats=6, | |
| norm_type="gLN", | |
| mask_act="relu", | |
| bidirectional=True, | |
| rnn_type="LSTM", | |
| num_layers=1, | |
| dropout=0, | |
| ): | |
| super(DPRNNLinear, self).__init__() | |
| self.in_chan = in_chan | |
| out_chan = out_chan if out_chan is not None else in_chan | |
| self.out_chan = out_chan | |
| self.bn_chan = bn_chan | |
| self.hid_size = hid_size | |
| self.chunk_size = chunk_size | |
| hop_size = hop_size if hop_size is not None else chunk_size // 2 | |
| self.hop_size = hop_size | |
| self.n_repeats = n_repeats | |
| self.n_src = n_src | |
| self.norm_type = norm_type | |
| self.mask_act = mask_act | |
| self.bidirectional = bidirectional | |
| self.rnn_type = rnn_type | |
| self.num_layers = num_layers | |
| self.dropout = dropout | |
| layer_norm = normalizations.get(norm_type)(in_chan) | |
| bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1) | |
| self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv) | |
| # Succession of DPRNNBlocks. | |
| net = [] | |
| for x in range(self.n_repeats): | |
| net += [ | |
| DPRNNBlock( | |
| bn_chan, | |
| hid_size, | |
| norm_type=norm_type, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| num_layers=num_layers, | |
| dropout=dropout, | |
| ) | |
| ] | |
| self.net = nn.Sequential(*net) | |
| # Masking in 3D space | |
| net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1) | |
| self.first_out = nn.Sequential(nn.PReLU(), net_out_conv) | |
| # Gating and masking in 2D space (after fold) | |
| # self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh()) | |
| self.net_out = nn.Linear(bn_chan, out_chan) | |
| self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid()) | |
| self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False) | |
| # Get activation function. | |
| mask_nl_class = activations.get(mask_act) | |
| # For softmax, feed the source dimension. | |
| if has_arg(mask_nl_class, "dim"): | |
| self.output_act = mask_nl_class(dim=1) | |
| else: | |
| self.output_act = mask_nl_class() | |
| def forward(self, mixture_w): | |
| r"""Forward. | |
| Args: | |
| mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$ | |
| Returns: | |
| :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$ | |
| """ | |
| batch, n_filters, n_frames = mixture_w.size() | |
| output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames] | |
| output = unfold( | |
| output.unsqueeze(-1), | |
| kernel_size=(self.chunk_size, 1), | |
| padding=(self.chunk_size, 0), | |
| stride=(self.hop_size, 1), | |
| ) | |
| n_chunks = output.shape[-1] | |
| output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks) | |
| # Apply stacked DPRNN Blocks sequentially | |
| output = self.net(output) | |
| # Map to sources with kind of 2D masks | |
| output = self.first_out(output) | |
| output = output.reshape( | |
| batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks | |
| ) | |
| # Overlap and add: | |
| # [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames] | |
| to_unfold = self.bn_chan * self.chunk_size | |
| output = fold( | |
| output.reshape(batch * self.n_src, to_unfold, n_chunks), | |
| (n_frames, 1), | |
| kernel_size=(self.chunk_size, 1), | |
| padding=(self.chunk_size, 0), | |
| stride=(self.hop_size, 1), | |
| ) | |
| # Apply gating | |
| output = output.reshape(batch * self.n_src, self.bn_chan, -1) | |
| output = self.net_out(output.transpose(1, 1)).transpose(1, 2) * self.net_gate( | |
| output | |
| ) | |
| # Compute mask | |
| score = self.mask_net(output) | |
| est_mask = self.output_act(score) | |
| est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames) | |
| return est_mask | |
| def get_config(self): | |
| config = { | |
| "in_chan": self.in_chan, | |
| "out_chan": self.out_chan, | |
| "bn_chan": self.bn_chan, | |
| "hid_size": self.hid_size, | |
| "chunk_size": self.chunk_size, | |
| "hop_size": self.hop_size, | |
| "n_repeats": self.n_repeats, | |
| "n_src": self.n_src, | |
| "norm_type": self.norm_type, | |
| "mask_act": self.mask_act, | |
| "bidirectional": self.bidirectional, | |
| "rnn_type": self.rnn_type, | |
| "num_layers": self.num_layers, | |
| "dropout": self.dropout, | |
| } | |
| return config | |