|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from . import normalizations, activations |
|
|
|
|
|
|
|
|
class _Chop1d(nn.Module): |
|
|
"""To ensure the output length is the same as the input.""" |
|
|
|
|
|
def __init__(self, chop_size): |
|
|
super().__init__() |
|
|
self.chop_size = chop_size |
|
|
|
|
|
def forward(self, x): |
|
|
return x[..., : -self.chop_size].contiguous() |
|
|
|
|
|
|
|
|
class Conv1DBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan, |
|
|
hid_chan, |
|
|
skip_out_chan, |
|
|
kernel_size, |
|
|
padding, |
|
|
dilation, |
|
|
norm_type="gLN", |
|
|
causal=False, |
|
|
): |
|
|
super(Conv1DBlock, self).__init__() |
|
|
self.skip_out_chan = skip_out_chan |
|
|
conv_norm = normalizations.get(norm_type) |
|
|
in_conv1d = nn.Conv1d(in_chan, hid_chan, 1) |
|
|
depth_conv1d = nn.Conv1d( |
|
|
hid_chan, |
|
|
hid_chan, |
|
|
kernel_size, |
|
|
padding=padding, |
|
|
dilation=dilation, |
|
|
groups=hid_chan, |
|
|
) |
|
|
if causal: |
|
|
depth_conv1d = nn.Sequential(depth_conv1d, _Chop1d(padding)) |
|
|
self.shared_block = nn.Sequential( |
|
|
in_conv1d, |
|
|
nn.PReLU(), |
|
|
conv_norm(hid_chan), |
|
|
depth_conv1d, |
|
|
nn.PReLU(), |
|
|
conv_norm(hid_chan), |
|
|
) |
|
|
self.res_conv = nn.Conv1d(hid_chan, in_chan, 1) |
|
|
if skip_out_chan: |
|
|
self.skip_conv = nn.Conv1d(hid_chan, skip_out_chan, 1) |
|
|
|
|
|
def forward(self, x): |
|
|
r"""Input shape $(batch, feats, seq)$.""" |
|
|
shared_out = self.shared_block(x) |
|
|
res_out = self.res_conv(shared_out) |
|
|
if not self.skip_out_chan: |
|
|
return res_out |
|
|
skip_out = self.skip_conv(shared_out) |
|
|
return res_out, skip_out |
|
|
|
|
|
|
|
|
class ConvNormAct(nn.Module): |
|
|
""" |
|
|
This class defines the convolution layer with normalization and a PReLU |
|
|
activation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super(ConvNormAct, self).__init__() |
|
|
self.conv = nn.Conv1d( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size, |
|
|
stride=stride, |
|
|
dilation=dilation, |
|
|
padding=padding, |
|
|
bias=True, |
|
|
groups=groups, |
|
|
) |
|
|
self.norm = normalizations.get(norm_type)(out_chan) |
|
|
self.act = activations.get(act_type)() |
|
|
|
|
|
def forward(self, x): |
|
|
output = self.conv(x) |
|
|
output = self.norm(output) |
|
|
return self.act(output) |
|
|
|
|
|
|
|
|
class ConvNorm(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type="gLN", |
|
|
): |
|
|
super(ConvNorm, self).__init__() |
|
|
self.conv = nn.Conv1d( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size, |
|
|
stride, |
|
|
padding, |
|
|
dilation, |
|
|
bias=True, |
|
|
groups=groups, |
|
|
) |
|
|
self.norm = normalizations.get(norm_type)(out_chan) |
|
|
|
|
|
def forward(self, x): |
|
|
output = self.conv(x) |
|
|
return self.norm(output) |
|
|
|
|
|
|
|
|
class NormAct(nn.Module): |
|
|
""" |
|
|
This class defines a normalization and PReLU activation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, out_chan, norm_type="gLN", act_type="prelu", |
|
|
): |
|
|
""" |
|
|
:param nOut: number of output channels |
|
|
""" |
|
|
super(NormAct, self).__init__() |
|
|
|
|
|
self.norm = normalizations.get(norm_type)(out_chan) |
|
|
self.act = activations.get(act_type)() |
|
|
|
|
|
def forward(self, input): |
|
|
output = self.norm(input) |
|
|
return self.act(output) |
|
|
|
|
|
|
|
|
class Video1DConv(nn.Module): |
|
|
""" |
|
|
video part 1-D Conv Block |
|
|
in_chan: video Encoder output channels |
|
|
out_chan: dconv channels |
|
|
kernel_size: the depthwise conv kernel size |
|
|
dilation: the depthwise conv dilation |
|
|
residual: Whether to use residual connection |
|
|
skip_con: Whether to use skip connection |
|
|
first_block: first block, not residual |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size, |
|
|
dilation=1, |
|
|
residual=True, |
|
|
skip_con=True, |
|
|
first_block=True, |
|
|
): |
|
|
super(Video1DConv, self).__init__() |
|
|
self.first_block = first_block |
|
|
|
|
|
self.residual = residual and not first_block |
|
|
self.bn = nn.BatchNorm1d(in_chan) if not first_block else None |
|
|
self.relu = nn.ReLU() if not first_block else None |
|
|
self.dconv = nn.Conv1d( |
|
|
in_chan, |
|
|
in_chan, |
|
|
kernel_size, |
|
|
groups=in_chan, |
|
|
dilation=dilation, |
|
|
padding=(dilation * (kernel_size - 1)) // 2, |
|
|
bias=True, |
|
|
) |
|
|
self.bconv = nn.Conv1d(in_chan, out_chan, 1) |
|
|
self.sconv = nn.Conv1d(in_chan, out_chan, 1) |
|
|
self.skip_con = skip_con |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
x: [B, N, T] |
|
|
out: [B, N, T] |
|
|
""" |
|
|
if not self.first_block: |
|
|
y = self.bn(self.relu(x)) |
|
|
y = self.dconv(y) |
|
|
else: |
|
|
y = self.dconv(x) |
|
|
|
|
|
if self.skip_con: |
|
|
skip = self.sconv(y) |
|
|
if self.residual: |
|
|
y = y + x |
|
|
return skip, y |
|
|
else: |
|
|
return skip, y |
|
|
else: |
|
|
y = self.bconv(y) |
|
|
if self.residual: |
|
|
y = y + x |
|
|
return y |
|
|
else: |
|
|
return y |
|
|
|
|
|
|
|
|
class Concat(nn.Module): |
|
|
def __init__(self, ain_chan, vin_chan, out_chan): |
|
|
super(Concat, self).__init__() |
|
|
self.ain_chan = ain_chan |
|
|
self.vin_chan = vin_chan |
|
|
|
|
|
self.conv1d = nn.Sequential( |
|
|
nn.Conv1d(ain_chan + vin_chan, out_chan, 1), nn.PReLU() |
|
|
) |
|
|
|
|
|
def forward(self, a, v): |
|
|
|
|
|
v = torch.nn.functional.interpolate(v, size=a.size(-1)) |
|
|
|
|
|
y = torch.cat([a, v], dim=1) |
|
|
|
|
|
return self.conv1d(y) |
|
|
|
|
|
|
|
|
class FRCNNBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_1x1 = ConvNormAct( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
self.depth = upsampling_depth |
|
|
self.spp_dw = nn.ModuleList([]) |
|
|
self.spp_dw.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=1, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
|
|
|
for i in range(1, upsampling_depth): |
|
|
self.spp_dw.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
|
|
|
self.fuse_layers = nn.ModuleList([]) |
|
|
for i in range(upsampling_depth): |
|
|
fuse_layer = nn.ModuleList([]) |
|
|
for j in range(upsampling_depth): |
|
|
if i == j: |
|
|
fuse_layer.append(None) |
|
|
elif j - i == 1: |
|
|
fuse_layer.append(None) |
|
|
elif i - j == 1: |
|
|
fuse_layer.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
self.fuse_layers.append(fuse_layer) |
|
|
self.concat_layer = nn.ModuleList([]) |
|
|
|
|
|
for i in range(upsampling_depth): |
|
|
if i == 0 or i == upsampling_depth - 1: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 2, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
else: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 3, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.last_layer = nn.Sequential( |
|
|
ConvNormAct( |
|
|
out_chan * upsampling_depth, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.res_conv = nn.Conv1d(out_chan, in_chan, 1) |
|
|
|
|
|
self.depth = upsampling_depth |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
:param x: input feature map |
|
|
:return: transformed feature map |
|
|
""" |
|
|
residual = x.clone() |
|
|
|
|
|
output1 = self.proj_1x1(x) |
|
|
output = [self.spp_dw[0](output1)] |
|
|
for k in range(1, self.depth): |
|
|
out_k = self.spp_dw[k](output[-1]) |
|
|
output.append(out_k) |
|
|
|
|
|
x_fuse = [] |
|
|
for i in range(len(self.fuse_layers)): |
|
|
wav_length = output[i].shape[-1] |
|
|
y = torch.cat( |
|
|
( |
|
|
self.fuse_layers[i][0](output[i - 1]) |
|
|
if i - 1 >= 0 |
|
|
else torch.Tensor().to(output1.device), |
|
|
output[i], |
|
|
F.interpolate(output[i + 1], size=wav_length, mode="nearest") |
|
|
if i + 1 < self.depth |
|
|
else torch.Tensor().to(output1.device), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
x_fuse.append(self.concat_layer[i](y)) |
|
|
|
|
|
wav_length = output[0].shape[-1] |
|
|
for i in range(1, len(x_fuse)): |
|
|
x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") |
|
|
|
|
|
concat = self.last_layer(torch.cat(x_fuse, dim=1)) |
|
|
expanded = self.res_conv(concat) |
|
|
return expanded + residual |
|
|
|
|
|
|
|
|
class Bottomup(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_1x1 = ConvNormAct( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
self.depth = upsampling_depth |
|
|
self.spp_dw = nn.ModuleList([]) |
|
|
self.spp_dw.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=1, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
|
|
|
for i in range(1, upsampling_depth): |
|
|
self.spp_dw.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x.clone() |
|
|
|
|
|
output1 = self.proj_1x1(x) |
|
|
output = [self.spp_dw[0](output1)] |
|
|
for k in range(1, self.depth): |
|
|
out_k = self.spp_dw[k](output[-1]) |
|
|
output.append(out_k) |
|
|
|
|
|
return residual, output[-1], output |
|
|
|
|
|
|
|
|
class BottomupTCN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_1x1 = ConvNormAct( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
self.depth = upsampling_depth |
|
|
self.spp_dw = nn.ModuleList([]) |
|
|
self.spp_dw.append( |
|
|
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) |
|
|
) |
|
|
|
|
|
for i in range(1, upsampling_depth): |
|
|
self.spp_dw.append( |
|
|
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x.clone() |
|
|
|
|
|
output1 = self.proj_1x1(x) |
|
|
output = [self.spp_dw[0](output1)] |
|
|
for k in range(1, self.depth): |
|
|
out_k = self.spp_dw[k](output[-1]) |
|
|
output.append(out_k) |
|
|
|
|
|
return residual, output[-1], output |
|
|
|
|
|
|
|
|
class Bottomup_Concat_Topdown(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.fuse_layers = nn.ModuleList([]) |
|
|
for i in range(upsampling_depth): |
|
|
fuse_layer = nn.ModuleList([]) |
|
|
for j in range(upsampling_depth): |
|
|
if i == j: |
|
|
fuse_layer.append(None) |
|
|
elif j - i == 1: |
|
|
fuse_layer.append(None) |
|
|
elif i - j == 1: |
|
|
fuse_layer.append( |
|
|
ConvNorm( |
|
|
out_chan, |
|
|
out_chan, |
|
|
kernel_size=5, |
|
|
stride=2, |
|
|
groups=out_chan, |
|
|
dilation=1, |
|
|
padding=((5 - 1) // 2) * 1, |
|
|
norm_type=norm_type, |
|
|
) |
|
|
) |
|
|
self.fuse_layers.append(fuse_layer) |
|
|
self.concat_layer = nn.ModuleList([]) |
|
|
|
|
|
for i in range(upsampling_depth): |
|
|
if i == 0 or i == upsampling_depth - 1: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 3, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
else: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 4, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.last_layer = nn.Sequential( |
|
|
ConvNormAct( |
|
|
out_chan * upsampling_depth, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.res_conv = nn.Conv1d(out_chan, in_chan, 1) |
|
|
|
|
|
self.depth = upsampling_depth |
|
|
|
|
|
def forward(self, residual, bottomup, topdown): |
|
|
x_fuse = [] |
|
|
for i in range(len(self.fuse_layers)): |
|
|
wav_length = bottomup[i].shape[-1] |
|
|
y = torch.cat( |
|
|
( |
|
|
self.fuse_layers[i][0](bottomup[i - 1]) |
|
|
if i - 1 >= 0 |
|
|
else torch.Tensor().to(bottomup[i].device), |
|
|
bottomup[i], |
|
|
F.interpolate(bottomup[i + 1], size=wav_length, mode="nearest") |
|
|
if i + 1 < self.depth |
|
|
else torch.Tensor().to(bottomup[i].device), |
|
|
F.interpolate(topdown, size=wav_length, mode="nearest"), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
x_fuse.append(self.concat_layer[i](y)) |
|
|
|
|
|
wav_length = bottomup[0].shape[-1] |
|
|
for i in range(1, len(x_fuse)): |
|
|
x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest") |
|
|
|
|
|
concat = self.last_layer(torch.cat(x_fuse, dim=1)) |
|
|
expanded = self.res_conv(concat) |
|
|
return expanded + residual |
|
|
|
|
|
|
|
|
class Bottomup_Concat_Topdown_TCN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.fuse_layers = nn.ModuleList([]) |
|
|
for i in range(upsampling_depth): |
|
|
fuse_layer = nn.ModuleList([]) |
|
|
for j in range(upsampling_depth): |
|
|
if i == j: |
|
|
fuse_layer.append(None) |
|
|
elif j - i == 1: |
|
|
fuse_layer.append(None) |
|
|
elif i - j == 1: |
|
|
fuse_layer.append(None) |
|
|
self.fuse_layers.append(fuse_layer) |
|
|
self.concat_layer = nn.ModuleList([]) |
|
|
|
|
|
for i in range(upsampling_depth): |
|
|
if i == 0 or i == upsampling_depth - 1: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 3, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
else: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 4, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.last_layer = nn.Sequential( |
|
|
ConvNormAct( |
|
|
out_chan * upsampling_depth, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.res_conv = nn.Conv1d(out_chan, in_chan, 1) |
|
|
|
|
|
self.depth = upsampling_depth |
|
|
|
|
|
def forward(self, residual, bottomup, topdown): |
|
|
x_fuse = [] |
|
|
for i in range(len(self.fuse_layers)): |
|
|
wav_length = bottomup[i].shape[-1] |
|
|
y = torch.cat( |
|
|
( |
|
|
bottomup[i - 1] |
|
|
if i - 1 >= 0 |
|
|
else torch.Tensor().to(bottomup[i].device), |
|
|
bottomup[i], |
|
|
bottomup[i + 1] |
|
|
if i + 1 < self.depth |
|
|
else torch.Tensor().to(bottomup[i].device), |
|
|
F.interpolate(topdown, size=wav_length, mode="nearest"), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
x_fuse.append(self.concat_layer[i](y)) |
|
|
|
|
|
concat = self.last_layer(torch.cat(x_fuse, dim=1)) |
|
|
expanded = self.res_conv(concat) |
|
|
return expanded + residual |
|
|
|
|
|
|
|
|
class FRCNNBlockTCN(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_chan=128, |
|
|
out_chan=512, |
|
|
upsampling_depth=4, |
|
|
norm_type="gLN", |
|
|
act_type="prelu", |
|
|
): |
|
|
super().__init__() |
|
|
self.proj_1x1 = ConvNormAct( |
|
|
in_chan, |
|
|
out_chan, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
groups=1, |
|
|
dilation=1, |
|
|
padding=0, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
self.depth = upsampling_depth |
|
|
self.spp_dw = nn.ModuleList([]) |
|
|
self.spp_dw.append( |
|
|
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True) |
|
|
) |
|
|
|
|
|
for i in range(1, upsampling_depth): |
|
|
self.spp_dw.append( |
|
|
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False) |
|
|
) |
|
|
|
|
|
self.fuse_layers = nn.ModuleList([]) |
|
|
for i in range(upsampling_depth): |
|
|
fuse_layer = nn.ModuleList([]) |
|
|
for j in range(upsampling_depth): |
|
|
if i == j: |
|
|
fuse_layer.append(None) |
|
|
elif j - i == 1: |
|
|
fuse_layer.append(None) |
|
|
elif i - j == 1: |
|
|
fuse_layer.append(None) |
|
|
self.fuse_layers.append(fuse_layer) |
|
|
self.concat_layer = nn.ModuleList([]) |
|
|
|
|
|
for i in range(upsampling_depth): |
|
|
if i == 0 or i == upsampling_depth - 1: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 2, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
else: |
|
|
self.concat_layer.append( |
|
|
ConvNormAct( |
|
|
out_chan * 3, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.last_layer = nn.Sequential( |
|
|
ConvNormAct( |
|
|
out_chan * upsampling_depth, |
|
|
out_chan, |
|
|
1, |
|
|
1, |
|
|
norm_type=norm_type, |
|
|
act_type=act_type, |
|
|
) |
|
|
) |
|
|
self.res_conv = nn.Conv1d(out_chan, in_chan, 1) |
|
|
|
|
|
self.depth = upsampling_depth |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
:param x: input feature map |
|
|
:return: transformed feature map |
|
|
""" |
|
|
residual = x.clone() |
|
|
|
|
|
output1 = self.proj_1x1(x) |
|
|
output = [self.spp_dw[0](output1)] |
|
|
for k in range(1, self.depth): |
|
|
out_k = self.spp_dw[k](output[-1]) |
|
|
output.append(out_k) |
|
|
|
|
|
x_fuse = [] |
|
|
for i in range(len(self.fuse_layers)): |
|
|
wav_length = output[i].shape[-1] |
|
|
y = torch.cat( |
|
|
( |
|
|
output[i - 1] if i - 1 >= 0 else torch.Tensor().to(output1.device), |
|
|
output[i], |
|
|
output[i + 1] |
|
|
if i + 1 < self.depth |
|
|
else torch.Tensor().to(output1.device), |
|
|
), |
|
|
dim=1, |
|
|
) |
|
|
x_fuse.append(self.concat_layer[i](y)) |
|
|
|
|
|
concat = self.last_layer(torch.cat(x_fuse, dim=1)) |
|
|
expanded = self.res_conv(concat) |
|
|
return expanded + residual |
|
|
|
|
|
|
|
|
class TAC(nn.Module): |
|
|
"""Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1]. |
|
|
Args: |
|
|
input_dim (int): Number of features of input representation. |
|
|
hidden_dim (int, optional): size of hidden layers in TAC operations. |
|
|
activation (str, optional): type of activation used. See asteroid.masknn.activations. |
|
|
norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms. |
|
|
.. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)` |
|
|
as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``. |
|
|
Output is of same shape as input. |
|
|
References |
|
|
[1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel |
|
|
speech separation." ICASSP 2020. |
|
|
""" |
|
|
|
|
|
def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"): |
|
|
super().__init__() |
|
|
self.hidden_dim = hidden_dim |
|
|
self.input_tf = nn.Sequential( |
|
|
nn.Linear(input_dim, hidden_dim), activations.get(activation)() |
|
|
) |
|
|
self.avg_tf = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), activations.get(activation)() |
|
|
) |
|
|
self.concat_tf = nn.Sequential( |
|
|
nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)() |
|
|
) |
|
|
self.norm = normalizations.get(norm_type)(input_dim) |
|
|
|
|
|
def forward(self, x, valid_mics=None): |
|
|
""" |
|
|
Args: |
|
|
x: (:class:`torch.Tensor`): Input multi-channel DPRNN features. |
|
|
Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. |
|
|
valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch. |
|
|
Batches can be composed of examples coming from arrays with a different |
|
|
number of microphones and thus the ``mic_channels`` dimension is padded. |
|
|
E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3. |
|
|
Shape: :math`(batch)`. |
|
|
Returns: |
|
|
output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing. |
|
|
Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`. |
|
|
""" |
|
|
|
|
|
batch_size, nmics, channels, chunk_size, n_chunks = x.size() |
|
|
if valid_mics is None: |
|
|
valid_mics = torch.LongTensor([nmics] * batch_size) |
|
|
|
|
|
output = self.input_tf( |
|
|
x.permute(0, 3, 4, 1, 2).reshape( |
|
|
batch_size * nmics * chunk_size * n_chunks, channels |
|
|
) |
|
|
).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim) |
|
|
|
|
|
|
|
|
if valid_mics.max() == 0: |
|
|
|
|
|
mics_mean = output.mean(1) |
|
|
else: |
|
|
|
|
|
mics_mean = [ |
|
|
output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0) |
|
|
for b in range(batch_size) |
|
|
] |
|
|
mics_mean = torch.cat(mics_mean, 0) |
|
|
|
|
|
|
|
|
mics_mean = self.avg_tf( |
|
|
mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim) |
|
|
) |
|
|
mics_mean = ( |
|
|
mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim) |
|
|
.unsqueeze(3) |
|
|
.expand_as(output) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
output = torch.cat([output, mics_mean], -1) |
|
|
output = self.concat_tf( |
|
|
output.reshape(batch_size * chunk_size * n_chunks * nmics, -1) |
|
|
).reshape(batch_size, chunk_size, n_chunks, nmics, -1) |
|
|
output = self.norm( |
|
|
output.permute(0, 3, 4, 1, 2).reshape( |
|
|
batch_size * nmics, -1, chunk_size, n_chunks |
|
|
) |
|
|
).reshape(batch_size, nmics, -1, chunk_size, n_chunks) |
|
|
|
|
|
output += x |
|
|
return output |
|
|
|