Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Authors: Eliya Nachmani (enk100), Yossi Adi (adiyoss), Lior Wolf | |
import sys | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from ..utils import overlap_and_add | |
from ..utils import capture_init | |
class MulCatBlock(nn.Module): | |
def __init__(self, input_size, hidden_size, dropout=0, bidirectional=False): | |
super(MulCatBlock, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.num_direction = int(bidirectional) + 1 | |
self.rnn = nn.LSTM(input_size, hidden_size, 1, dropout=dropout, | |
batch_first=True, bidirectional=bidirectional) | |
self.rnn_proj = nn.Linear(hidden_size * self.num_direction, input_size) | |
self.gate_rnn = nn.LSTM(input_size, hidden_size, num_layers=1, | |
batch_first=True, dropout=dropout, bidirectional=bidirectional) | |
self.gate_rnn_proj = nn.Linear( | |
hidden_size * self.num_direction, input_size) | |
self.block_projection = nn.Linear(input_size * 2, input_size) | |
def forward(self, input): | |
output = input | |
# run rnn module | |
rnn_output, _ = self.rnn(output) | |
rnn_output = self.rnn_proj(rnn_output.contiguous( | |
).view(-1, rnn_output.shape[2])).view(output.shape).contiguous() | |
# run gate rnn module | |
gate_rnn_output, _ = self.gate_rnn(output) | |
gate_rnn_output = self.gate_rnn_proj(gate_rnn_output.contiguous( | |
).view(-1, gate_rnn_output.shape[2])).view(output.shape).contiguous() | |
# apply gated rnn | |
gated_output = torch.mul(rnn_output, gate_rnn_output) | |
gated_output = torch.cat([gated_output, output], 2) | |
gated_output = self.block_projection( | |
gated_output.contiguous().view(-1, gated_output.shape[2])).view(output.shape) | |
return gated_output | |
class ByPass(nn.Module): | |
def __init__(self): | |
super(ByPass, self).__init__() | |
def forward(self, input): | |
return input | |
class DPMulCat(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size, num_spk, | |
dropout=0, num_layers=1, bidirectional=True, input_normalize=False): | |
super(DPMulCat, self).__init__() | |
self.input_size = input_size | |
self.output_size = output_size | |
self.hidden_size = hidden_size | |
self.in_norm = input_normalize | |
self.num_layers = num_layers | |
self.rows_grnn = nn.ModuleList([]) | |
self.cols_grnn = nn.ModuleList([]) | |
self.rows_normalization = nn.ModuleList([]) | |
self.cols_normalization = nn.ModuleList([]) | |
# create the dual path pipeline | |
for i in range(num_layers): | |
self.rows_grnn.append(MulCatBlock( | |
input_size, hidden_size, dropout, bidirectional=bidirectional)) | |
self.cols_grnn.append(MulCatBlock( | |
input_size, hidden_size, dropout, bidirectional=bidirectional)) | |
if self.in_norm: | |
self.rows_normalization.append( | |
nn.GroupNorm(1, input_size, eps=1e-8)) | |
self.cols_normalization.append( | |
nn.GroupNorm(1, input_size, eps=1e-8)) | |
else: | |
# used to disable normalization | |
self.rows_normalization.append(ByPass()) | |
self.cols_normalization.append(ByPass()) | |
self.output = nn.Sequential( | |
nn.PReLU(), nn.Conv2d(input_size, output_size * num_spk, 1)) | |
def forward(self, input): | |
batch_size, _, d1, d2 = input.shape | |
output = input | |
output_all = [] | |
for i in range(self.num_layers): | |
row_input = output.permute(0, 3, 2, 1).contiguous().view( | |
batch_size * d2, d1, -1) | |
row_output = self.rows_grnn[i](row_input) | |
row_output = row_output.view( | |
batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous() | |
row_output = self.rows_normalization[i](row_output) | |
# apply a skip connection | |
if self.training: | |
output = output + row_output | |
else: | |
output += row_output | |
col_input = output.permute(0, 2, 3, 1).contiguous().view( | |
batch_size * d1, d2, -1) | |
col_output = self.cols_grnn[i](col_input) | |
col_output = col_output.view( | |
batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous() | |
col_output = self.cols_normalization[i](col_output).contiguous() | |
# apply a skip connection | |
if self.training: | |
output = output + col_output | |
else: | |
output += col_output | |
output_i = self.output(output) | |
if self.training or i == (self.num_layers - 1): | |
output_all.append(output_i) | |
return output_all | |
class Separator(nn.Module): | |
def __init__(self, input_dim, feature_dim, hidden_dim, output_dim, num_spk=2, | |
layer=4, segment_size=100, input_normalize=False, bidirectional=True): | |
super(Separator, self).__init__() | |
self.input_dim = input_dim | |
self.feature_dim = feature_dim | |
self.hidden_dim = hidden_dim | |
self.output_dim = output_dim | |
self.layer = layer | |
self.segment_size = segment_size | |
self.num_spk = num_spk | |
self.input_normalize = input_normalize | |
self.rnn_model = DPMulCat(self.feature_dim, self.hidden_dim, | |
self.feature_dim, self.num_spk, num_layers=layer, bidirectional=bidirectional, input_normalize=input_normalize) | |
# ======================================= # | |
# The following code block was borrowed and modified from https://github.com/yluo42/TAC | |
# ================ BEGIN ================ # | |
def pad_segment(self, input, segment_size): | |
# input is the features: (B, N, T) | |
batch_size, dim, seq_len = input.shape | |
segment_stride = segment_size // 2 | |
rest = segment_size - (segment_stride + seq_len % | |
segment_size) % segment_size | |
if rest > 0: | |
pad = Variable(torch.zeros(batch_size, dim, rest) | |
).type(input.type()) | |
input = torch.cat([input, pad], 2) | |
pad_aux = Variable(torch.zeros( | |
batch_size, dim, segment_stride)).type(input.type()) | |
input = torch.cat([pad_aux, input, pad_aux], 2) | |
return input, rest | |
def create_chuncks(self, input, segment_size): | |
# split the feature into chunks of segment size | |
# input is the features: (B, N, T) | |
input, rest = self.pad_segment(input, segment_size) | |
batch_size, dim, seq_len = input.shape | |
segment_stride = segment_size // 2 | |
segments1 = input[:, :, :-segment_stride].contiguous().view(batch_size, | |
dim, -1, segment_size) | |
segments2 = input[:, :, segment_stride:].contiguous().view( | |
batch_size, dim, -1, segment_size) | |
segments = torch.cat([segments1, segments2], 3).view( | |
batch_size, dim, -1, segment_size).transpose(2, 3) | |
return segments.contiguous(), rest | |
def merge_chuncks(self, input, rest): | |
# merge the splitted features into full utterance | |
# input is the features: (B, N, L, K) | |
batch_size, dim, segment_size, _ = input.shape | |
segment_stride = segment_size // 2 | |
input = input.transpose(2, 3).contiguous().view( | |
batch_size, dim, -1, segment_size*2) # B, N, K, L | |
input1 = input[:, :, :, :segment_size].contiguous().view( | |
batch_size, dim, -1)[:, :, segment_stride:] | |
input2 = input[:, :, :, segment_size:].contiguous().view( | |
batch_size, dim, -1)[:, :, :-segment_stride] | |
output = input1 + input2 | |
if rest > 0: | |
output = output[:, :, :-rest] | |
return output.contiguous() # B, N, T | |
# ================= END ================= # | |
def forward(self, input): | |
# create chunks | |
enc_segments, enc_rest = self.create_chuncks( | |
input, self.segment_size) | |
# separate | |
output_all = self.rnn_model(enc_segments) | |
# merge back audio files | |
output_all_wav = [] | |
for ii in range(len(output_all)): | |
output_ii = self.merge_chuncks( | |
output_all[ii], enc_rest) | |
output_all_wav.append(output_ii) | |
return output_all_wav | |
class SWave(nn.Module): | |
def __init__(self, N, L, H, R, C, sr, segment, input_normalize): | |
super(SWave, self).__init__() | |
# hyper-parameter | |
self.N, self.L, self.H, self.R, self.C, self.sr, self.segment = N, L, H, R, C, sr, segment | |
self.input_normalize = input_normalize | |
self.context_len = 2 * self.sr / 1000 | |
self.context = int(self.sr * self.context_len / 1000) | |
self.layer = self.R | |
self.filter_dim = self.context * 2 + 1 | |
self.num_spk = self.C | |
# similar to dprnn paper, setting chancksize to sqrt(2*L) | |
self.segment_size = int( | |
np.sqrt(2 * self.sr * self.segment / (self.L/2))) | |
# model sub-networks | |
self.encoder = Encoder(L, N) | |
self.decoder = Decoder(L) | |
self.separator = Separator(self.filter_dim + self.N, self.N, self.H, | |
self.filter_dim, self.num_spk, self.layer, self.segment_size, self.input_normalize) | |
# init | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_normal_(p) | |
def forward(self, mixture): | |
mixture_w = self.encoder(mixture) | |
output_all = self.separator(mixture_w) | |
# fix time dimension, might change due to convolution operations | |
T_mix = mixture.size(-1) | |
# generate wav after each RNN block and optimize the loss | |
outputs = [] | |
for ii in range(len(output_all)): | |
output_ii = output_all[ii].view( | |
mixture.shape[0], self.C, self.N, mixture_w.shape[2]) | |
output_ii = self.decoder(output_ii) | |
T_est = output_ii.size(-1) | |
output_ii = F.pad(output_ii, (0, T_mix - T_est)) | |
outputs.append(output_ii) | |
return torch.stack(outputs) | |
class Encoder(nn.Module): | |
def __init__(self, L, N): | |
super(Encoder, self).__init__() | |
self.L, self.N = L, N | |
# setting 50% overlap | |
self.conv = nn.Conv1d( | |
1, N, kernel_size=L, stride=L // 2, bias=False) | |
def forward(self, mixture): | |
mixture = torch.unsqueeze(mixture, 1) | |
mixture_w = F.relu(self.conv(mixture)) | |
return mixture_w | |
class Decoder(nn.Module): | |
def __init__(self, L): | |
super(Decoder, self).__init__() | |
self.L = L | |
def forward(self, est_source): | |
est_source = torch.transpose(est_source, 2, 3) | |
est_source = nn.AvgPool2d((1, self.L))(est_source) | |
est_source = overlap_and_add(est_source, self.L//2) | |
return est_source | |