AudioSep / models /resunet.py
badayvedat's picture
Initial commit
ae29df4
raw
history blame
No virus
21 kB
import numpy as np
from typing import Dict, List, NoReturn, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchlibrosa.stft import STFT, ISTFT, magphase
from models.base import Base, init_layer, init_bn, act
class FiLM(nn.Module):
def __init__(self, film_meta, condition_size):
super(FiLM, self).__init__()
self.condition_size = condition_size
self.modules, _ = self.create_film_modules(
film_meta=film_meta,
ancestor_names=[],
)
def create_film_modules(self, film_meta, ancestor_names):
modules = {}
# Pre-order traversal of modules
for module_name, value in film_meta.items():
if isinstance(value, int):
ancestor_names.append(module_name)
unique_module_name = '->'.join(ancestor_names)
modules[module_name] = self.add_film_layer_to_module(
num_features=value,
unique_module_name=unique_module_name,
)
elif isinstance(value, dict):
ancestor_names.append(module_name)
modules[module_name], _ = self.create_film_modules(
film_meta=value,
ancestor_names=ancestor_names,
)
ancestor_names.pop()
return modules, ancestor_names
def add_film_layer_to_module(self, num_features, unique_module_name):
layer = nn.Linear(self.condition_size, num_features)
init_layer(layer)
self.add_module(name=unique_module_name, module=layer)
return layer
def forward(self, conditions):
film_dict = self.calculate_film_data(
conditions=conditions,
modules=self.modules,
)
return film_dict
def calculate_film_data(self, conditions, modules):
film_data = {}
# Pre-order traversal of modules
for module_name, module in modules.items():
if isinstance(module, nn.Module):
film_data[module_name] = module(conditions)[:, :, None, None]
elif isinstance(module, dict):
film_data[module_name] = self.calculate_film_data(conditions, module)
return film_data
class ConvBlockRes(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
momentum: float,
has_film,
):
r"""Residual block."""
super(ConvBlockRes, self).__init__()
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
)
self.is_shortcut = True
else:
self.is_shortcut = False
self.has_film = has_film
self.init_weights()
def init_weights(self) -> NoReturn:
r"""Initialize weights."""
init_bn(self.bn1)
init_bn(self.bn2)
init_layer(self.conv1)
init_layer(self.conv2)
if self.is_shortcut:
init_layer(self.shortcut)
def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor:
r"""Forward data into the module.
Args:
input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
Returns:
output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
"""
b1 = film_dict['beta1']
b2 = film_dict['beta2']
x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1, negative_slope=0.01))
x = self.conv2(F.leaky_relu_(self.bn2(x) + b2, negative_slope=0.01))
if self.is_shortcut:
return self.shortcut(input_tensor) + x
else:
return input_tensor + x
class EncoderBlockRes1B(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
downsample: Tuple,
momentum: float,
has_film,
):
r"""Encoder block, contains 8 convolutional layers."""
super(EncoderBlockRes1B, self).__init__()
self.conv_block1 = ConvBlockRes(
in_channels, out_channels, kernel_size, momentum, has_film,
)
self.downsample = downsample
def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor:
r"""Forward data into the module.
Args:
input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
Returns:
encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins)
encoder: (batch_size, output_feature_maps, time_steps, freq_bins)
"""
encoder = self.conv_block1(input_tensor, film_dict['conv_block1'])
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlockRes1B(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Tuple,
upsample: Tuple,
momentum: float,
has_film,
):
r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers."""
super(DecoderBlockRes1B, self).__init__()
self.kernel_size = kernel_size
self.stride = upsample
self.conv1 = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.stride,
stride=self.stride,
padding=(0, 0),
bias=False,
dilation=(1, 1),
)
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
self.conv_block2 = ConvBlockRes(
out_channels * 2, out_channels, kernel_size, momentum, has_film,
)
self.bn2 = nn.BatchNorm2d(in_channels, momentum=momentum)
self.has_film = has_film
self.init_weights()
def init_weights(self):
r"""Initialize weights."""
init_bn(self.bn1)
init_layer(self.conv1)
def forward(
self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor, film_dict: Dict,
) -> torch.Tensor:
r"""Forward data into the module.
Args:
input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins)
concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
Returns:
output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
"""
# b1 = film_dict['beta1']
b1 = film_dict['beta1']
x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1))
# (batch_size, input_feature_maps, time_steps, freq_bins)
x = torch.cat((x, concat_tensor), dim=1)
# (batch_size, input_feature_maps * 2, time_steps, freq_bins)
x = self.conv_block2(x, film_dict['conv_block2'])
# output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
return x
class ResUNet30_Base(nn.Module, Base):
def __init__(self, input_channels, output_channels):
super(ResUNet30_Base, self).__init__()
window_size = 2048
hop_size = 320
center = True
pad_mode = "reflect"
window = "hann"
momentum = 0.01
self.output_channels = output_channels
self.target_sources_num = 1
self.K = 3
self.time_downsample_ratio = 2 ** 5 # This number equals 2^{#encoder_blcoks}
self.stft = STFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
self.istft = ISTFT(
n_fft=window_size,
hop_length=hop_size,
win_length=window_size,
window=window,
center=center,
pad_mode=pad_mode,
freeze_parameters=True,
)
self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
self.pre_conv = nn.Conv2d(
in_channels=input_channels,
out_channels=32,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
self.encoder_block1 = EncoderBlockRes1B(
in_channels=32,
out_channels=32,
kernel_size=(3, 3),
downsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.encoder_block2 = EncoderBlockRes1B(
in_channels=32,
out_channels=64,
kernel_size=(3, 3),
downsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.encoder_block3 = EncoderBlockRes1B(
in_channels=64,
out_channels=128,
kernel_size=(3, 3),
downsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.encoder_block4 = EncoderBlockRes1B(
in_channels=128,
out_channels=256,
kernel_size=(3, 3),
downsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.encoder_block5 = EncoderBlockRes1B(
in_channels=256,
out_channels=384,
kernel_size=(3, 3),
downsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.encoder_block6 = EncoderBlockRes1B(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
downsample=(1, 2),
momentum=momentum,
has_film=True,
)
self.conv_block7a = EncoderBlockRes1B(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
downsample=(1, 1),
momentum=momentum,
has_film=True,
)
self.decoder_block1 = DecoderBlockRes1B(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
upsample=(1, 2),
momentum=momentum,
has_film=True,
)
self.decoder_block2 = DecoderBlockRes1B(
in_channels=384,
out_channels=384,
kernel_size=(3, 3),
upsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.decoder_block3 = DecoderBlockRes1B(
in_channels=384,
out_channels=256,
kernel_size=(3, 3),
upsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.decoder_block4 = DecoderBlockRes1B(
in_channels=256,
out_channels=128,
kernel_size=(3, 3),
upsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.decoder_block5 = DecoderBlockRes1B(
in_channels=128,
out_channels=64,
kernel_size=(3, 3),
upsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.decoder_block6 = DecoderBlockRes1B(
in_channels=64,
out_channels=32,
kernel_size=(3, 3),
upsample=(2, 2),
momentum=momentum,
has_film=True,
)
self.after_conv = nn.Conv2d(
in_channels=32,
out_channels=output_channels * self.K,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
bias=True,
)
self.init_weights()
def init_weights(self):
init_bn(self.bn0)
init_layer(self.pre_conv)
init_layer(self.after_conv)
def feature_maps_to_wav(
self,
input_tensor: torch.Tensor,
sp: torch.Tensor,
sin_in: torch.Tensor,
cos_in: torch.Tensor,
audio_length: int,
) -> torch.Tensor:
r"""Convert feature maps to waveform.
Args:
input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins)
sp: (batch_size, input_channels, time_steps, freq_bins)
sin_in: (batch_size, input_channels, time_steps, freq_bins)
cos_in: (batch_size, input_channels, time_steps, freq_bins)
(There is input_channels == output_channels for the source separation task.)
Outputs:
waveform: (batch_size, target_sources_num * output_channels, segment_samples)
"""
batch_size, _, time_steps, freq_bins = input_tensor.shape
x = input_tensor.reshape(
batch_size,
self.target_sources_num,
self.output_channels,
self.K,
time_steps,
freq_bins,
)
# x: (batch_size, target_sources_num, output_channels, self.K, time_steps, freq_bins)
mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
_mask_real = torch.tanh(x[:, :, :, 1, :, :])
_mask_imag = torch.tanh(x[:, :, :, 2, :, :])
# linear_mag = torch.tanh(x[:, :, :, 3, :, :])
_, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
# mask_cos, mask_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
# Y = |Y|cos∠Y + j|Y|sin∠Y
# = |Y|cos(∠X + ∠M) + j|Y|sin(∠X + ∠M)
# = |Y|(cos∠X cos∠M - sin∠X sin∠M) + j|Y|(sin∠X cos∠M + cos∠X sin∠M)
out_cos = (
cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
)
out_sin = (
sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
)
# out_cos: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
# out_sin: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
# Calculate |Y|.
out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
# out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag + linear_mag)
# out_mag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
# Calculate Y_{real} and Y_{imag} for ISTFT.
out_real = out_mag * out_cos
out_imag = out_mag * out_sin
# out_real, out_imag: (batch_size, target_sources_num, output_channels, time_steps, freq_bins)
# Reformat shape to (N, 1, time_steps, freq_bins) for ISTFT where
# N = batch_size * target_sources_num * output_channels
shape = (
batch_size * self.target_sources_num * self.output_channels,
1,
time_steps,
freq_bins,
)
out_real = out_real.reshape(shape)
out_imag = out_imag.reshape(shape)
# ISTFT.
x = self.istft(out_real, out_imag, audio_length)
# (batch_size * target_sources_num * output_channels, segments_num)
# Reshape.
waveform = x.reshape(
batch_size, self.target_sources_num * self.output_channels, audio_length
)
# (batch_size, target_sources_num * output_channels, segments_num)
return waveform
def forward(self, mixtures, film_dict):
"""
Args:
input: (batch_size, segment_samples, channels_num)
Outputs:
output_dict: {
'wav': (batch_size, segment_samples, channels_num),
'sp': (batch_size, channels_num, time_steps, freq_bins)}
"""
mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
x = mag
# Batch normalization
x = x.transpose(1, 3)
x = self.bn0(x)
x = x.transpose(1, 3)
"""(batch_size, chanenls, time_steps, freq_bins)"""
# Pad spectrogram to be evenly divided by downsample ratio.
origin_len = x.shape[2]
pad_len = (
int(np.ceil(x.shape[2] / self.time_downsample_ratio)) * self.time_downsample_ratio
- origin_len
)
x = F.pad(x, pad=(0, 0, 0, pad_len))
"""(batch_size, channels, padded_time_steps, freq_bins)"""
# Let frequency bins be evenly divided by 2, e.g., 513 -> 512
x = x[..., 0 : x.shape[-1] - 1] # (bs, channels, T, F)
# UNet
x = self.pre_conv(x)
x1_pool, x1 = self.encoder_block1(x, film_dict['encoder_block1']) # x1_pool: (bs, 32, T / 2, F / 2)
x2_pool, x2 = self.encoder_block2(x1_pool, film_dict['encoder_block2']) # x2_pool: (bs, 64, T / 4, F / 4)
x3_pool, x3 = self.encoder_block3(x2_pool, film_dict['encoder_block3']) # x3_pool: (bs, 128, T / 8, F / 8)
x4_pool, x4 = self.encoder_block4(x3_pool, film_dict['encoder_block4']) # x4_pool: (bs, 256, T / 16, F / 16)
x5_pool, x5 = self.encoder_block5(x4_pool, film_dict['encoder_block5']) # x5_pool: (bs, 384, T / 32, F / 32)
x6_pool, x6 = self.encoder_block6(x5_pool, film_dict['encoder_block6']) # x6_pool: (bs, 384, T / 32, F / 64)
x_center, _ = self.conv_block7a(x6_pool, film_dict['conv_block7a']) # (bs, 384, T / 32, F / 64)
x7 = self.decoder_block1(x_center, x6, film_dict['decoder_block1']) # (bs, 384, T / 32, F / 32)
x8 = self.decoder_block2(x7, x5, film_dict['decoder_block2']) # (bs, 384, T / 16, F / 16)
x9 = self.decoder_block3(x8, x4, film_dict['decoder_block3']) # (bs, 256, T / 8, F / 8)
x10 = self.decoder_block4(x9, x3, film_dict['decoder_block4']) # (bs, 128, T / 4, F / 4)
x11 = self.decoder_block5(x10, x2, film_dict['decoder_block5']) # (bs, 64, T / 2, F / 2)
x12 = self.decoder_block6(x11, x1, film_dict['decoder_block6']) # (bs, 32, T, F)
x = self.after_conv(x12)
# Recover shape
x = F.pad(x, pad=(0, 1))
x = x[:, :, 0:origin_len, :]
audio_length = mixtures.shape[2]
# Recover each subband spectrograms to subband waveforms. Then synthesis
# the subband waveforms to a waveform.
separated_audio = self.feature_maps_to_wav(
input_tensor=x,
# input_tensor: (batch_size, target_sources_num * output_channels * self.K, T, F')
sp=mag,
# sp: (batch_size, input_channels, T, F')
sin_in=sin_in,
# sin_in: (batch_size, input_channels, T, F')
cos_in=cos_in,
# cos_in: (batch_size, input_channels, T, F')
audio_length=audio_length,
)
# (batch_size, target_sources_num * output_channels, subbands_num, segment_samples)
output_dict = {'waveform': separated_audio}
return output_dict
def get_film_meta(module):
film_meta = {}
if hasattr(module, 'has_film'):\
if module.has_film:
film_meta['beta1'] = module.bn1.num_features
film_meta['beta2'] = module.bn2.num_features
else:
film_meta['beta1'] = 0
film_meta['beta2'] = 0
for child_name, child_module in module.named_children():
child_meta = get_film_meta(child_module)
if len(child_meta) > 0:
film_meta[child_name] = child_meta
return film_meta
class ResUNet30(nn.Module):
def __init__(self, input_channels, output_channels, condition_size):
super(ResUNet30, self).__init__()
self.base = ResUNet30_Base(
input_channels=input_channels,
output_channels=output_channels,
)
self.film_meta = get_film_meta(
module=self.base,
)
self.film = FiLM(
film_meta=self.film_meta,
condition_size=condition_size
)
def forward(self, input_dict):
mixtures = input_dict['mixture']
conditions = input_dict['condition']
film_dict = self.film(
conditions=conditions,
)
output_dict = self.base(
mixtures=mixtures,
film_dict=film_dict,
)
return output_dict