akhaliq3
spaces demo
5019931
raw
history blame contribute delete
No virus
17.3 kB
import math
from typing import Dict, List, NoReturn, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torchlibrosa.stft import ISTFT, STFT, magphase
from bytesep.models.pytorch_modules import Base, Subband, act, init_bn, init_layer
class ConvBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
activation: str,
momentum: float,
):
r"""Convolutional block."""
super(ConvBlock, self).__init__()
self.activation = activation
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.init_weights()
def init_weights(self) -> NoReturn:
r"""Initialize weights."""
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
r"""Forward data into the module.
Args:
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
Returns:
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
"""
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
x = act(self.bn2(self.conv2(x)), self.activation)
output_tensor = x
return output_tensor
class EncoderBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
downsample: Tuple,
activation: str,
momentum: float,
):
r"""Encoder block."""
super(EncoderBlock, self).__init__()
self.conv_block = ConvBlock(
in_channels, out_channels, kernel_size, activation, momentum
)
self.downsample = downsample
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
r"""Forward data into the module.
Args:
input_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
Returns:
encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
"""
encoder_tensor = self.conv_block(input_tensor)
# encoder: (batch_size, out_feature_maps, time_steps, freq_bins)
encoder_pool = F.avg_pool2d(encoder_tensor, kernel_size=self.downsample)
# encoder_pool: (batch_size, out_feature_maps, downsampled_time_steps, downsampled_freq_bins)
return encoder_pool, encoder_tensor
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
upsample: Tuple,
activation: str,
momentum: float,
):
r"""Decoder block."""
super(DecoderBlock, self).__init__()
self.kernel_size = kernel_size
self.stride = upsample
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.stride,
stride=self.stride,
padding=(0, 0),
bias=False,
dilation=(1, 1),
)
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.conv_block2 = ConvBlock(
out_channels * 2, out_channels, kernel_size, activation, momentum
)
self.init_weights()
def init_weights(self):
r"""Initialize weights."""
init_layer(self.conv1)
init_bn(self.bn1)
def forward(
self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor
) -> torch.Tensor:
r"""Forward data into the module.
Args:
torch_tensor: (batch_size, in_feature_maps, downsampled_time_steps, downsampled_freq_bins)
concat_tensor: (batch_size, in_feature_maps, time_steps, freq_bins)
Returns:
output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
"""
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
# (batch_size, in_feature_maps, time_steps, freq_bins)
x = torch.cat((x, concat_tensor), dim=1)
# (batch_size, in_feature_maps * 2, time_steps, freq_bins)
output_tensor = self.conv_block2(x)
# output_tensor: (batch_size, out_feature_maps, time_steps, freq_bins)
return output_tensor
class UNet(nn.Module, Base):
def __init__(self, input_channels: int, target_sources_num: int):
r"""UNet."""
super(UNet, self).__init__()
self.input_channels = input_channels
self.target_sources_num = target_sources_num
window_size = 2048
hop_size = 441
center = True
pad_mode = "reflect"
window = "hann"
activation = "leaky_relu"
momentum = 0.01
self.subbands_num = 1
assert (
self.subbands_num == 1
), "Using subbands_num > 1 on spectrogram \
will lead to unexpected performance sometimes. Suggest to use \
subband method on waveform."
self.K = 3 # outputs: |M|, cos∠M, sin∠M
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
self.stft = STFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
self.istft = ISTFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
self.subband = Subband(subbands_num=self.subbands_num)
self.encoder_block1 = EncoderBlock(
in_channels=input_channels * self.subbands_num,
out_channels=32,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.encoder_block2 = EncoderBlock(
in_channels=32,
out_channels=64,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.encoder_block3 = EncoderBlock(
in_channels=64,
out_channels=128,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.encoder_block4 = EncoderBlock(
in_channels=128,
out_channels=256,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.encoder_block5 = EncoderBlock(
in_channels=256,
out_channels=384,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.encoder_block6 = EncoderBlock(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
downsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.conv_block7 = ConvBlock(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
activation=activation,
momentum=momentum,
)
self.decoder_block1 = DecoderBlock(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.decoder_block2 = DecoderBlock(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.decoder_block3 = DecoderBlock(
in_channels=384,
out_channels=256,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.decoder_block4 = DecoderBlock(
in_channels=256,
out_channels=128,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.decoder_block5 = DecoderBlock(
in_channels=128,
out_channels=64,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.decoder_block6 = DecoderBlock(
in_channels=64,
out_channels=32,
kernel_size=(3, 3),
upsample=(2, 2),
activation=activation,
momentum=momentum,
)
self.after_conv_block1 = ConvBlock(
in_channels=32,
out_channels=32,
kernel_size=(3, 3),
activation=activation,
momentum=momentum,
)
self.after_conv2 = nn.Conv2d(
in_channels=32,
out_channels=target_sources_num
* input_channels
* self.K
* self.subbands_num,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
self.init_weights()
def init_weights(self):
r"""Initialize weights."""
init_bn(self.bn0)
init_layer(self.after_conv2)
def feature_maps_to_wav(
self,
input_tensor: torch.Tensor,
sp: torch.Tensor,
sin_in: torch.Tensor,
cos_in: torch.Tensor,
audio_length: int,
) -> torch.Tensor:
r"""Convert feature maps to waveform.
Args:
input_tensor: (batch_size, target_sources_num * input_channels * self.K, time_steps, freq_bins)
sp: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
sin_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
cos_in: (batch_size, target_sources_num * input_channels, time_steps, freq_bins)
Outputs:
waveform: (batch_size, target_sources_num * input_channels, segment_samples)
"""
batch_size, _, time_steps, freq_bins = input_tensor.shape
x = input_tensor.reshape(
batch_size,
self.target_sources_num,
self.input_channels,
self.K,
time_steps,
freq_bins,
)
# x: (batch_size, target_sources_num, input_channles, K, time_steps, freq_bins)
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
# mask_cos, mask_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
# Y = |Y|cos∠Y + j|Y|sin∠Y
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
out_cos = (
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
)
out_sin = (
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
)
# out_cos: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
# out_sin: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
# Calculate |Y|.
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
# out_mag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
# Calculate Y_{real} and Y_{imag} for ISTFT.
out_real = out_mag * out_cos
out_imag = out_mag * out_sin
# out_real, out_imag: (batch_size, target_sources_num, input_channles, time_steps, freq_bins)
# Reformat shape to (n, 1, time_steps, freq_bins) for ISTFT.
shape = (
batch_size * self.target_sources_num * self.input_channels,
1,
time_steps,
freq_bins,
)
out_real = out_real.reshape(shape)
out_imag = out_imag.reshape(shape)
# ISTFT.
x = self.istft(out_real, out_imag, audio_length)
# (batch_size * target_sources_num * input_channels, segments_num)
# Reshape.
waveform = x.reshape(
batch_size, self.target_sources_num * self.input_channels, audio_length
)
# (batch_size, target_sources_num * input_channels, segments_num)
return waveform
def forward(self, input_dict: Dict) -> Dict:
r"""Forward data into the module.
Args:
input_dict: dict, e.g., {
waveform: (batch_size, input_channels, segment_samples),
...,
}
Outputs:
output_dict: dict, e.g., {
'waveform': (batch_size, input_channels, segment_samples),
...,
}
"""
mixtures = input_dict['waveform']
# (batch_size, input_channels, segment_samples)
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
# mag, cos_in, sin_in: (batch_size, input_channels, time_steps, freq_bins)
# Batch normalize on individual frequency bins.
x = mag.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
# x: (batch_size, input_channels, time_steps, freq_bins)
# Pad spectrogram to be evenly divided by downsample ratio.
origin_len = x.shape[2]
pad_len = (
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
- origin_len
)
x = F.pad(x, pad=(0, 0, 0, pad_len))
# x: (batch_size, input_channels, padded_time_steps, freq_bins)
# Let frequency bins be evenly divided by 2, e.g., 1025 -> 1024
x = x[..., 0 : x.shape[-1] - 1] # (bs, input_channels, T, F)
if self.subbands_num > 1:
x = self.subband.analysis(x)
# (bs, input_channels, T, F'), where F' = F // subbands_num
# UNet
(x1_pool, x1) = self.encoder_block1(x) # x1_pool: (bs, 32, T / 2, F' / 2)
(x2_pool, x2) = self.encoder_block2(x1_pool) # x2_pool: (bs, 64, T / 4, F' / 4)
(x3_pool, x3) = self.encoder_block3(
x2_pool
) # x3_pool: (bs, 128, T / 8, F' / 8)
(x4_pool, x4) = self.encoder_block4(
x3_pool
) # x4_pool: (bs, 256, T / 16, F' / 16)
(x5_pool, x5) = self.encoder_block5(
x4_pool
) # x5_pool: (bs, 384, T / 32, F' / 32)
(x6_pool, x6) = self.encoder_block6(
x5_pool
) # x6_pool: (bs, 384, T / 64, F' / 64)
x_center = self.conv_block7(x6_pool) # (bs, 384, T / 64, F' / 64)
x7 = self.decoder_block1(x_center, x6) # (bs, 384, T / 32, F' / 32)
x8 = self.decoder_block2(x7, x5) # (bs, 384, T / 16, F' / 16)
x9 = self.decoder_block3(x8, x4) # (bs, 256, T / 8, F' / 8)
x10 = self.decoder_block4(x9, x3) # (bs, 128, T / 4, F' / 4)
x11 = self.decoder_block5(x10, x2) # (bs, 64, T / 2, F' / 2)
x12 = self.decoder_block6(x11, x1) # (bs, 32, T, F')
x = self.after_conv_block1(x12) # (bs, 32, T, F')
x = self.after_conv2(x)
# (batch_size, target_sources_num * input_channles * self.K * subbands_num, T, F')
if self.subbands_num > 1:
x = self.subband.synthesis(x)
# (batch_size, target_sources_num * input_channles * self.K, T, F)
# Recover shape
x = F.pad(x, pad=(0, 1)) # Pad frequency, e.g., 1024 -> 1025.
x = x[:, :, 0:origin_len, :]
# (batch_size, target_sources_num * input_channles * self.K, T, F)
audio_length = mixtures.shape[2]
separated_audio = self.feature_maps_to_wav(x, mag, sin_in, cos_in, audio_length)
# separated_audio: (batch_size, target_sources_num * input_channels, segments_num)
output_dict = {'waveform': separated_audio}
return output_dict