|
from typing import List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from models.scnet_unofficial.utils import get_convtranspose_output_padding
|
|
|
|
|
|
class FusionLayer(nn.Module):
|
|
"""
|
|
FusionLayer class implements a module for fusing two input tensors using convolutional operations.
|
|
|
|
Args:
|
|
- input_dim (int): Dimensionality of the input channels.
|
|
- kernel_size (int, optional): Kernel size for the convolutional layer. Default is 3.
|
|
- stride (int, optional): Stride value for the convolutional layer. Default is 1.
|
|
- padding (int, optional): Padding value for the convolutional layer. Default is 1.
|
|
|
|
Shapes:
|
|
- Input: (B, F, T, C) and (B, F, T, C) where
|
|
B is batch size,
|
|
F is the number of features,
|
|
T is sequence length,
|
|
C is input dimensionality.
|
|
- Output: (B, F, T, C) where
|
|
B is batch size,
|
|
F is the number of features,
|
|
T is sequence length,
|
|
C is input dimensionality.
|
|
"""
|
|
|
|
def __init__(
|
|
self, input_dim: int, kernel_size: int = 3, stride: int = 1, padding: int = 1
|
|
):
|
|
"""
|
|
Initializes FusionLayer with input dimension, kernel size, stride, and padding.
|
|
"""
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(
|
|
input_dim * 2,
|
|
input_dim * 2,
|
|
kernel_size=(kernel_size, 1),
|
|
stride=(stride, 1),
|
|
padding=(padding, 0),
|
|
)
|
|
self.activation = nn.GLU()
|
|
|
|
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the FusionLayer.
|
|
|
|
Args:
|
|
- x1 (torch.Tensor): First input tensor of shape (B, F, T, C).
|
|
- x2 (torch.Tensor): Second input tensor of shape (B, F, T, C).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, F, T, C).
|
|
"""
|
|
x = x1 + x2
|
|
x = x.repeat(1, 1, 1, 2)
|
|
x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
|
x = self.activation(x)
|
|
return x
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
Upsample class implements a module for upsampling input tensors using transposed 2D convolution.
|
|
|
|
Args:
|
|
- input_dim (int): Dimensionality of the input channels.
|
|
- output_dim (int): Dimensionality of the output channels.
|
|
- stride (int): Stride value for the transposed convolution operation.
|
|
- output_padding (int): Output padding value for the transposed convolution operation.
|
|
|
|
Shapes:
|
|
- Input: (B, C_in, F, T) where
|
|
B is batch size,
|
|
C_in is the number of input channels,
|
|
F is the frequency dimension,
|
|
T is the time dimension.
|
|
- Output: (B, C_out, F * stride + output_padding, T) where
|
|
B is batch size,
|
|
C_out is the number of output channels,
|
|
F * stride + output_padding is the upsampled frequency dimension.
|
|
"""
|
|
|
|
def __init__(
|
|
self, input_dim: int, output_dim: int, stride: int, output_padding: int
|
|
):
|
|
"""
|
|
Initializes Upsample with input dimension, output dimension, stride, and output padding.
|
|
"""
|
|
super().__init__()
|
|
self.conv = nn.ConvTranspose2d(
|
|
input_dim, output_dim, 1, (stride, 1), output_padding=(output_padding, 0)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the Upsample module.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, C_in, F, T).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, C_out, F * stride + output_padding, T).
|
|
"""
|
|
return self.conv(x)
|
|
|
|
|
|
class SULayer(nn.Module):
|
|
"""
|
|
SULayer class implements a subband upsampling layer using transposed convolution.
|
|
|
|
Args:
|
|
- input_dim (int): Dimensionality of the input channels.
|
|
- output_dim (int): Dimensionality of the output channels.
|
|
- upsample_stride (int): Stride value for the upsampling operation.
|
|
- subband_shape (int): Shape of the subband.
|
|
- sd_interval (Tuple[int, int]): Start and end indices of the subband interval.
|
|
|
|
Shapes:
|
|
- Input: (B, F, T, C) where
|
|
B is batch size,
|
|
F is the number of features,
|
|
T is sequence length,
|
|
C is input dimensionality.
|
|
- Output: (B, F, T, C) where
|
|
B is batch size,
|
|
F is the number of features,
|
|
T is sequence length,
|
|
C is input dimensionality.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
upsample_stride: int,
|
|
subband_shape: int,
|
|
sd_interval: Tuple[int, int],
|
|
):
|
|
"""
|
|
Initializes SULayer with input dimension, output dimension, upsample stride, subband shape, and subband interval.
|
|
"""
|
|
super().__init__()
|
|
sd_shape = sd_interval[1] - sd_interval[0]
|
|
upsample_output_padding = get_convtranspose_output_padding(
|
|
input_shape=sd_shape, output_shape=subband_shape, stride=upsample_stride
|
|
)
|
|
self.upsample = Upsample(
|
|
input_dim=input_dim,
|
|
output_dim=output_dim,
|
|
stride=upsample_stride,
|
|
output_padding=upsample_output_padding,
|
|
)
|
|
self.sd_interval = sd_interval
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the SULayer.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, F, T, C).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, F, T, C).
|
|
"""
|
|
x = x[:, self.sd_interval[0] : self.sd_interval[1]]
|
|
x = x.permute(0, 3, 1, 2)
|
|
x = self.upsample(x)
|
|
x = x.permute(0, 2, 3, 1)
|
|
return x
|
|
|
|
|
|
class SUBlock(nn.Module):
|
|
"""
|
|
SUBlock class implements a block with fusion layer and subband upsampling layers.
|
|
|
|
Args:
|
|
- input_dim (int): Dimensionality of the input channels.
|
|
- output_dim (int): Dimensionality of the output channels.
|
|
- upsample_strides (List[int]): List of stride values for the upsampling operations.
|
|
- subband_shapes (List[int]): List of shapes for the subbands.
|
|
- sd_intervals (List[Tuple[int, int]]): List of intervals for subband decomposition.
|
|
|
|
Shapes:
|
|
- Input: (B, Fi-1, T, Ci-1) and (B, Fi-1, T, Ci-1) where
|
|
B is batch size,
|
|
Fi-1 is the number of input subbands,
|
|
T is sequence length,
|
|
Ci-1 is the number of input channels.
|
|
- Output: (B, Fi, T, Ci) where
|
|
B is batch size,
|
|
Fi is the number of output subbands,
|
|
T is sequence length,
|
|
Ci is the number of output channels.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
input_dim: int,
|
|
output_dim: int,
|
|
upsample_strides: List[int],
|
|
subband_shapes: List[int],
|
|
sd_intervals: List[Tuple[int, int]],
|
|
):
|
|
"""
|
|
Initializes SUBlock with input dimension, output dimension,
|
|
upsample strides, subband shapes, and subband intervals.
|
|
"""
|
|
super().__init__()
|
|
self.fusion_layer = FusionLayer(input_dim=input_dim)
|
|
self.su_layers = nn.ModuleList(
|
|
SULayer(
|
|
input_dim=input_dim,
|
|
output_dim=output_dim,
|
|
upsample_stride=uss,
|
|
subband_shape=sbs,
|
|
sd_interval=sdi,
|
|
)
|
|
for i, (uss, sbs, sdi) in enumerate(
|
|
zip(upsample_strides, subband_shapes, sd_intervals)
|
|
)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Performs forward pass through the SUBlock.
|
|
|
|
Args:
|
|
- x (torch.Tensor): Input tensor of shape (B, Fi-1, T, Ci-1).
|
|
- x_skip (torch.Tensor): Input skip connection tensor of shape (B, Fi-1, T, Ci-1).
|
|
|
|
Returns:
|
|
- torch.Tensor: Output tensor of shape (B, Fi, T, Ci).
|
|
"""
|
|
x = self.fusion_layer(x, x_skip)
|
|
x = torch.concat([layer(x) for layer in self.su_layers], dim=1)
|
|
return x
|
|
|