DeFTAN-II / DeFTAN2.py
donghoney0416's picture
Upload DeFTAN2.py
795921a verified
raw
history blame
13.8 kB
import math
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from packaging.version import parse as V
from torch.nn import init
from torch.nn.parameter import Parameter
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class DeFTAN2(nn.Module):
def __init__(self, n_srcs=1, win=512, n_mics=4, n_layers=6, att_dim=64, hidden_dim=256, n_head=4, emb_dim=64, emb_ks=4, emb_hs=1, dropout=0.1, eps=1.0e-5):
super().__init__()
self.n_srcs = n_srcs
self.win = win
self.hop = win // 2
self.n_layers = n_layers
self.n_mics = n_mics
self.emb_dim = emb_dim
assert win % 2 == 0
t_ksize = 3
ks, padding = (t_ksize, 3), (t_ksize // 2, 1)
self.up_conv = nn.Sequential(
nn.Conv2d(2 * n_mics, emb_dim * n_head, ks, padding=padding),
nn.GroupNorm(1, emb_dim * n_head, eps=eps),
SDB2d(emb_dim * n_head, emb_dim, n_head)
)
self.blocks = nn.ModuleList([])
for idx in range(n_layers):
self.blocks.append(DeFTAN2block(idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps))
self.down_conv = nn.Sequential(
nn.Conv2d(emb_dim, 2 * n_srcs * n_head, ks, padding=padding),
SDB2d(2 * n_srcs * n_head, 2 * n_srcs, n_head))
def pad_signal(self, input):
# input is the waveforms: (B, T) or (B, 1, T)
# reshape and padding
if input.dim() not in [2, 3]:
raise RuntimeError("Input can only be 2 or 3 dimensional.")
if input.dim() == 2:
input = input.unsqueeze(1)
batch_size = input.size(0)
nchannel = input.size(1)
nsample = input.size(2)
rest = self.win - (self.hop + nsample % self.win) % self.win
if rest > 0:
pad = Variable(torch.zeros(batch_size, nchannel, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, nchannel, self.hop)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def forward(self, input: Union[torch.Tensor]) -> Tuple[List[Union[torch.Tensor]], torch.Tensor, OrderedDict]:
input, rest = self.pad_signal(input)
B, M, N = input.size() # batch B, mic M, time samples N
mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1]
input = input / mix_std_ # RMS normalization
# Encoding
stft_input = torch.stft(input.view([-1, N]), n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
_, F, T, _ = stft_input.size() # B*M , F= num freqs, T= num frame, 2= real imag
xi = stft_input.view([B, M, F, T, 2]) # B*M, F, T, 2 -> B, M, F, T, 2
xi = xi.permute(0, 1, 4, 3, 2).contiguous() # [B, M, 2, T, F]
xi = xi.view([B, M * 2, T, F]) # [B, 2*M, T, F]
# Separation
feature = self.up_conv(xi) # [B, C, T, F]
for ii in range(self.n_layers):
feature = self.blocks[ii](feature) # [B, C, T, F]
xo = self.down_conv(feature).view([B, self.n_srcs, 2, T, F]).view([B * self.n_srcs, 2, T, F])
# Decoding
xo = xo.permute(0, 3, 2, 1).type(input.type()) # [B*n_srcs, 2, T, F] -> [B*n_srcs, F, T, 2]
istft_input = torch.complex(xo[:, :, :, 0], xo[:, :, :, 1])
istft_output = torch.istft(istft_input, n_fft=self.win, hop_length=self.hop, window=torch.hann_window(self.win).type(input.type()), return_complex=False)
output = istft_output[:, self.hop:-(rest + self.hop)].unsqueeze(1) # [B*n_srcs, 1, N]
output = output.view([B, self.n_srcs, -1]) # [B, n_srcs, N]
output = output * mix_std_ # reverse the RMS normalization
return output
class SDB1d(nn.Module):
def __init__(self, in_channels, out_channels, groups):
super().__init__()
assert in_channels // out_channels == groups
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.blocks = nn.ModuleList([])
for idx in range(groups):
self.blocks.append(nn.Sequential(
nn.Conv1d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=3, padding=1),
nn.GroupNorm(1, out_channels, 1e-5),
nn.PReLU(out_channels)
))
def forward(self, x):
B, C, L = x.size()
g = self.groups
# x = x.view(B, g, C//g, L).transpose(1, 2).reshape(B, C, L)
skip = x[:, ::g, :]
for idx in range(g):
output = self.blocks[idx](skip)
skip = torch.cat([output, x[:, idx+1::g, :]], dim=1)
return output
class SDB2d(nn.Module):
def __init__(self, in_channels, out_channels, groups):
super().__init__()
assert in_channels // out_channels == groups
self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.blocks = nn.ModuleList([])
for idx in range(groups):
self.blocks.append(nn.Sequential(
nn.Conv2d(out_channels * ((idx > 0) + 1), out_channels, kernel_size=(3, 3), padding=(1, 1)),
nn.GroupNorm(1, out_channels, 1e-5),
nn.PReLU(out_channels)
))
def forward(self, x):
B, C, T, Q = x.size()
g = self.groups
# x = x.view(B, g, C//g, T, Q).transpose(1, 2).reshape(B, C, T, Q)
skip = x[:, ::g, :, :]
for idx in range(g):
output = self.blocks[idx](skip)
skip = torch.cat([output, x[:, idx+1::g, :, :]], dim=1)
return output
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class CEA(nn.Module):
def __init__(self, dim, heads, dim_head, dropout):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.cv_qk = nn.Sequential(
nn.Conv1d(dim, dim * 2, kernel_size=3, padding=1, bias = False),
nn.GLU(dim=1))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.att_drop = nn.Dropout(dropout)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qk = self.cv_qk(x.transpose(1, 2)).transpose(1, 2)
q = rearrange(self.to_q(qk), 'b n (h d) -> b h n d', h = self.heads)
k = rearrange(self.to_k(qk), 'b n (h d) -> b h n d', h=self.heads)
v = rearrange(self.to_v(x), 'b n (h d) -> b h n d', h = self.heads)
weight = torch.matmul(F.softmax(k, dim=2).transpose(-1, -2), v) * self.scale
out = torch.matmul(F.softmax(q, dim=3), self.att_drop(weight))
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class DPFN(nn.Module):
def __init__(self, dim, hidden_dim, idx, dropout):
super().__init__()
self.PW1 = nn.Sequential(
nn.Linear(dim, hidden_dim//2),
nn.GELU(),
nn.Dropout(dropout)
)
self.PW2 = nn.Sequential(
nn.Linear(dim, hidden_dim//2),
nn.GELU(),
nn.Dropout(dropout)
)
self.DW_Conv = nn.Sequential(
nn.Conv1d(hidden_dim//2, hidden_dim//2, kernel_size=5, dilation=2**idx, padding='same'),
nn.GroupNorm(1, hidden_dim//2, 1e-5),
nn.PReLU(hidden_dim//2)
)
self.PW3 = nn.Sequential(
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
ffw_out = self.PW1(x)
dw_out = self.DW_Conv(self.PW2(x).transpose(1, 2)).transpose(1, 2)
out = self.PW3(torch.cat((ffw_out, dw_out), dim=2))
return out
class DeFTAN2block(nn.Module):
def __getitem__(self, key):
return getattr(self, key)
def __init__(self, idx, emb_dim, emb_ks, emb_hs, att_dim, hidden_dim, n_head, dropout, eps):
super().__init__()
in_channels = emb_dim * emb_ks
self.F_norm = LayerNormalization4D(emb_dim, eps)
self.F_inv = SDB1d(in_channels, emb_dim, emb_ks)
self.F_att = PreNorm(emb_dim, CEA(emb_dim, n_head, att_dim, dropout))
self.F_ffw = PreNorm(emb_dim, DPFN(emb_dim, hidden_dim, idx, dropout))
self.F_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
self.T_norm = LayerNormalization4D(emb_dim, eps)
self.T_inv = SDB1d(in_channels, emb_dim, emb_ks)
self.T_att = PreNorm(emb_dim, CEA(emb_dim, n_head, att_dim, dropout))
self.T_ffw = PreNorm(emb_dim, DPFN(emb_dim, hidden_dim, idx, dropout))
self.T_linear = nn.ConvTranspose1d(emb_dim, emb_dim, emb_ks, stride=emb_hs)
self.emb_dim = emb_dim
self.emb_ks = emb_ks
self.emb_hs = emb_hs
self.n_head = n_head
def forward(self, x):
B, C, old_T, old_Q = x.shape
T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks
x = F.pad(x, (0, Q - old_Q, 0, T - old_T))
# F-transformer
input_ = x
F_feat = self.F_norm(input_) # [B, C, T, Q]
F_feat = F_feat.transpose(1, 2).contiguous().view(B * T, C, Q) # [BT, C, Q]
F_feat = F.unfold(F_feat[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BT, C*emb_ks, -1]
F_feat = self.F_inv(F_feat) # [BT, C, -1]
F_feat = F_feat.transpose(1, 2) # [BT, -1, C]
F_feat = self.F_att(F_feat) + F_feat
F_feat = self.F_ffw(F_feat) + F_feat
F_feat = F_feat.transpose(1, 2) # [BT, H, -1]
F_feat = self.F_linear(F_feat) # [BT, C, Q]
F_feat = F_feat.view([B, T, C, Q])
F_feat = F_feat.transpose(1, 2).contiguous() # [B, C, T, Q]
F_feat = F_feat + input_ # [B, C, T, Q]
# T-transformer
input_ = F_feat
T_feat = self.T_norm(input_) # [B, C, T, F]
T_feat = T_feat.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) # [BF, C, T]
T_feat = F.unfold(T_feat[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1)) # [BF, C*emb_ks, -1]
T_feat = self.T_inv(T_feat) # [BF, C, -1]
T_feat = T_feat.transpose(1, 2) # [BF, -1, C]
T_feat = self.T_att(T_feat) + T_feat
T_feat = self.T_ffw(T_feat) + T_feat
T_feat = T_feat.transpose(1, 2) # [BF, H, -1]
T_feat = self.T_linear(T_feat) # [BF, C, T]
T_feat = T_feat.view([B, Q, C, T])
T_feat = T_feat.permute(0, 2, 3, 1).contiguous() # [B, C, T, Q]
T_feat = T_feat + input_ # [B, C, T, Q]
return T_feat
class LayerNormalization4D(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
param_size = [1, input_dimension, 1, 1]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.eps = eps
def forward(self, x):
if x.ndim == 4:
_, C, _, _ = x.shape
stat_dim = (1,)
else:
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
) # [B,1,T,F]
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat
class LayerNormalization4DCF(nn.Module):
def __init__(self, input_dimension, eps=1e-5):
super().__init__()
assert len(input_dimension) == 2
param_size = [1, input_dimension[0], 1, input_dimension[1]]
self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32))
init.ones_(self.gamma)
init.zeros_(self.beta)
self.eps = eps
def forward(self, x):
if x.ndim == 4:
stat_dim = (1, 3)
else:
raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim))
mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1]
std_ = torch.sqrt(
x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps
) # [B,1,T,F]
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat