from typing import List, Tuple import torch import torch.nn as nn from models.scnet_unofficial.utils import get_convtranspose_output_padding class FusionLayer(nn.Module): """ FusionLayer class implements a module for fusing two input tensors using convolutional operations. Args: - input_dim (int): Dimensionality of the input channels. - kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3. - stride (int, optional): Stride value for the convolutional layer. Default is 1. - padding (int, optional): Padding value for the convolutional layer. Default is 1. Shapes: - Input: (B, F, T, C) and (B, F, T, C) where B is batch size, F is the number of features, T is sequence length, C is input dimensionality. - Output: (B, F, T, C) where B is batch size, F is the number of features, T is sequence length, C is input dimensionality. """ def __init__( self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1 ): """ Initializes FusionLayer with input dimension, kernel size, stride, and padding. """ super().__init__() self.conv = nn.Conv2d( input_dim * 2, input_dim * 2, kernel_size=(kernel_size, 1), stride=(stride, 1), padding=(padding, 0), ) self.activation = nn.GLU() def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """ Performs forward pass through the FusionLayer. Args: - x1 (torch.Tensor): First input tensor of shape (B, F, T, C). - x2 (torch.Tensor): Second input tensor of shape (B, F, T, C). Returns: - torch.Tensor: Output tensor of shape (B, F, T, C). """ x = x1 + x2 x = x.repeat(1, 1, 1, 2) x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) x = self.activation(x) return x class Upsample(nn.Module): """ Upsample class implements a module for upsampling input tensors using transposed 2D convolution. Args: - input_dim (int): Dimensionality of the input channels. - output_dim (int): Dimensionality of the output channels. - stride (int): Stride value for the transposed convolution operation. - output_padding (int): Output padding value for the transposed convolution operation. Shapes: - Input: (B, C_in, F, T) where B is batch size, C_in is the number of input channels, F is the frequency dimension, T is the time dimension. - Output: (B, C_out, F * stride + output_padding, T) where B is batch size, C_out is the number of output channels, F * stride + output_padding is the upsampled frequency dimension. """ def __init__( self, input_dim: int, output_dim: int, stride: int, output_padding: int ): """ Initializes Upsample with input dimension, output dimension, stride, and output padding. """ super().__init__() self.conv = nn.ConvTranspose2d( input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs forward pass through the Upsample module. Args: - x (torch.Tensor): Input tensor of shape (B, C_in, F, T). Returns: - torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T). """ return self.conv(x) class SULayer(nn.Module): """ SULayer class implements a subband upsampling layer using transposed convolution. Args: - input_dim (int): Dimensionality of the input channels. - output_dim (int): Dimensionality of the output channels. - upsample_stride (int): Stride value for the upsampling operation. - subband_shape (int): Shape of the subband. - sd_interval (Tuple[int, int]): Start and end indices of the subband interval. Shapes: - Input: (B, F, T, C) where B is batch size, F is the number of features, T is sequence length, C is input dimensionality. - Output: (B, F, T, C) where B is batch size, F is the number of features, T is sequence length, C is input dimensionality. """ def __init__( self, input_dim: int, output_dim: int, upsample_stride: int, subband_shape: int, sd_interval: Tuple[int, int], ): """ Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval. """ super().__init__() sd_shape = sd_interval[1] - sd_interval[0] upsample_output_padding = get_convtranspose_output_padding( input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride ) self.upsample = Upsample( input_dim=input_dim, output_dim=output_dim, stride=upsample_stride, output_padding=upsample_output_padding, ) self.sd_interval = sd_interval def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs forward pass through the SULayer. Args: - x (torch.Tensor): Input tensor of shape (B, F, T, C). Returns: - torch.Tensor: Output tensor of shape (B, F, T, C). """ x = x[:, self.sd_interval[0] : self.sd_interval[1]] x = x.permute(0, 3, 1, 2) x = self.upsample(x) x = x.permute(0, 2, 3, 1) return x class SUBlock(nn.Module): """ SUBlock class implements a block with fusion layer and subband upsampling layers. Args: - input_dim (int): Dimensionality of the input channels. - output_dim (int): Dimensionality of the output channels. - upsample_strides (List[int]): List of stride values for the upsampling operations. - subband_shapes (List[int]): List of shapes for the subbands. - sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition. Shapes: - Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where B is batch size, Fi-1 is the number of input subbands, T is sequence length, Ci-1 is the number of input channels. - Output: (B, Fi, T, Ci) where B is batch size, Fi is the number of output subbands, T is sequence length, Ci is the number of output channels. """ def __init__( self, input_dim: int, output_dim: int, upsample_strides: List[int], subband_shapes: List[int], sd_intervals: List[Tuple[int, int]], ): """ Initializes SUBlock with input dimension, output dimension, upsample strides, subband shapes, and subband intervals. """ super().__init__() self.fusion_layer = FusionLayer(input_dim=input_dim) self.su_layers = nn.ModuleList( SULayer( input_dim=input_dim, output_dim=output_dim, upsample_stride=uss, subband_shape=sbs, sd_interval=sdi, ) for i, (uss, sbs, sdi) in enumerate( zip(upsample_strides, subband_shapes, sd_intervals) ) ) def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor: """ Performs forward pass through the SUBlock. Args: - x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1). - x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1). Returns: - torch.Tensor: Output tensor of shape (B, Fi, T, Ci). """ x = self.fusion_layer(x, x_skip) x = torch.concat([layer(x) for layer in self.su_layers], dim=1) return x