ahmedghani's picture
initial commit
8235b4f
raw
history blame
11.3 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from ..utils import overlap_and_add
from ..utils import capture_init
class MulCatBlock(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False):
super(MulCatBlock, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_direction = int(bidirectional) + 1
self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout,
batch_first=True, bidirectional=bidirectional)
self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size)
self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1,
batch_first=True, dropout=dropout, bidirectional=bidirectional)
self.gate_rnn_proj = nn.Linear(
hidden_size * self.num_direction, input_size)
self.block_projection = nn.Linear(input_size * 2, input_size)
def forward(self, input):
output = input
# run rnn module
rnn_output, _ = self.rnn(output)
rnn_output = self.rnn_proj(rnn_output.contiguous(
).view(-1, rnn_output.shape[2])).view(output.shape).contiguous()
# run gate rnn module
gate_rnn_output, _ = self.gate_rnn(output)
gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous(
).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous()
# apply gated rnn
gated_output = torch.mul(rnn_output, gate_rnn_output)
gated_output = torch.cat([gated_output, output], 2)
gated_output = self.block_projection(
gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape)
return gated_output
class ByPass(nn.Module):
def __init__(self):
super(ByPass, self).__init__()
def forward(self, input):
return input
class DPMulCat(nn.Module):
def __init__(self, input_size, hidden_size, output_size, num_spk,
dropout=0, num_layers=1, bidirectional=True, input_normalize=False):
super(DPMulCat, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.in_norm = input_normalize
self.num_layers = num_layers
self.rows_grnn = nn.ModuleList([])
self.cols_grnn = nn.ModuleList([])
self.rows_normalization = nn.ModuleList([])
self.cols_normalization = nn.ModuleList([])
# create the dual path pipeline
for i in range(num_layers):
self.rows_grnn.append(MulCatBlock(
input_size, hidden_size, dropout, bidirectional=bidirectional))
self.cols_grnn.append(MulCatBlock(
input_size, hidden_size, dropout, bidirectional=bidirectional))
if self.in_norm:
self.rows_normalization.append(
nn.GroupNorm(1, input_size, eps=1e-8))
self.cols_normalization.append(
nn.GroupNorm(1, input_size, eps=1e-8))
else:
# used to disable normalization
self.rows_normalization.append(ByPass())
self.cols_normalization.append(ByPass())
self.output = nn.Sequential(
nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1))
def forward(self, input):
batch_size, _, d1, d2 = input.shape
output = input
output_all = []
for i in range(self.num_layers):
row_input = output.permute(0, 3, 2, 1).contiguous().view(
batch_size * d2, d1, -1)
row_output = self.rows_grnn[i](row_input)
row_output = row_output.view(
batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
row_output = self.rows_normalization[i](row_output)
# apply a skip connection
if self.training:
output = output + row_output
else:
output += row_output
col_input = output.permute(0, 2, 3, 1).contiguous().view(
batch_size * d1, d2, -1)
col_output = self.cols_grnn[i](col_input)
col_output = col_output.view(
batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
col_output = self.cols_normalization[i](col_output).contiguous()
# apply a skip connection
if self.training:
output = output + col_output
else:
output += col_output
output_i = self.output(output)
if self.training or i == (self.num_layers - 1):
output_all.append(output_i)
return output_all
class Separator(nn.Module):
def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2,
layer=4, segment_size=100, input_normalize=False, bidirectional=True):
super(Separator, self).__init__()
self.input_dim = input_dim
self.feature_dim = feature_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.layer = layer
self.segment_size = segment_size
self.num_spk = num_spk
self.input_normalize = input_normalize
self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim,
self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize)
# ======================================= #
# The following code block was borrowed and modified from https://github.com/yluo42/TAC
# ================ BEGIN ================ #
def pad_segment(self, input, segment_size):
# input is the features: (B, N, T)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
rest = segment_size - (segment_stride + seq_len %
segment_size) % segment_size
if rest > 0:
pad = Variable(torch.zeros(batch_size, dim, rest)
).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(
batch_size, dim, segment_stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def create_chuncks(self, input, segment_size):
# split the feature into chunks of segment size
# input is the features: (B, N, T)
input, rest = self.pad_segment(input, segment_size)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size,
dim, -1, segment_size)
segments2 = input[:, :, segment_stride:].contiguous().view(
batch_size, dim, -1, segment_size)
segments = torch.cat([segments1, segments2], 3).view(
batch_size, dim, -1, segment_size).transpose(2, 3)
return segments.contiguous(), rest
def merge_chuncks(self, input, rest):
# merge the splitted features into full utterance
# input is the features: (B, N, L, K)
batch_size, dim, segment_size, _ = input.shape
segment_stride = segment_size // 2
input = input.transpose(2, 3).contiguous().view(
batch_size, dim, -1, segment_size*2) # B, N, K, L
input1 = input[:, :, :, :segment_size].contiguous().view(
batch_size, dim, -1)[:, :, segment_stride:]
input2 = input[:, :, :, segment_size:].contiguous().view(
batch_size, dim, -1)[:, :, :-segment_stride]
output = input1 + input2
if rest > 0:
output = output[:, :, :-rest]
return output.contiguous() # B, N, T
# ================= END ================= #
def forward(self, input):
# create chunks
enc_segments, enc_rest = self.create_chuncks(
input, self.segment_size)
# separate
output_all = self.rnn_model(enc_segments)
# merge back audio files
output_all_wav = []
for ii in range(len(output_all)):
output_ii = self.merge_chuncks(
output_all[ii], enc_rest)
output_all_wav.append(output_ii)
return output_all_wav
class SWave(nn.Module):
@capture_init
def __init__(self, N, L, H, R, C, sr, segment, input_normalize):
super(SWave, self).__init__()
# hyper-parameter
self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment
self.input_normalize = input_normalize
self.context_len = 2 * self.sr / 1000
self.context = int(self.sr * self.context_len / 1000)
self.layer = self.R
self.filter_dim = self.context * 2 + 1
self.num_spk = self.C
# similar to dprnn paper, setting chancksize to sqrt(2*L)
self.segment_size = int(
np.sqrt(2 * self.sr * self.segment / (self.L/2)))
# model sub-networks
self.encoder = Encoder(L, N)
self.decoder = Decoder(L)
self.separator = Separator(self.filter_dim + self.N, self.N, self.H,
self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize)
# init
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def forward(self, mixture):
mixture_w = self.encoder(mixture)
output_all = self.separator(mixture_w)
# fix time dimension, might change due to convolution operations
T_mix = mixture.size(-1)
# generate wav after each RNN block and optimize the loss
outputs = []
for ii in range(len(output_all)):
output_ii = output_all[ii].view(
mixture.shape[0], self.C, self.N, mixture_w.shape[2])
output_ii = self.decoder(output_ii)
T_est = output_ii.size(-1)
output_ii = F.pad(output_ii, (0, T_mix - T_est))
outputs.append(output_ii)
return torch.stack(outputs)
class Encoder(nn.Module):
def __init__(self, L, N):
super(Encoder, self).__init__()
self.L, self.N = L, N
# setting 50% overlap
self.conv = nn.Conv1d(
1, N, kernel_size=L, stride=L // 2, bias=False)
def forward(self, mixture):
mixture = torch.unsqueeze(mixture, 1)
mixture_w = F.relu(self.conv(mixture))
return mixture_w
class Decoder(nn.Module):
def __init__(self, L):
super(Decoder, self).__init__()
self.L = L
def forward(self, est_source):
est_source = torch.transpose(est_source, 2, 3)
est_source = nn.AvgPool2d((1, self.L))(est_source)
est_source = overlap_and_add(est_source, self.L//2)
return est_source