OcTra / df_local /modules.py
arcan3's picture
adding rust
35916c5
raw
history blame
35 kB
import math
from collections import OrderedDict
from typing import Callable, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.nn import init
from torch.nn.parameter import Parameter
from typing_extensions import Final
from df_local.model import ModelParams
from df_local.utils import as_complex, as_real, get_device, get_norm_alpha
from libdf import unit_norm_init
class Conv2dNormAct(nn.Sequential):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
):
"""Causal Conv2d by delaying the signal for any lookahead.
Expected input format: [B, C, T, F]
"""
lookahead = 0 # This needs to be handled on the input feature side
# Padding on time axis
kernel_size = (
(kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
)
if fpad:
fpad_ = kernel_size[1] // 2 + dilation - 1
else:
fpad_ = 0
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_ch, out_ch) if separable else 1
if groups == 1:
separable = False
if max(kernel_size) == 1:
separable = False
layers.append(
nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=(0, fpad_),
stride=(1, fstride), # Stride over time is always 1
dilation=(1, dilation), # Same for dilation
groups=groups,
bias=bias,
)
)
if separable:
layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
if norm_layer is not None:
layers.append(norm_layer(out_ch))
if activation_layer is not None:
layers.append(activation_layer())
super().__init__(*layers)
class ConvTranspose2dNormAct(nn.Sequential):
def __init__(
self,
in_ch: int,
out_ch: int,
kernel_size: Union[int, Tuple[int, int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
):
"""Causal ConvTranspose2d.
Expected input format: [B, C, T, F]
"""
# Padding on time axis, with lookahead = 0
lookahead = 0 # This needs to be handled on the input feature side
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if fpad:
fpad_ = kernel_size[1] // 2
else:
fpad_ = 0
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_ch, out_ch) if separable else 1
if groups == 1:
separable = False
layers.append(
nn.ConvTranspose2d(
in_ch,
out_ch,
kernel_size=kernel_size,
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
output_padding=(0, fpad_),
stride=(1, fstride), # Stride over time is always 1
dilation=(1, dilation),
groups=groups,
bias=bias,
)
)
if separable:
layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False))
if norm_layer is not None:
layers.append(norm_layer(out_ch))
if activation_layer is not None:
layers.append(activation_layer())
super().__init__(*layers)
def convkxf(
in_ch: int,
out_ch: Optional[int] = None,
k: int = 1,
f: int = 3,
fstride: int = 2,
lookahead: int = 0,
batch_norm: bool = False,
act: nn.Module = nn.ReLU(inplace=True),
mode="normal", # must be "normal", "transposed" or "upsample"
depthwise: bool = True,
complex_in: bool = False,
):
bias = batch_norm is False
assert f % 2 == 1
stride = 1 if f == 1 else (1, fstride)
if out_ch is None:
out_ch = in_ch * 2 if mode == "normal" else in_ch // 2
fpad = (f - 1) // 2
convpad = (0, fpad)
modules = []
# Manually pad for time axis kernel to not introduce delay
pad = (0, 0, k - 1 - lookahead, lookahead)
if any(p > 0 for p in pad):
modules.append(("pad", nn.ConstantPad2d(pad, 0.0)))
if depthwise:
groups = min(in_ch, out_ch)
else:
groups = 1
if in_ch % groups != 0 or out_ch % groups != 0:
groups = 1
if complex_in and groups % 2 == 0:
groups //= 2
convkwargs = {
"in_channels": in_ch,
"out_channels": out_ch,
"kernel_size": (k, f),
"stride": stride,
"groups": groups,
"bias": bias,
}
if mode == "normal":
modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
elif mode == "transposed":
# Since pytorch's transposed conv padding does not correspond to the actual padding but
# rather the padding that was used in the encoder conv, we need to set time axis padding
# according to k. E.g., this disables the padding for k=2:
# dilation - (k - 1) - padding
# = 1 - (2 - 1) - 1 = 0; => padding = fpad (=1 for k=2)
padding = (k - 1, fpad)
modules.append(
("sconvt", nn.ConvTranspose2d(padding=padding, output_padding=convpad, **convkwargs))
)
elif mode == "upsample":
modules.append(("upsample", FreqUpsample(fstride)))
convkwargs["stride"] = 1
modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs)))
else:
raise NotImplementedError()
if groups > 1:
modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False)))
if batch_norm:
modules.append(("norm", nn.BatchNorm2d(out_ch)))
modules.append(("act", act))
return nn.Sequential(OrderedDict(modules))
class FreqUpsample(nn.Module):
def __init__(self, factor: int, mode="nearest"):
super().__init__()
self.f = float(factor)
self.mode = mode
def forward(self, x: Tensor) -> Tensor:
return F.interpolate(x, scale_factor=[1.0, self.f], mode=self.mode)
def erb_fb(widths: np.ndarray, sr: int, normalized: bool = True, inverse: bool = False) -> Tensor:
n_freqs = int(np.sum(widths))
all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1]
b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1]
fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0]))
for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())):
fb[b : b + w, i] = 1
# Normalize to constant energy per resulting band
if inverse:
fb = fb.t()
if not normalized:
fb /= fb.sum(dim=1, keepdim=True)
else:
if normalized:
fb /= fb.sum(dim=0)
return fb.to(device=get_device())
class Mask(nn.Module):
def __init__(self, erb_inv_fb: Tensor, post_filter: bool = False, eps: float = 1e-12):
super().__init__()
self.erb_inv_fb: Tensor
self.register_buffer("erb_inv_fb", erb_inv_fb)
self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
self.post_filter = post_filter
self.eps = eps
def pf(self, mask: Tensor, beta: float = 0.02) -> Tensor:
"""Post-Filter proposed by Valin et al. [1].
Args:
mask (Tensor): Real valued mask, typically of shape [B, C, T, F].
beta: Global gain factor.
Refs:
[1]: Valin et al.: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
"""
mask_sin = mask * torch.sin(np.pi * mask / 2)
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
return mask_pf
def forward(self, spec: Tensor, mask: Tensor, atten_lim: Optional[Tensor] = None) -> Tensor:
# spec (real) [B, 1, T, F, 2], F: freq_bins
# mask (real): [B, 1, T, Fe], Fe: erb_bins
# atten_lim: [B]
if not self.training and self.post_filter:
mask = self.pf(mask)
if atten_lim is not None:
# dB to amplitude
atten_lim = 10 ** (-atten_lim / 20)
# Greater equal (__ge__) not implemented for TorchVersion.
if self.clamp_tensor:
# Supported by torch >= 1.9
mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
else:
m_out = []
for i in range(atten_lim.shape[0]):
m_out.append(mask[i].clamp_min(atten_lim[i].item()))
mask = torch.stack(m_out, dim=0)
mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
return spec * mask.unsqueeze(4)
class ExponentialUnitNorm(nn.Module):
"""Unit norm for a complex spectrogram.
This should match the rust code:
```rust
for (x, s) in xs.iter_mut().zip(state.iter_mut()) {
*s = x.norm() * (1. - alpha) + *s * alpha;
*x /= s.sqrt();
}
```
"""
alpha: Final[float]
eps: Final[float]
def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14):
super().__init__()
self.alpha = alpha
self.eps = eps
self.init_state: Tensor
s = torch.from_numpy(unit_norm_init(num_freq_bins)).view(1, 1, num_freq_bins, 1)
self.register_buffer("init_state", s)
def forward(self, x: Tensor) -> Tensor:
# x: [B, C, T, F, 2]
b, c, t, f, _ = x.shape
x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min(self.eps).sqrt()
state = self.init_state.clone().expand(b, c, f, 1)
out_states: List[Tensor] = []
for t in range(t):
state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha
out_states.append(state)
return x / torch.stack(out_states, 2).sqrt()
class DfOp(nn.Module):
df_order: Final[int]
df_bins: Final[int]
df_lookahead: Final[int]
freq_bins: Final[int]
def __init__(
self,
df_bins: int,
df_order: int = 5,
df_lookahead: int = 0,
method: str = "complex_strided",
freq_bins: int = 0,
):
super().__init__()
self.df_order = df_order
self.df_bins = df_bins
self.df_lookahead = df_lookahead
self.freq_bins = freq_bins
self.set_forward(method)
def set_forward(self, method: str):
# All forward methods should be mathematically similar.
# DeepFilterNet results are obtained with 'real_unfold'.
forward_methods = {
"real_loop": self.forward_real_loop,
"real_strided": self.forward_real_strided,
"real_unfold": self.forward_real_unfold,
"complex_strided": self.forward_complex_strided,
"real_one_step": self.forward_real_no_pad_one_step,
"real_hidden_state_loop": self.forward_real_hidden_state_loop,
}
if method not in forward_methods.keys():
raise NotImplementedError(f"`method` must be one of {forward_methods.keys()}")
if method == "real_hidden_state_loop":
assert self.freq_bins >= self.df_bins
self.spec_buf: Tensor
# Currently only designed for batch size of 1
self.register_buffer(
"spec_buf", torch.zeros(1, 1, self.df_order, self.freq_bins, 2), persistent=False
)
self.forward = forward_methods[method]
def forward_real_loop(
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
) -> Tensor:
# Version 0: Manual loop over df_order, maybe best for onnx export?
b, _, t, _, _ = spec.shape
f = self.df_bins
padded = spec_pad(
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
)
spec_f = torch.zeros((b, t, f, 2), device=spec.device)
for i in range(self.df_order):
spec_f[..., 0] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 0]
spec_f[..., 0] -= padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 1]
spec_f[..., 1] += padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 0]
spec_f[..., 1] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 1]
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
def forward_real_strided(
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
) -> Tensor:
# Version1: Use as_strided instead of unfold
# spec (real) [B, 1, T, F, 2], O: df_order
# coefs (real) [B, T, O, F, 2]
# alpha (real) [B, T, 1]
padded = as_strided(
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
)
# Complex numbers are not supported by onnx
re = padded[..., 0] * coefs[..., 0]
re -= padded[..., 1] * coefs[..., 1]
im = padded[..., 1] * coefs[..., 0]
im += padded[..., 0] * coefs[..., 1]
spec_f = torch.stack((re, im), -1).sum(2)
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
def forward_real_unfold(
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
) -> Tensor:
# Version2: Unfold
# spec (real) [B, 1, T, F, 2], O: df_order
# coefs (real) [B, T, O, F, 2]
# alpha (real) [B, T, 1]
padded = spec_pad(
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
)
padded = padded.unfold(dimension=1, size=self.df_order, step=1) # [B, T, F, 2, O]
padded = padded.permute(0, 1, 4, 2, 3)
spec_f = torch.empty_like(padded)
spec_f[..., 0] = padded[..., 0] * coefs[..., 0] # re1
spec_f[..., 0] -= padded[..., 1] * coefs[..., 1] # re2
spec_f[..., 1] = padded[..., 1] * coefs[..., 0] # im1
spec_f[..., 1] += padded[..., 0] * coefs[..., 1] # im2
spec_f = spec_f.sum(dim=2)
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
def forward_complex_strided(
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
) -> Tensor:
# Version3: Complex strided; definatly nicest, no permute, no indexing, but complex gradient
# spec (real) [B, 1, T, F, 2], O: df_order
# coefs (real) [B, T, O, F, 2]
# alpha (real) [B, T, 1]
padded = as_strided(
spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3
)
spec_f = torch.sum(torch.view_as_complex(padded) * torch.view_as_complex(coefs), dim=2)
spec_f = torch.view_as_real(spec_f)
return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha)
def forward_real_no_pad_one_step(
self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None
) -> Tensor:
# Version4: Only viable for onnx handling. `spec` needs external (ring-)buffer handling.
# Thus, time steps `t` must be equal to `df_order`.
# spec (real) [B, 1, O, F', 2]
# coefs (real) [B, 1, O, F, 2]
assert (
spec.shape[2] == self.df_order
), "This forward method needs spectrogram buffer with `df_order` time steps as input"
assert coefs.shape[1] == 1, "This forward method is only valid for 1 time step"
sre, sim = spec[..., : self.df_bins, :].split(1, -1)
cre, cim = coefs.split(1, -1)
outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
spec_f = torch.stack((outr, outi), dim=-1)
return assign_df(
spec[:, :, self.df_order - self.df_lookahead - 1],
spec_f.unsqueeze(1),
self.df_bins,
alpha,
)
def forward_real_hidden_state_loop(self, spec: Tensor, coefs: Tensor, alpha: Tensor) -> Tensor:
# Version5: Designed for onnx export. `spec` buffer handling is done via a torch buffer.
# spec (real) [B, 1, T, F', 2]
# coefs (real) [B, T, O, F, 2]
b, _, t, _, _ = spec.shape
spec_out = torch.empty((b, 1, t, self.freq_bins, 2), device=spec.device)
for t in range(spec.shape[2]):
self.spec_buf = self.spec_buf.roll(-1, dims=2)
self.spec_buf[:, :, -1] = spec[:, :, t]
sre, sim = self.spec_buf[..., : self.df_bins, :].split(1, -1)
cre, cim = coefs[:, t : t + 1].split(1, -1)
outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1)
outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1)
spec_f = torch.stack((outr, outi), dim=-1)
spec_out[:, :, t] = assign_df(
self.spec_buf[:, :, self.df_order - self.df_lookahead - 1].unsqueeze(2),
spec_f.unsqueeze(1),
self.df_bins,
alpha[:, t],
).squeeze(2)
return spec_out
def assign_df(spec: Tensor, spec_f: Tensor, df_bins: int, alpha: Optional[Tensor]):
spec_out = spec.clone()
if alpha is not None:
b = spec.shape[0]
alpha = alpha.view(b, 1, -1, 1, 1)
spec_out[..., :df_bins, :] = spec_f * alpha + spec[..., :df_bins, :] * (1 - alpha)
else:
spec_out[..., :df_bins, :] = spec_f
return spec_out
def spec_pad(x: Tensor, window_size: int, lookahead: int, dim: int = 0) -> Tensor:
pad = [0] * x.dim() * 2
if dim >= 0:
pad[(x.dim() - dim - 1) * 2] = window_size - lookahead - 1
pad[(x.dim() - dim - 1) * 2 + 1] = lookahead
else:
pad[(-dim - 1) * 2] = window_size - lookahead - 1
pad[(-dim - 1) * 2 + 1] = lookahead
return F.pad(x, pad)
def as_strided(x: Tensor, window_size: int, lookahead: int, step: int = 1, dim: int = 0) -> Tensor:
shape = list(x.shape)
shape.insert(dim + 1, window_size)
x = spec_pad(x, window_size, lookahead, dim=dim)
# torch.fx workaround
step = 1
stride = [x.stride(0), x.stride(1), x.stride(2), x.stride(3)]
stride.insert(dim, stride[dim] * step)
return torch.as_strided(x, shape, stride)
class GroupedGRULayer(nn.Module):
input_size: Final[int]
hidden_size: Final[int]
out_size: Final[int]
bidirectional: Final[bool]
num_directions: Final[int]
groups: Final[int]
batch_first: Final[bool]
def __init__(
self,
input_size: int,
hidden_size: int,
groups: int,
batch_first: bool = True,
bias: bool = True,
dropout: float = 0,
bidirectional: bool = False,
):
super().__init__()
assert input_size % groups == 0
assert hidden_size % groups == 0
kwargs = {
"bias": bias,
"batch_first": batch_first,
"dropout": dropout,
"bidirectional": bidirectional,
}
self.input_size = input_size // groups
self.hidden_size = hidden_size // groups
self.out_size = hidden_size
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self.groups = groups
self.batch_first = batch_first
assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups"
self.layers = nn.ModuleList(
(nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups))
)
def flatten_parameters(self):
for layer in self.layers:
layer.flatten_parameters()
def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")):
return torch.zeros(
self.groups * self.num_directions,
batch_size,
self.hidden_size,
device=device,
)
def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
# input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size
# state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size
if h0 is None:
dim0, dim1 = input.shape[:2]
bs = dim0 if self.batch_first else dim1
h0 = self.get_h0(bs, device=input.device)
outputs: List[Tensor] = []
outstates: List[Tensor] = []
for i, layer in enumerate(self.layers):
o, s = layer(
input[..., i * self.input_size : (i + 1) * self.input_size],
h0[i * self.num_directions : (i + 1) * self.num_directions].detach(),
)
outputs.append(o)
outstates.append(s)
output = torch.cat(outputs, dim=-1)
h = torch.cat(outstates, dim=0)
return output, h
class GroupedGRU(nn.Module):
groups: Final[int]
num_layers: Final[int]
batch_first: Final[bool]
hidden_size: Final[int]
bidirectional: Final[bool]
num_directions: Final[int]
shuffle: Final[bool]
add_outputs: Final[bool]
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int = 1,
groups: int = 4,
bias: bool = True,
batch_first: bool = True,
dropout: float = 0,
bidirectional: bool = False,
shuffle: bool = True,
add_outputs: bool = False,
):
super().__init__()
kwargs = {
"groups": groups,
"bias": bias,
"batch_first": batch_first,
"dropout": dropout,
"bidirectional": bidirectional,
}
assert input_size % groups == 0
assert hidden_size % groups == 0
assert num_layers > 0
self.input_size = input_size
self.groups = groups
self.num_layers = num_layers
self.batch_first = batch_first
self.hidden_size = hidden_size // groups
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
if groups == 1:
shuffle = False # Fully connected, no need to shuffle
self.shuffle = shuffle
self.add_outputs = add_outputs
self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore
self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs))
for _ in range(1, num_layers):
self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs))
self.flatten_parameters()
def flatten_parameters(self):
for gru in self.grus:
gru.flatten_parameters()
def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor:
return torch.zeros(
(self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size),
device=device,
)
def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
dim0, dim1, _ = input.shape
b = dim0 if self.batch_first else dim1
if state is None:
state = self.get_h0(b, input.device)
output = torch.zeros(
dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device
)
outstates = []
h = self.groups * self.num_directions
for i, gru in enumerate(self.grus):
input, s = gru(input, state[i * h : (i + 1) * h])
outstates.append(s)
if self.shuffle and i < self.num_layers - 1:
input = (
input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1)
)
if self.add_outputs:
output += input
else:
output = input
outstate = torch.cat(outstates, dim=0)
return output, outstate
class SqueezedGRU(nn.Module):
input_size: Final[int]
hidden_size: Final[int]
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: Optional[int] = None,
num_layers: int = 1,
linear_groups: int = 8,
batch_first: bool = True,
gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None,
linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity,
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_in = nn.Sequential(
GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer()
)
self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first)
self.gru_skip = gru_skip_op() if gru_skip_op is not None else None
if output_size is not None:
self.linear_out = nn.Sequential(
GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer()
)
else:
self.linear_out = nn.Identity()
def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]:
input = self.linear_in(input)
x, h = self.gru(input, h)
if self.gru_skip is not None:
x = x + self.gru_skip(input)
x = self.linear_out(x)
return x, h
class GroupedLinearEinsum(nn.Module):
input_size: Final[int]
hidden_size: Final[int]
groups: Final[int]
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
super().__init__()
# self.weight: Tensor
self.input_size = input_size
self.hidden_size = hidden_size
self.groups = groups
assert input_size % groups == 0
self.ws = input_size // groups
self.register_parameter(
"weight",
Parameter(
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
),
)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
def forward(self, x: Tensor) -> Tensor:
# x: [..., I]
x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
x = torch.einsum("...gi,...gih->...gh", x, self.weight) # [..., G, H/G]
x = x.flatten(2, 3) # [B, T, H]
return x
class GroupedLinear(nn.Module):
input_size: Final[int]
hidden_size: Final[int]
groups: Final[int]
shuffle: Final[bool]
def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True):
super().__init__()
assert input_size % groups == 0
assert hidden_size % groups == 0
self.groups = groups
self.input_size = input_size // groups
self.hidden_size = hidden_size // groups
if groups == 1:
shuffle = False
self.shuffle = shuffle
self.layers = nn.ModuleList(
nn.Linear(self.input_size, self.hidden_size) for _ in range(groups)
)
def forward(self, x: Tensor) -> Tensor:
outputs: List[Tensor] = []
for i, layer in enumerate(self.layers):
outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size]))
output = torch.cat(outputs, dim=-1)
if self.shuffle:
orig_shape = output.shape
output = (
output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape)
)
return output
class LocalSnrTarget(nn.Module):
def __init__(
self, ws: int = 20, db: bool = True, ws_ns: Optional[int] = None, target_snr_range=None
):
super().__init__()
self.ws = self.calc_ws(ws)
self.ws_ns = self.ws * 2 if ws_ns is None else self.calc_ws(ws_ns)
self.db = db
self.range = target_snr_range
def calc_ws(self, ws_ms: int) -> int:
# Calculates windows size in stft domain given a window size in ms
p = ModelParams()
ws = ws_ms - p.fft_size / p.sr * 1000 # length ms of an fft_window
ws = 1 + ws / (p.hop_size / p.sr * 1000) # consider hop_size
return max(int(round(ws)), 1)
def forward(self, clean: Tensor, noise: Tensor, max_bin: Optional[int] = None) -> Tensor:
# clean: [B, 1, T, F]
# out: [B, T']
if max_bin is not None:
clean = as_complex(clean[..., :max_bin])
noise = as_complex(noise[..., :max_bin])
return (
local_snr(clean, noise, window_size=self.ws, db=self.db, window_size_ns=self.ws_ns)[0]
.clamp(self.range[0], self.range[1])
.squeeze(1)
)
def _local_energy(x: Tensor, ws: int, device: torch.device) -> Tensor:
if (ws % 2) == 0:
ws += 1
ws_half = ws // 2
x = F.pad(x.pow(2).sum(-1).sum(-1), (ws_half, ws_half, 0, 0))
w = torch.hann_window(ws, device=device, dtype=x.dtype)
x = x.unfold(-1, size=ws, step=1) * w
return torch.sum(x, dim=-1).div(ws)
def local_snr(
clean: Tensor,
noise: Tensor,
window_size: int,
db: bool = False,
window_size_ns: Optional[int] = None,
eps: float = 1e-12,
) -> Tuple[Tensor, Tensor, Tensor]:
# clean shape: [B, C, T, F]
clean = as_real(clean)
noise = as_real(noise)
assert clean.dim() == 5
E_speech = _local_energy(clean, window_size, clean.device)
window_size_ns = window_size if window_size_ns is None else window_size_ns
E_noise = _local_energy(noise, window_size_ns, clean.device)
snr = E_speech / E_noise.clamp_min(eps)
if db:
snr = snr.clamp_min(eps).log10().mul(10)
return snr, E_speech, E_noise
def test_grouped_gru():
from icecream import ic
g = 2 # groups
h = 4 # hidden_size
i = 2 # input_size
b = 1 # batch_size
t = 5 # time_steps
m = GroupedGRULayer(i, h, g, batch_first=True)
ic(m)
input = torch.randn((b, t, i))
h0 = m.get_h0(b)
assert list(h0.shape) == [g, b, h // g]
out, hout = m(input, h0)
# Should be exportable as raw nn.Module
torch.onnx.export(
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
)
# Should be exportable as traced
m = torch.jit.trace(m, (input, h0))
torch.onnx.export(
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
)
# and as scripted module
m = torch.jit.script(m)
torch.onnx.export(
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
)
# now grouped gru
num = 2
m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True)
ic(m)
h0 = m.get_h0(b)
assert list(h0.shape) == [num * g, b, h // g]
out, hout = m(input, h0)
# Should be exportable as traced
m = torch.jit.trace(m, (input, h0))
torch.onnx.export(
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
)
# and scripted module
m = torch.jit.script(m)
torch.onnx.export(
m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13
)
def test_erb():
import libdf
from df_local.config import config
config.use_defaults()
p = ModelParams()
n_freq = p.fft_size // 2 + 1
df_state = libdf.DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb)
erb = erb_fb(df_state.erb_widths(), p.sr)
erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True)
input = torch.randn((1, 1, 1, n_freq), dtype=torch.complex64)
input_abs = input.abs().square()
erb_widths = df_state.erb_widths()
df_erb = torch.from_numpy(libdf.erb(input.numpy(), erb_widths, False))
py_erb = torch.matmul(input_abs, erb)
assert torch.allclose(df_erb, py_erb)
df_out = torch.from_numpy(libdf.erb_inv(df_erb.numpy(), erb_widths))
py_out = torch.matmul(py_erb, erb_inverse)
assert torch.allclose(df_out, py_out)
def test_unit_norm():
from df_local.config import config
from libdf import unit_norm
config.use_defaults()
p = ModelParams()
b = 2
F = p.nb_df
t = 100
spec = torch.randn(b, 1, t, F, 2)
alpha = get_norm_alpha(log=False)
# Expects complex input of shape [C, T, F]
norm_lib = torch.as_tensor(unit_norm(torch.view_as_complex(spec).squeeze(1).numpy(), alpha))
m = ExponentialUnitNorm(alpha, F)
norm_torch = torch.view_as_complex(m(spec).squeeze(1))
assert torch.allclose(norm_lib.real, norm_torch.real)
assert torch.allclose(norm_lib.imag, norm_torch.imag)
assert torch.allclose(norm_lib.abs(), norm_torch.abs())
def test_dfop():
from df_local.config import config
config.use_defaults()
p = ModelParams()
f = p.nb_df
F = f * 2
o = p.df_order
d = p.df_lookahead
t = 100
spec = torch.randn(1, 1, t, F, 2)
coefs = torch.randn(1, t, o, f, 2)
alpha = torch.randn(1, t, 1)
dfop = DfOp(df_bins=p.nb_df)
dfop.set_forward("real_loop")
out1 = dfop(spec, coefs, alpha)
dfop.set_forward("real_strided")
out2 = dfop(spec, coefs, alpha)
dfop.set_forward("real_unfold")
out3 = dfop(spec, coefs, alpha)
dfop.set_forward("complex_strided")
out4 = dfop(spec, coefs, alpha)
torch.testing.assert_allclose(out1, out2)
torch.testing.assert_allclose(out1, out3)
torch.testing.assert_allclose(out1, out4)
# This forward method requires external padding/lookahead as well as spectrogram buffer
# handling, i.e. via a ring buffer. Could be used in real time usage.
dfop.set_forward("real_one_step")
spec_padded = spec_pad(spec, o, d, dim=-3)
out5 = torch.zeros_like(out1)
for i in range(t):
out5[:, :, i] = dfop(
spec_padded[:, :, i : i + o], coefs[:, i].unsqueeze(1), alpha[:, i].unsqueeze(1)
)
torch.testing.assert_allclose(out1, out5)
# Forward method that does the padding/lookahead handling using an internal hidden state.
dfop.freq_bins = F
dfop.set_forward("real_hidden_state_loop")
out6 = dfop(spec, coefs, alpha)
torch.testing.assert_allclose(out1, out6)