|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from .base_model import BaseModel |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
|
def __init__(self, dimension, groups=1): |
|
|
super().__init__() |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(dimension)) |
|
|
self.groups = groups |
|
|
self.eps = 1e-5 |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
B, N, T = input.shape |
|
|
assert N % self.groups == 0 |
|
|
|
|
|
input_float = input.reshape(B, self.groups, -1, T).float() |
|
|
input_norm = input_float * torch.rsqrt(input_float.pow(2).mean(-2, keepdim=True) + self.eps) |
|
|
|
|
|
return input_norm.type_as(input).reshape(B, N, T) * self.weight.reshape(1, -1, 1) |
|
|
|
|
|
class RMVN(nn.Module): |
|
|
""" |
|
|
Rescaled MVN. |
|
|
""" |
|
|
def __init__(self, dimension, groups=1): |
|
|
super(RMVN, self).__init__() |
|
|
|
|
|
self.mean = nn.Parameter(torch.zeros(dimension)) |
|
|
self.std = nn.Parameter(torch.ones(dimension)) |
|
|
self.groups = groups |
|
|
self.eps = 1e-5 |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
B, N = input.shape[:2] |
|
|
assert N % self.groups == 0 |
|
|
input_reshape = input.reshape(B, self.groups, N // self.groups, -1) |
|
|
T = input_reshape.shape[-1] |
|
|
|
|
|
input_norm = (input_reshape - input_reshape.mean(2).unsqueeze(2)) / (input_reshape.var(2).unsqueeze(2) + self.eps).sqrt() |
|
|
input_norm = input_norm.reshape(B, N, T) * self.std.reshape(1, -1, 1) + self.mean.reshape(1, -1, 1) |
|
|
|
|
|
return input_norm.reshape(input.shape) |
|
|
|
|
|
class Roformer(nn.Module): |
|
|
""" |
|
|
Transformer with rotary positional embedding. |
|
|
""" |
|
|
def __init__(self, input_size, hidden_size, num_head=8, theta=10000, window=10000, |
|
|
input_drop=0., attention_drop=0., causal=True): |
|
|
super().__init__() |
|
|
|
|
|
self.input_size = input_size |
|
|
self.hidden_size = hidden_size // num_head |
|
|
self.num_head = num_head |
|
|
self.theta = theta |
|
|
self.window = window |
|
|
|
|
|
cos_freq, sin_freq = self._calc_rotary_emb() |
|
|
self.register_buffer("cos_freq", cos_freq) |
|
|
self.register_buffer("sin_freq", sin_freq) |
|
|
|
|
|
self.attention_drop = attention_drop |
|
|
self.causal = causal |
|
|
self.eps = 1e-5 |
|
|
|
|
|
self.input_norm = RMSNorm(self.input_size) |
|
|
self.input_drop = nn.Dropout(p=input_drop) |
|
|
self.weight = nn.Conv1d(self.input_size, self.hidden_size*self.num_head*3, 1, bias=False) |
|
|
self.output = nn.Conv1d(self.hidden_size*self.num_head, self.input_size, 1, bias=False) |
|
|
|
|
|
self.MLP = nn.Sequential(RMSNorm(self.input_size), |
|
|
nn.Conv1d(self.input_size, self.input_size*8, 1, bias=False), |
|
|
nn.SiLU() |
|
|
) |
|
|
self.MLP_output = nn.Conv1d(self.input_size*4, self.input_size, 1, bias=False) |
|
|
|
|
|
def _calc_rotary_emb(self): |
|
|
freq = 1. / (self.theta ** (torch.arange(0, self.hidden_size, 2)[:(self.hidden_size // 2)] / self.hidden_size)) |
|
|
freq = freq.reshape(1, -1) |
|
|
pos = torch.arange(0, self.window).reshape(-1, 1) |
|
|
cos_freq = torch.cos(pos*freq) |
|
|
sin_freq = torch.sin(pos*freq) |
|
|
cos_freq = torch.stack([cos_freq]*2, -1).reshape(self.window, self.hidden_size) |
|
|
sin_freq = torch.stack([sin_freq]*2, -1).reshape(self.window, self.hidden_size) |
|
|
|
|
|
return cos_freq, sin_freq |
|
|
|
|
|
def _add_rotary_emb(self, feature, pos): |
|
|
|
|
|
N = feature.shape[-1] |
|
|
|
|
|
feature_reshape = feature.reshape(-1, N) |
|
|
pos = min(pos, self.window-1) |
|
|
cos_freq = self.cos_freq[pos] |
|
|
sin_freq = self.sin_freq[pos] |
|
|
reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) |
|
|
feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, N) |
|
|
feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) |
|
|
|
|
|
return feature_rope.reshape(feature.shape) |
|
|
|
|
|
def _add_rotary_sequence(self, feature): |
|
|
|
|
|
T, N = feature.shape[-2:] |
|
|
feature_reshape = feature.reshape(-1, T, N) |
|
|
|
|
|
cos_freq = self.cos_freq[:T] |
|
|
sin_freq = self.sin_freq[:T] |
|
|
reverse_sign = torch.from_numpy(np.asarray([-1, 1])).to(feature.device).type(feature.dtype) |
|
|
feature_reshape_neg = (torch.flip(feature_reshape.reshape(-1, N//2, 2), [-1]) * reverse_sign.reshape(1, 1, 2)).reshape(-1, T, N) |
|
|
feature_rope = feature_reshape * cos_freq.unsqueeze(0) + feature_reshape_neg * sin_freq.unsqueeze(0) |
|
|
|
|
|
return feature_rope.reshape(feature.shape) |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
|
B, _, T = input.shape |
|
|
|
|
|
weight = self.weight(self.input_drop(self.input_norm(input))).reshape(B, self.num_head, self.hidden_size*3, T).mT |
|
|
Q, K, V = torch.split(weight, self.hidden_size, dim=-1) |
|
|
|
|
|
|
|
|
Q_rot = self._add_rotary_sequence(Q) |
|
|
K_rot = self._add_rotary_sequence(K) |
|
|
|
|
|
attention_output = F.scaled_dot_product_attention(Q_rot.contiguous(), K_rot.contiguous(), V.contiguous(), dropout_p=self.attention_drop, is_causal=self.causal) |
|
|
attention_output = attention_output.mT.reshape(B, -1, T) |
|
|
output = self.output(attention_output) + input |
|
|
|
|
|
gate, z = self.MLP(output).chunk(2, dim=1) |
|
|
output = output + self.MLP_output(F.silu(gate) * z) |
|
|
|
|
|
return output, (K_rot, V) |
|
|
|
|
|
class ConvActNorm1d(nn.Module): |
|
|
def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): |
|
|
super(ConvActNorm1d, self).__init__() |
|
|
|
|
|
self.in_channel = in_channel |
|
|
self.kernel = kernel |
|
|
self.causal = causal |
|
|
if not causal: |
|
|
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2, groups=in_channel), |
|
|
RMSNorm(in_channel), |
|
|
nn.Conv1d(in_channel, hidden_channel, 1), |
|
|
nn.SiLU(), |
|
|
nn.Conv1d(hidden_channel, in_channel, 1) |
|
|
) |
|
|
else: |
|
|
self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1, groups=in_channel), |
|
|
RMSNorm(in_channel), |
|
|
nn.Conv1d(in_channel, hidden_channel, 1), |
|
|
nn.SiLU(), |
|
|
nn.Conv1d(hidden_channel, in_channel, 1) |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
output = self.conv(input) |
|
|
if self.causal: |
|
|
output = output[...,:-self.kernel+1] |
|
|
return input + output |
|
|
|
|
|
class ICB(nn.Module): |
|
|
def __init__(self, in_channel, kernel=7, causal=False): |
|
|
super(ICB, self).__init__() |
|
|
|
|
|
self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
|
|
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), |
|
|
ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) |
|
|
) |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
return self.blocks(input) |
|
|
|
|
|
class BSNet(nn.Module): |
|
|
def __init__(self, feature_dim, kernel=7): |
|
|
super(BSNet, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
self.band_net = Roformer(self.feature_dim, self.feature_dim, num_head=8, window=100, causal=False) |
|
|
self.seq_net = ICB(self.feature_dim, kernel=kernel) |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
|
|
|
B, nband, N, T = input.shape |
|
|
|
|
|
|
|
|
band_input = input.permute(0,3,2,1).reshape(B*T, -1, nband) |
|
|
band_output, _ = self.band_net(band_input) |
|
|
band_output = band_output.reshape(B, T, -1, nband).permute(0,3,2,1) |
|
|
|
|
|
|
|
|
output = self.seq_net(band_output.reshape(B*nband, -1, T)).reshape(B, nband, -1, T) |
|
|
|
|
|
return output |
|
|
|
|
|
class Apollo(BaseModel): |
|
|
def __init__( |
|
|
self, |
|
|
sr: int, |
|
|
win: int, |
|
|
feature_dim: int, |
|
|
layer: int |
|
|
): |
|
|
super().__init__(sample_rate=sr) |
|
|
|
|
|
self.sr = sr |
|
|
self.win = int(sr * win // 1000) |
|
|
self.stride = self.win // 2 |
|
|
self.enc_dim = self.win // 2 + 1 |
|
|
self.feature_dim = feature_dim |
|
|
self.eps = torch.finfo(torch.float32).eps |
|
|
|
|
|
|
|
|
bandwidth = int(self.win / 160) |
|
|
self.band_width = [bandwidth]*79 |
|
|
self.band_width.append(self.enc_dim - np.sum(self.band_width)) |
|
|
self.nband = len(self.band_width) |
|
|
print(self.band_width, self.nband) |
|
|
|
|
|
self.BN = nn.ModuleList([]) |
|
|
for i in range(self.nband): |
|
|
self.BN.append(nn.Sequential(RMSNorm(self.band_width[i]*2+1), |
|
|
nn.Conv1d(self.band_width[i]*2+1, self.feature_dim, 1)) |
|
|
) |
|
|
|
|
|
self.net = [] |
|
|
for _ in range(layer): |
|
|
self.net.append(BSNet(self.feature_dim)) |
|
|
self.net = nn.Sequential(*self.net) |
|
|
|
|
|
self.output = nn.ModuleList([]) |
|
|
for i in range(self.nband): |
|
|
self.output.append(nn.Sequential(RMSNorm(self.feature_dim), |
|
|
nn.Conv1d(self.feature_dim, self.band_width[i]*4, 1), |
|
|
nn.GLU(dim=1) |
|
|
) |
|
|
) |
|
|
|
|
|
def spec_band_split(self, input): |
|
|
|
|
|
B, nch, nsample = input.shape |
|
|
|
|
|
spec = torch.stft(input.view(B*nch, nsample), n_fft=self.win, hop_length=self.stride, |
|
|
window=torch.hann_window(self.win).to(input.device), return_complex=True) |
|
|
|
|
|
subband_spec = [] |
|
|
subband_spec_norm = [] |
|
|
subband_power = [] |
|
|
band_idx = 0 |
|
|
for i in range(self.nband): |
|
|
this_spec = spec[:,band_idx:band_idx+self.band_width[i]] |
|
|
subband_spec.append(this_spec) |
|
|
subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) |
|
|
subband_spec_norm.append(torch.complex(this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1])) |
|
|
band_idx += self.band_width[i] |
|
|
subband_power = torch.cat(subband_power, 1) |
|
|
|
|
|
return subband_spec_norm, subband_power |
|
|
|
|
|
def feature_extractor(self, input): |
|
|
|
|
|
subband_spec_norm, subband_power = self.spec_band_split(input) |
|
|
|
|
|
|
|
|
subband_feature = [] |
|
|
for i in range(self.nband): |
|
|
concat_spec = torch.cat([subband_spec_norm[i].real, subband_spec_norm[i].imag, torch.log(subband_power[:,i].unsqueeze(1))], 1) |
|
|
subband_feature.append(self.BN[i](concat_spec)) |
|
|
subband_feature = torch.stack(subband_feature, 1) |
|
|
|
|
|
return subband_feature |
|
|
|
|
|
def forward(self, input): |
|
|
|
|
|
B, nch, nsample = input.shape |
|
|
|
|
|
subband_feature = self.feature_extractor(input) |
|
|
feature = self.net(subband_feature) |
|
|
|
|
|
est_spec = [] |
|
|
for i in range(self.nband): |
|
|
this_RI = self.output[i](feature[:,i]).view(B*nch, 2, self.band_width[i], -1) |
|
|
est_spec.append(torch.complex(this_RI[:,0], this_RI[:,1])) |
|
|
est_spec = torch.cat(est_spec, 1) |
|
|
output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, |
|
|
window=torch.hann_window(self.win).to(input.device), length=nsample).view(B, nch, -1) |
|
|
|
|
|
return output |
|
|
|
|
|
def get_model_args(self): |
|
|
model_args = {"n_sample_rate": 2} |
|
|
return model_args |