roychao19477
Initial clean commit
7001051
# Reference: https://github.com/yxlu-0102/MP-SENet/blob/main/models/generator.py
import torch
import torch.nn as nn
from einops import rearrange
from .lsigmoid import LearnableSigmoid2D
def get_padding(kernel_size, dilation=1):
"""
Calculate the padding size for a convolutional layer.
Args:
- kernel_size (int): Size of the convolutional kernel.
- dilation (int, optional): Dilation rate of the convolution. Defaults to 1.
Returns:
- int: Calculated padding size.
"""
return int((kernel_size * dilation - dilation) / 2)
def get_padding_2d(kernel_size, dilation=(1, 1)):
"""
Calculate the padding size for a 2D convolutional layer.
Args:
- kernel_size (tuple): Size of the convolutional kernel (height, width).
- dilation (tuple, optional): Dilation rate of the convolution (height, width). Defaults to (1, 1).
Returns:
- tuple: Calculated padding size (height, width).
"""
return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2),
int((kernel_size[1] * dilation[1] - dilation[1]) / 2))
class DenseBlock(nn.Module):
"""
DenseBlock module consisting of multiple convolutional layers with dilation.
"""
def __init__(self, cfg, kernel_size=(3, 3), depth=4):
super(DenseBlock, self).__init__()
self.cfg = cfg
self.depth = depth
self.dense_block = nn.ModuleList()
self.hid_feature = cfg['model_cfg']['hid_feature']
for i in range(depth):
dil = 2 ** i
dense_conv = nn.Sequential(
nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size,
dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.dense_block.append(dense_conv)
def forward(self, x):
"""
Forward pass for the DenseBlock module.
Args:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Output tensor after processing through the dense block.
"""
skip = x
for i in range(self.depth):
x = self.dense_block[i](skip)
skip = torch.cat([x, skip], dim=1)
return x
class DenseEncoder(nn.Module):
"""
DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
"""
def __init__(self, cfg):
super(DenseEncoder, self).__init__()
self.cfg = cfg
self.input_channel = cfg['model_cfg']['input_channel']
self.hid_feature = cfg['model_cfg']['hid_feature']
self.dense_conv_1 = nn.Sequential(
nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.dense_block = DenseBlock(cfg, depth=4)
self.dense_conv_2 = nn.Sequential(
nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
def forward(self, x):
"""
Forward pass for the DenseEncoder module.
Args:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Encoded tensor.
"""
x = self.dense_conv_1(x) # [batch, hid_feature, time, freq]
x = self.dense_block(x) # [batch, hid_feature, time, freq]
x = self.dense_conv_2(x) # [batch, hid_feature, time, freq//2]
return x
class MagDecoder(nn.Module):
"""
MagDecoder module for decoding magnitude information.
"""
def __init__(self, cfg):
super(MagDecoder, self).__init__()
self.dense_block = DenseBlock(cfg, depth=4)
self.hid_feature = cfg['model_cfg']['hid_feature']
self.output_channel = cfg['model_cfg']['output_channel']
self.n_fft = cfg['stft_cfg']['n_fft']
self.beta = cfg['model_cfg']['beta']
self.mask_conv = nn.Sequential(
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)),
nn.InstanceNorm2d(self.output_channel, affine=True),
nn.PReLU(self.output_channel),
nn.Conv2d(self.output_channel, self.output_channel, (1, 1))
)
self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta)
def forward(self, x):
"""
Forward pass for the MagDecoder module.
Args:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Decoded tensor with magnitude information.
"""
x = self.dense_block(x)
x = self.mask_conv(x)
x = rearrange(x, 'b c t f -> b f t c').squeeze(-1)
x = self.lsigmoid(x)
x = rearrange(x, 'b f t -> b t f').unsqueeze(1)
return x
class PhaseDecoder(nn.Module):
"""
PhaseDecoder module for decoding phase information.
"""
def __init__(self, cfg):
super(PhaseDecoder, self).__init__()
self.dense_block = DenseBlock(cfg, depth=4)
self.hid_feature = cfg['model_cfg']['hid_feature']
self.output_channel = cfg['model_cfg']['output_channel']
self.phase_conv = nn.Sequential(
nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
nn.InstanceNorm2d(self.hid_feature, affine=True),
nn.PReLU(self.hid_feature)
)
self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
def forward(self, x):
"""
Forward pass for the PhaseDecoder module.
Args:
- x (torch.Tensor): Input tensor.
Returns:
- torch.Tensor: Decoded tensor with phase information.
"""
x = self.dense_block(x)
x = self.phase_conv(x)
x_r = self.phase_conv_r(x)
x_i = self.phase_conv_i(x)
x = torch.atan2(x_i, x_r)
return x