Spaces:
Runtime error
Runtime error
| import math | |
| from collections import OrderedDict | |
| from typing import Callable, Iterable, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from torch.nn import init | |
| from torch.nn.parameter import Parameter | |
| from typing_extensions import Final | |
| from df_local.model import ModelParams | |
| from df_local.utils import as_complex, as_real, get_device, get_norm_alpha | |
| from libdf import unit_norm_init | |
| class Conv2dNormAct(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_ch: int, | |
| out_ch: int, | |
| kernel_size: Union[int, Iterable[int]], | |
| fstride: int = 1, | |
| dilation: int = 1, | |
| fpad: bool = True, | |
| bias: bool = True, | |
| separable: bool = False, | |
| norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | |
| activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | |
| ): | |
| """Causal Conv2d by delaying the signal for any lookahead. | |
| Expected input format: [B, C, T, F] | |
| """ | |
| lookahead = 0 # This needs to be handled on the input feature side | |
| # Padding on time axis | |
| kernel_size = ( | |
| (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) | |
| ) | |
| if fpad: | |
| fpad_ = kernel_size[1] // 2 + dilation - 1 | |
| else: | |
| fpad_ = 0 | |
| pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
| layers = [] | |
| if any(x > 0 for x in pad): | |
| layers.append(nn.ConstantPad2d(pad, 0.0)) | |
| groups = math.gcd(in_ch, out_ch) if separable else 1 | |
| if groups == 1: | |
| separable = False | |
| if max(kernel_size) == 1: | |
| separable = False | |
| layers.append( | |
| nn.Conv2d( | |
| in_ch, | |
| out_ch, | |
| kernel_size=kernel_size, | |
| padding=(0, fpad_), | |
| stride=(1, fstride), # Stride over time is always 1 | |
| dilation=(1, dilation), # Same for dilation | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| ) | |
| if separable: | |
| layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) | |
| if norm_layer is not None: | |
| layers.append(norm_layer(out_ch)) | |
| if activation_layer is not None: | |
| layers.append(activation_layer()) | |
| super().__init__(*layers) | |
| class ConvTranspose2dNormAct(nn.Sequential): | |
| def __init__( | |
| self, | |
| in_ch: int, | |
| out_ch: int, | |
| kernel_size: Union[int, Tuple[int, int]], | |
| fstride: int = 1, | |
| dilation: int = 1, | |
| fpad: bool = True, | |
| bias: bool = True, | |
| separable: bool = False, | |
| norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, | |
| activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, | |
| ): | |
| """Causal ConvTranspose2d. | |
| Expected input format: [B, C, T, F] | |
| """ | |
| # Padding on time axis, with lookahead = 0 | |
| lookahead = 0 # This needs to be handled on the input feature side | |
| kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
| if fpad: | |
| fpad_ = kernel_size[1] // 2 | |
| else: | |
| fpad_ = 0 | |
| pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
| layers = [] | |
| if any(x > 0 for x in pad): | |
| layers.append(nn.ConstantPad2d(pad, 0.0)) | |
| groups = math.gcd(in_ch, out_ch) if separable else 1 | |
| if groups == 1: | |
| separable = False | |
| layers.append( | |
| nn.ConvTranspose2d( | |
| in_ch, | |
| out_ch, | |
| kernel_size=kernel_size, | |
| padding=(kernel_size[0] - 1, fpad_ + dilation - 1), | |
| output_padding=(0, fpad_), | |
| stride=(1, fstride), # Stride over time is always 1 | |
| dilation=(1, dilation), | |
| groups=groups, | |
| bias=bias, | |
| ) | |
| ) | |
| if separable: | |
| layers.append(nn.Conv2d(out_ch, out_ch, kernel_size=1, bias=False)) | |
| if norm_layer is not None: | |
| layers.append(norm_layer(out_ch)) | |
| if activation_layer is not None: | |
| layers.append(activation_layer()) | |
| super().__init__(*layers) | |
| def convkxf( | |
| in_ch: int, | |
| out_ch: Optional[int] = None, | |
| k: int = 1, | |
| f: int = 3, | |
| fstride: int = 2, | |
| lookahead: int = 0, | |
| batch_norm: bool = False, | |
| act: nn.Module = nn.ReLU(inplace=True), | |
| mode="normal", # must be "normal", "transposed" or "upsample" | |
| depthwise: bool = True, | |
| complex_in: bool = False, | |
| ): | |
| bias = batch_norm is False | |
| assert f % 2 == 1 | |
| stride = 1 if f == 1 else (1, fstride) | |
| if out_ch is None: | |
| out_ch = in_ch * 2 if mode == "normal" else in_ch // 2 | |
| fpad = (f - 1) // 2 | |
| convpad = (0, fpad) | |
| modules = [] | |
| # Manually pad for time axis kernel to not introduce delay | |
| pad = (0, 0, k - 1 - lookahead, lookahead) | |
| if any(p > 0 for p in pad): | |
| modules.append(("pad", nn.ConstantPad2d(pad, 0.0))) | |
| if depthwise: | |
| groups = min(in_ch, out_ch) | |
| else: | |
| groups = 1 | |
| if in_ch % groups != 0 or out_ch % groups != 0: | |
| groups = 1 | |
| if complex_in and groups % 2 == 0: | |
| groups //= 2 | |
| convkwargs = { | |
| "in_channels": in_ch, | |
| "out_channels": out_ch, | |
| "kernel_size": (k, f), | |
| "stride": stride, | |
| "groups": groups, | |
| "bias": bias, | |
| } | |
| if mode == "normal": | |
| modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) | |
| elif mode == "transposed": | |
| # Since pytorch's transposed conv padding does not correspond to the actual padding but | |
| # rather the padding that was used in the encoder conv, we need to set time axis padding | |
| # according to k. E.g., this disables the padding for k=2: | |
| # dilation - (k - 1) - padding | |
| # = 1 - (2 - 1) - 1 = 0; => padding = fpad (=1 for k=2) | |
| padding = (k - 1, fpad) | |
| modules.append( | |
| ("sconvt", nn.ConvTranspose2d(padding=padding, output_padding=convpad, **convkwargs)) | |
| ) | |
| elif mode == "upsample": | |
| modules.append(("upsample", FreqUpsample(fstride))) | |
| convkwargs["stride"] = 1 | |
| modules.append(("sconv", nn.Conv2d(padding=convpad, **convkwargs))) | |
| else: | |
| raise NotImplementedError() | |
| if groups > 1: | |
| modules.append(("1x1conv", nn.Conv2d(out_ch, out_ch, 1, bias=False))) | |
| if batch_norm: | |
| modules.append(("norm", nn.BatchNorm2d(out_ch))) | |
| modules.append(("act", act)) | |
| return nn.Sequential(OrderedDict(modules)) | |
| class FreqUpsample(nn.Module): | |
| def __init__(self, factor: int, mode="nearest"): | |
| super().__init__() | |
| self.f = float(factor) | |
| self.mode = mode | |
| def forward(self, x: Tensor) -> Tensor: | |
| return F.interpolate(x, scale_factor=[1.0, self.f], mode=self.mode) | |
| def erb_fb(widths: np.ndarray, sr: int, normalized: bool = True, inverse: bool = False) -> Tensor: | |
| n_freqs = int(np.sum(widths)) | |
| all_freqs = torch.linspace(0, sr // 2, n_freqs + 1)[:-1] | |
| b_pts = np.cumsum([0] + widths.tolist()).astype(int)[:-1] | |
| fb = torch.zeros((all_freqs.shape[0], b_pts.shape[0])) | |
| for i, (b, w) in enumerate(zip(b_pts.tolist(), widths.tolist())): | |
| fb[b : b + w, i] = 1 | |
| # Normalize to constant energy per resulting band | |
| if inverse: | |
| fb = fb.t() | |
| if not normalized: | |
| fb /= fb.sum(dim=1, keepdim=True) | |
| else: | |
| if normalized: | |
| fb /= fb.sum(dim=0) | |
| return fb.to(device=get_device()) | |
| class Mask(nn.Module): | |
| def __init__(self, erb_inv_fb: Tensor, post_filter: bool = False, eps: float = 1e-12): | |
| super().__init__() | |
| self.erb_inv_fb: Tensor | |
| self.register_buffer("erb_inv_fb", erb_inv_fb) | |
| self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0" | |
| self.post_filter = post_filter | |
| self.eps = eps | |
| def pf(self, mask: Tensor, beta: float = 0.02) -> Tensor: | |
| """Post-Filter proposed by Valin et al. [1]. | |
| Args: | |
| mask (Tensor): Real valued mask, typically of shape [B, C, T, F]. | |
| beta: Global gain factor. | |
| Refs: | |
| [1]: Valin et al.: A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. | |
| """ | |
| mask_sin = mask * torch.sin(np.pi * mask / 2) | |
| mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) | |
| return mask_pf | |
| def forward(self, spec: Tensor, mask: Tensor, atten_lim: Optional[Tensor] = None) -> Tensor: | |
| # spec (real) [B, 1, T, F, 2], F: freq_bins | |
| # mask (real): [B, 1, T, Fe], Fe: erb_bins | |
| # atten_lim: [B] | |
| if not self.training and self.post_filter: | |
| mask = self.pf(mask) | |
| if atten_lim is not None: | |
| # dB to amplitude | |
| atten_lim = 10 ** (-atten_lim / 20) | |
| # Greater equal (__ge__) not implemented for TorchVersion. | |
| if self.clamp_tensor: | |
| # Supported by torch >= 1.9 | |
| mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1)) | |
| else: | |
| m_out = [] | |
| for i in range(atten_lim.shape[0]): | |
| m_out.append(mask[i].clamp_min(atten_lim[i].item())) | |
| mask = torch.stack(m_out, dim=0) | |
| mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F] | |
| return spec * mask.unsqueeze(4) | |
| class ExponentialUnitNorm(nn.Module): | |
| """Unit norm for a complex spectrogram. | |
| This should match the rust code: | |
| ```rust | |
| for (x, s) in xs.iter_mut().zip(state.iter_mut()) { | |
| *s = x.norm() * (1. - alpha) + *s * alpha; | |
| *x /= s.sqrt(); | |
| } | |
| ``` | |
| """ | |
| alpha: Final[float] | |
| eps: Final[float] | |
| def __init__(self, alpha: float, num_freq_bins: int, eps: float = 1e-14): | |
| super().__init__() | |
| self.alpha = alpha | |
| self.eps = eps | |
| self.init_state: Tensor | |
| s = torch.from_numpy(unit_norm_init(num_freq_bins)).view(1, 1, num_freq_bins, 1) | |
| self.register_buffer("init_state", s) | |
| def forward(self, x: Tensor) -> Tensor: | |
| # x: [B, C, T, F, 2] | |
| b, c, t, f, _ = x.shape | |
| x_abs = x.square().sum(dim=-1, keepdim=True).clamp_min(self.eps).sqrt() | |
| state = self.init_state.clone().expand(b, c, f, 1) | |
| out_states: List[Tensor] = [] | |
| for t in range(t): | |
| state = x_abs[:, :, t] * (1 - self.alpha) + state * self.alpha | |
| out_states.append(state) | |
| return x / torch.stack(out_states, 2).sqrt() | |
| class DfOp(nn.Module): | |
| df_order: Final[int] | |
| df_bins: Final[int] | |
| df_lookahead: Final[int] | |
| freq_bins: Final[int] | |
| def __init__( | |
| self, | |
| df_bins: int, | |
| df_order: int = 5, | |
| df_lookahead: int = 0, | |
| method: str = "complex_strided", | |
| freq_bins: int = 0, | |
| ): | |
| super().__init__() | |
| self.df_order = df_order | |
| self.df_bins = df_bins | |
| self.df_lookahead = df_lookahead | |
| self.freq_bins = freq_bins | |
| self.set_forward(method) | |
| def set_forward(self, method: str): | |
| # All forward methods should be mathematically similar. | |
| # DeepFilterNet results are obtained with 'real_unfold'. | |
| forward_methods = { | |
| "real_loop": self.forward_real_loop, | |
| "real_strided": self.forward_real_strided, | |
| "real_unfold": self.forward_real_unfold, | |
| "complex_strided": self.forward_complex_strided, | |
| "real_one_step": self.forward_real_no_pad_one_step, | |
| "real_hidden_state_loop": self.forward_real_hidden_state_loop, | |
| } | |
| if method not in forward_methods.keys(): | |
| raise NotImplementedError(f"`method` must be one of {forward_methods.keys()}") | |
| if method == "real_hidden_state_loop": | |
| assert self.freq_bins >= self.df_bins | |
| self.spec_buf: Tensor | |
| # Currently only designed for batch size of 1 | |
| self.register_buffer( | |
| "spec_buf", torch.zeros(1, 1, self.df_order, self.freq_bins, 2), persistent=False | |
| ) | |
| self.forward = forward_methods[method] | |
| def forward_real_loop( | |
| self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None | |
| ) -> Tensor: | |
| # Version 0: Manual loop over df_order, maybe best for onnx export? | |
| b, _, t, _, _ = spec.shape | |
| f = self.df_bins | |
| padded = spec_pad( | |
| spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 | |
| ) | |
| spec_f = torch.zeros((b, t, f, 2), device=spec.device) | |
| for i in range(self.df_order): | |
| spec_f[..., 0] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 0] | |
| spec_f[..., 0] -= padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 1] | |
| spec_f[..., 1] += padded[:, i : i + t, ..., 1] * coefs[:, :, i, :, 0] | |
| spec_f[..., 1] += padded[:, i : i + t, ..., 0] * coefs[:, :, i, :, 1] | |
| return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) | |
| def forward_real_strided( | |
| self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None | |
| ) -> Tensor: | |
| # Version1: Use as_strided instead of unfold | |
| # spec (real) [B, 1, T, F, 2], O: df_order | |
| # coefs (real) [B, T, O, F, 2] | |
| # alpha (real) [B, T, 1] | |
| padded = as_strided( | |
| spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 | |
| ) | |
| # Complex numbers are not supported by onnx | |
| re = padded[..., 0] * coefs[..., 0] | |
| re -= padded[..., 1] * coefs[..., 1] | |
| im = padded[..., 1] * coefs[..., 0] | |
| im += padded[..., 0] * coefs[..., 1] | |
| spec_f = torch.stack((re, im), -1).sum(2) | |
| return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) | |
| def forward_real_unfold( | |
| self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None | |
| ) -> Tensor: | |
| # Version2: Unfold | |
| # spec (real) [B, 1, T, F, 2], O: df_order | |
| # coefs (real) [B, T, O, F, 2] | |
| # alpha (real) [B, T, 1] | |
| padded = spec_pad( | |
| spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 | |
| ) | |
| padded = padded.unfold(dimension=1, size=self.df_order, step=1) # [B, T, F, 2, O] | |
| padded = padded.permute(0, 1, 4, 2, 3) | |
| spec_f = torch.empty_like(padded) | |
| spec_f[..., 0] = padded[..., 0] * coefs[..., 0] # re1 | |
| spec_f[..., 0] -= padded[..., 1] * coefs[..., 1] # re2 | |
| spec_f[..., 1] = padded[..., 1] * coefs[..., 0] # im1 | |
| spec_f[..., 1] += padded[..., 0] * coefs[..., 1] # im2 | |
| spec_f = spec_f.sum(dim=2) | |
| return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) | |
| def forward_complex_strided( | |
| self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None | |
| ) -> Tensor: | |
| # Version3: Complex strided; definatly nicest, no permute, no indexing, but complex gradient | |
| # spec (real) [B, 1, T, F, 2], O: df_order | |
| # coefs (real) [B, T, O, F, 2] | |
| # alpha (real) [B, T, 1] | |
| padded = as_strided( | |
| spec[..., : self.df_bins, :].squeeze(1), self.df_order, self.df_lookahead, dim=-3 | |
| ) | |
| spec_f = torch.sum(torch.view_as_complex(padded) * torch.view_as_complex(coefs), dim=2) | |
| spec_f = torch.view_as_real(spec_f) | |
| return assign_df(spec, spec_f.unsqueeze(1), self.df_bins, alpha) | |
| def forward_real_no_pad_one_step( | |
| self, spec: Tensor, coefs: Tensor, alpha: Optional[Tensor] = None | |
| ) -> Tensor: | |
| # Version4: Only viable for onnx handling. `spec` needs external (ring-)buffer handling. | |
| # Thus, time steps `t` must be equal to `df_order`. | |
| # spec (real) [B, 1, O, F', 2] | |
| # coefs (real) [B, 1, O, F, 2] | |
| assert ( | |
| spec.shape[2] == self.df_order | |
| ), "This forward method needs spectrogram buffer with `df_order` time steps as input" | |
| assert coefs.shape[1] == 1, "This forward method is only valid for 1 time step" | |
| sre, sim = spec[..., : self.df_bins, :].split(1, -1) | |
| cre, cim = coefs.split(1, -1) | |
| outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1) | |
| outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1) | |
| spec_f = torch.stack((outr, outi), dim=-1) | |
| return assign_df( | |
| spec[:, :, self.df_order - self.df_lookahead - 1], | |
| spec_f.unsqueeze(1), | |
| self.df_bins, | |
| alpha, | |
| ) | |
| def forward_real_hidden_state_loop(self, spec: Tensor, coefs: Tensor, alpha: Tensor) -> Tensor: | |
| # Version5: Designed for onnx export. `spec` buffer handling is done via a torch buffer. | |
| # spec (real) [B, 1, T, F', 2] | |
| # coefs (real) [B, T, O, F, 2] | |
| b, _, t, _, _ = spec.shape | |
| spec_out = torch.empty((b, 1, t, self.freq_bins, 2), device=spec.device) | |
| for t in range(spec.shape[2]): | |
| self.spec_buf = self.spec_buf.roll(-1, dims=2) | |
| self.spec_buf[:, :, -1] = spec[:, :, t] | |
| sre, sim = self.spec_buf[..., : self.df_bins, :].split(1, -1) | |
| cre, cim = coefs[:, t : t + 1].split(1, -1) | |
| outr = torch.sum(sre * cre - sim * cim, dim=2).squeeze(-1) | |
| outi = torch.sum(sre * cim + sim * cre, dim=2).squeeze(-1) | |
| spec_f = torch.stack((outr, outi), dim=-1) | |
| spec_out[:, :, t] = assign_df( | |
| self.spec_buf[:, :, self.df_order - self.df_lookahead - 1].unsqueeze(2), | |
| spec_f.unsqueeze(1), | |
| self.df_bins, | |
| alpha[:, t], | |
| ).squeeze(2) | |
| return spec_out | |
| def assign_df(spec: Tensor, spec_f: Tensor, df_bins: int, alpha: Optional[Tensor]): | |
| spec_out = spec.clone() | |
| if alpha is not None: | |
| b = spec.shape[0] | |
| alpha = alpha.view(b, 1, -1, 1, 1) | |
| spec_out[..., :df_bins, :] = spec_f * alpha + spec[..., :df_bins, :] * (1 - alpha) | |
| else: | |
| spec_out[..., :df_bins, :] = spec_f | |
| return spec_out | |
| def spec_pad(x: Tensor, window_size: int, lookahead: int, dim: int = 0) -> Tensor: | |
| pad = [0] * x.dim() * 2 | |
| if dim >= 0: | |
| pad[(x.dim() - dim - 1) * 2] = window_size - lookahead - 1 | |
| pad[(x.dim() - dim - 1) * 2 + 1] = lookahead | |
| else: | |
| pad[(-dim - 1) * 2] = window_size - lookahead - 1 | |
| pad[(-dim - 1) * 2 + 1] = lookahead | |
| return F.pad(x, pad) | |
| def as_strided(x: Tensor, window_size: int, lookahead: int, step: int = 1, dim: int = 0) -> Tensor: | |
| shape = list(x.shape) | |
| shape.insert(dim + 1, window_size) | |
| x = spec_pad(x, window_size, lookahead, dim=dim) | |
| # torch.fx workaround | |
| step = 1 | |
| stride = [x.stride(0), x.stride(1), x.stride(2), x.stride(3)] | |
| stride.insert(dim, stride[dim] * step) | |
| return torch.as_strided(x, shape, stride) | |
| class GroupedGRULayer(nn.Module): | |
| input_size: Final[int] | |
| hidden_size: Final[int] | |
| out_size: Final[int] | |
| bidirectional: Final[bool] | |
| num_directions: Final[int] | |
| groups: Final[int] | |
| batch_first: Final[bool] | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| groups: int, | |
| batch_first: bool = True, | |
| bias: bool = True, | |
| dropout: float = 0, | |
| bidirectional: bool = False, | |
| ): | |
| super().__init__() | |
| assert input_size % groups == 0 | |
| assert hidden_size % groups == 0 | |
| kwargs = { | |
| "bias": bias, | |
| "batch_first": batch_first, | |
| "dropout": dropout, | |
| "bidirectional": bidirectional, | |
| } | |
| self.input_size = input_size // groups | |
| self.hidden_size = hidden_size // groups | |
| self.out_size = hidden_size | |
| self.bidirectional = bidirectional | |
| self.num_directions = 2 if bidirectional else 1 | |
| self.groups = groups | |
| self.batch_first = batch_first | |
| assert (self.hidden_size % groups) == 0, "Hidden size must be divisible by groups" | |
| self.layers = nn.ModuleList( | |
| (nn.GRU(self.input_size, self.hidden_size, **kwargs) for _ in range(groups)) | |
| ) | |
| def flatten_parameters(self): | |
| for layer in self.layers: | |
| layer.flatten_parameters() | |
| def get_h0(self, batch_size: int = 1, device: torch.device = torch.device("cpu")): | |
| return torch.zeros( | |
| self.groups * self.num_directions, | |
| batch_size, | |
| self.hidden_size, | |
| device=device, | |
| ) | |
| def forward(self, input: Tensor, h0: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: | |
| # input shape: [B, T, I] if batch_first else [T, B, I], B: batch_size, I: input_size | |
| # state shape: [G*D, B, H], where G: groups, D: num_directions, H: hidden_size | |
| if h0 is None: | |
| dim0, dim1 = input.shape[:2] | |
| bs = dim0 if self.batch_first else dim1 | |
| h0 = self.get_h0(bs, device=input.device) | |
| outputs: List[Tensor] = [] | |
| outstates: List[Tensor] = [] | |
| for i, layer in enumerate(self.layers): | |
| o, s = layer( | |
| input[..., i * self.input_size : (i + 1) * self.input_size], | |
| h0[i * self.num_directions : (i + 1) * self.num_directions].detach(), | |
| ) | |
| outputs.append(o) | |
| outstates.append(s) | |
| output = torch.cat(outputs, dim=-1) | |
| h = torch.cat(outstates, dim=0) | |
| return output, h | |
| class GroupedGRU(nn.Module): | |
| groups: Final[int] | |
| num_layers: Final[int] | |
| batch_first: Final[bool] | |
| hidden_size: Final[int] | |
| bidirectional: Final[bool] | |
| num_directions: Final[int] | |
| shuffle: Final[bool] | |
| add_outputs: Final[bool] | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| num_layers: int = 1, | |
| groups: int = 4, | |
| bias: bool = True, | |
| batch_first: bool = True, | |
| dropout: float = 0, | |
| bidirectional: bool = False, | |
| shuffle: bool = True, | |
| add_outputs: bool = False, | |
| ): | |
| super().__init__() | |
| kwargs = { | |
| "groups": groups, | |
| "bias": bias, | |
| "batch_first": batch_first, | |
| "dropout": dropout, | |
| "bidirectional": bidirectional, | |
| } | |
| assert input_size % groups == 0 | |
| assert hidden_size % groups == 0 | |
| assert num_layers > 0 | |
| self.input_size = input_size | |
| self.groups = groups | |
| self.num_layers = num_layers | |
| self.batch_first = batch_first | |
| self.hidden_size = hidden_size // groups | |
| self.bidirectional = bidirectional | |
| self.num_directions = 2 if bidirectional else 1 | |
| if groups == 1: | |
| shuffle = False # Fully connected, no need to shuffle | |
| self.shuffle = shuffle | |
| self.add_outputs = add_outputs | |
| self.grus: List[GroupedGRULayer] = nn.ModuleList() # type: ignore | |
| self.grus.append(GroupedGRULayer(input_size, hidden_size, **kwargs)) | |
| for _ in range(1, num_layers): | |
| self.grus.append(GroupedGRULayer(hidden_size, hidden_size, **kwargs)) | |
| self.flatten_parameters() | |
| def flatten_parameters(self): | |
| for gru in self.grus: | |
| gru.flatten_parameters() | |
| def get_h0(self, batch_size: int, device: torch.device = torch.device("cpu")) -> Tensor: | |
| return torch.zeros( | |
| (self.num_layers * self.groups * self.num_directions, batch_size, self.hidden_size), | |
| device=device, | |
| ) | |
| def forward(self, input: Tensor, state: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: | |
| dim0, dim1, _ = input.shape | |
| b = dim0 if self.batch_first else dim1 | |
| if state is None: | |
| state = self.get_h0(b, input.device) | |
| output = torch.zeros( | |
| dim0, dim1, self.hidden_size * self.num_directions * self.groups, device=input.device | |
| ) | |
| outstates = [] | |
| h = self.groups * self.num_directions | |
| for i, gru in enumerate(self.grus): | |
| input, s = gru(input, state[i * h : (i + 1) * h]) | |
| outstates.append(s) | |
| if self.shuffle and i < self.num_layers - 1: | |
| input = ( | |
| input.view(dim0, dim1, -1, self.groups).transpose(2, 3).reshape(dim0, dim1, -1) | |
| ) | |
| if self.add_outputs: | |
| output += input | |
| else: | |
| output = input | |
| outstate = torch.cat(outstates, dim=0) | |
| return output, outstate | |
| class SqueezedGRU(nn.Module): | |
| input_size: Final[int] | |
| hidden_size: Final[int] | |
| def __init__( | |
| self, | |
| input_size: int, | |
| hidden_size: int, | |
| output_size: Optional[int] = None, | |
| num_layers: int = 1, | |
| linear_groups: int = 8, | |
| batch_first: bool = True, | |
| gru_skip_op: Optional[Callable[..., torch.nn.Module]] = None, | |
| linear_act_layer: Callable[..., torch.nn.Module] = nn.Identity, | |
| ): | |
| super().__init__() | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.linear_in = nn.Sequential( | |
| GroupedLinearEinsum(input_size, hidden_size, linear_groups), linear_act_layer() | |
| ) | |
| self.gru = nn.GRU(hidden_size, hidden_size, num_layers=num_layers, batch_first=batch_first) | |
| self.gru_skip = gru_skip_op() if gru_skip_op is not None else None | |
| if output_size is not None: | |
| self.linear_out = nn.Sequential( | |
| GroupedLinearEinsum(hidden_size, output_size, linear_groups), linear_act_layer() | |
| ) | |
| else: | |
| self.linear_out = nn.Identity() | |
| def forward(self, input: Tensor, h=None) -> Tuple[Tensor, Tensor]: | |
| input = self.linear_in(input) | |
| x, h = self.gru(input, h) | |
| if self.gru_skip is not None: | |
| x = x + self.gru_skip(input) | |
| x = self.linear_out(x) | |
| return x, h | |
| class GroupedLinearEinsum(nn.Module): | |
| input_size: Final[int] | |
| hidden_size: Final[int] | |
| groups: Final[int] | |
| def __init__(self, input_size: int, hidden_size: int, groups: int = 1): | |
| super().__init__() | |
| # self.weight: Tensor | |
| self.input_size = input_size | |
| self.hidden_size = hidden_size | |
| self.groups = groups | |
| assert input_size % groups == 0 | |
| self.ws = input_size // groups | |
| self.register_parameter( | |
| "weight", | |
| Parameter( | |
| torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True | |
| ), | |
| ) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore | |
| def forward(self, x: Tensor) -> Tensor: | |
| # x: [..., I] | |
| x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] | |
| x = torch.einsum("...gi,...gih->...gh", x, self.weight) # [..., G, H/G] | |
| x = x.flatten(2, 3) # [B, T, H] | |
| return x | |
| class GroupedLinear(nn.Module): | |
| input_size: Final[int] | |
| hidden_size: Final[int] | |
| groups: Final[int] | |
| shuffle: Final[bool] | |
| def __init__(self, input_size: int, hidden_size: int, groups: int = 1, shuffle: bool = True): | |
| super().__init__() | |
| assert input_size % groups == 0 | |
| assert hidden_size % groups == 0 | |
| self.groups = groups | |
| self.input_size = input_size // groups | |
| self.hidden_size = hidden_size // groups | |
| if groups == 1: | |
| shuffle = False | |
| self.shuffle = shuffle | |
| self.layers = nn.ModuleList( | |
| nn.Linear(self.input_size, self.hidden_size) for _ in range(groups) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| outputs: List[Tensor] = [] | |
| for i, layer in enumerate(self.layers): | |
| outputs.append(layer(x[..., i * self.input_size : (i + 1) * self.input_size])) | |
| output = torch.cat(outputs, dim=-1) | |
| if self.shuffle: | |
| orig_shape = output.shape | |
| output = ( | |
| output.view(-1, self.hidden_size, self.groups).transpose(-1, -2).reshape(orig_shape) | |
| ) | |
| return output | |
| class LocalSnrTarget(nn.Module): | |
| def __init__( | |
| self, ws: int = 20, db: bool = True, ws_ns: Optional[int] = None, target_snr_range=None | |
| ): | |
| super().__init__() | |
| self.ws = self.calc_ws(ws) | |
| self.ws_ns = self.ws * 2 if ws_ns is None else self.calc_ws(ws_ns) | |
| self.db = db | |
| self.range = target_snr_range | |
| def calc_ws(self, ws_ms: int) -> int: | |
| # Calculates windows size in stft domain given a window size in ms | |
| p = ModelParams() | |
| ws = ws_ms - p.fft_size / p.sr * 1000 # length ms of an fft_window | |
| ws = 1 + ws / (p.hop_size / p.sr * 1000) # consider hop_size | |
| return max(int(round(ws)), 1) | |
| def forward(self, clean: Tensor, noise: Tensor, max_bin: Optional[int] = None) -> Tensor: | |
| # clean: [B, 1, T, F] | |
| # out: [B, T'] | |
| if max_bin is not None: | |
| clean = as_complex(clean[..., :max_bin]) | |
| noise = as_complex(noise[..., :max_bin]) | |
| return ( | |
| local_snr(clean, noise, window_size=self.ws, db=self.db, window_size_ns=self.ws_ns)[0] | |
| .clamp(self.range[0], self.range[1]) | |
| .squeeze(1) | |
| ) | |
| def _local_energy(x: Tensor, ws: int, device: torch.device) -> Tensor: | |
| if (ws % 2) == 0: | |
| ws += 1 | |
| ws_half = ws // 2 | |
| x = F.pad(x.pow(2).sum(-1).sum(-1), (ws_half, ws_half, 0, 0)) | |
| w = torch.hann_window(ws, device=device, dtype=x.dtype) | |
| x = x.unfold(-1, size=ws, step=1) * w | |
| return torch.sum(x, dim=-1).div(ws) | |
| def local_snr( | |
| clean: Tensor, | |
| noise: Tensor, | |
| window_size: int, | |
| db: bool = False, | |
| window_size_ns: Optional[int] = None, | |
| eps: float = 1e-12, | |
| ) -> Tuple[Tensor, Tensor, Tensor]: | |
| # clean shape: [B, C, T, F] | |
| clean = as_real(clean) | |
| noise = as_real(noise) | |
| assert clean.dim() == 5 | |
| E_speech = _local_energy(clean, window_size, clean.device) | |
| window_size_ns = window_size if window_size_ns is None else window_size_ns | |
| E_noise = _local_energy(noise, window_size_ns, clean.device) | |
| snr = E_speech / E_noise.clamp_min(eps) | |
| if db: | |
| snr = snr.clamp_min(eps).log10().mul(10) | |
| return snr, E_speech, E_noise | |
| def test_grouped_gru(): | |
| from icecream import ic | |
| g = 2 # groups | |
| h = 4 # hidden_size | |
| i = 2 # input_size | |
| b = 1 # batch_size | |
| t = 5 # time_steps | |
| m = GroupedGRULayer(i, h, g, batch_first=True) | |
| ic(m) | |
| input = torch.randn((b, t, i)) | |
| h0 = m.get_h0(b) | |
| assert list(h0.shape) == [g, b, h // g] | |
| out, hout = m(input, h0) | |
| # Should be exportable as raw nn.Module | |
| torch.onnx.export( | |
| m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 | |
| ) | |
| # Should be exportable as traced | |
| m = torch.jit.trace(m, (input, h0)) | |
| torch.onnx.export( | |
| m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 | |
| ) | |
| # and as scripted module | |
| m = torch.jit.script(m) | |
| torch.onnx.export( | |
| m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 | |
| ) | |
| # now grouped gru | |
| num = 2 | |
| m = GroupedGRU(i, h, num, g, batch_first=True, shuffle=True) | |
| ic(m) | |
| h0 = m.get_h0(b) | |
| assert list(h0.shape) == [num * g, b, h // g] | |
| out, hout = m(input, h0) | |
| # Should be exportable as traced | |
| m = torch.jit.trace(m, (input, h0)) | |
| torch.onnx.export( | |
| m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 | |
| ) | |
| # and scripted module | |
| m = torch.jit.script(m) | |
| torch.onnx.export( | |
| m, (input, h0), "out/grouped.onnx", example_outputs=(out, hout), opset_version=13 | |
| ) | |
| def test_erb(): | |
| import libdf | |
| from df_local.config import config | |
| config.use_defaults() | |
| p = ModelParams() | |
| n_freq = p.fft_size // 2 + 1 | |
| df_state = libdf.DF(sr=p.sr, fft_size=p.fft_size, hop_size=p.hop_size, nb_bands=p.nb_erb) | |
| erb = erb_fb(df_state.erb_widths(), p.sr) | |
| erb_inverse = erb_fb(df_state.erb_widths(), p.sr, inverse=True) | |
| input = torch.randn((1, 1, 1, n_freq), dtype=torch.complex64) | |
| input_abs = input.abs().square() | |
| erb_widths = df_state.erb_widths() | |
| df_erb = torch.from_numpy(libdf.erb(input.numpy(), erb_widths, False)) | |
| py_erb = torch.matmul(input_abs, erb) | |
| assert torch.allclose(df_erb, py_erb) | |
| df_out = torch.from_numpy(libdf.erb_inv(df_erb.numpy(), erb_widths)) | |
| py_out = torch.matmul(py_erb, erb_inverse) | |
| assert torch.allclose(df_out, py_out) | |
| def test_unit_norm(): | |
| from df_local.config import config | |
| from libdf import unit_norm | |
| config.use_defaults() | |
| p = ModelParams() | |
| b = 2 | |
| F = p.nb_df | |
| t = 100 | |
| spec = torch.randn(b, 1, t, F, 2) | |
| alpha = get_norm_alpha(log=False) | |
| # Expects complex input of shape [C, T, F] | |
| norm_lib = torch.as_tensor(unit_norm(torch.view_as_complex(spec).squeeze(1).numpy(), alpha)) | |
| m = ExponentialUnitNorm(alpha, F) | |
| norm_torch = torch.view_as_complex(m(spec).squeeze(1)) | |
| assert torch.allclose(norm_lib.real, norm_torch.real) | |
| assert torch.allclose(norm_lib.imag, norm_torch.imag) | |
| assert torch.allclose(norm_lib.abs(), norm_torch.abs()) | |
| def test_dfop(): | |
| from df_local.config import config | |
| config.use_defaults() | |
| p = ModelParams() | |
| f = p.nb_df | |
| F = f * 2 | |
| o = p.df_order | |
| d = p.df_lookahead | |
| t = 100 | |
| spec = torch.randn(1, 1, t, F, 2) | |
| coefs = torch.randn(1, t, o, f, 2) | |
| alpha = torch.randn(1, t, 1) | |
| dfop = DfOp(df_bins=p.nb_df) | |
| dfop.set_forward("real_loop") | |
| out1 = dfop(spec, coefs, alpha) | |
| dfop.set_forward("real_strided") | |
| out2 = dfop(spec, coefs, alpha) | |
| dfop.set_forward("real_unfold") | |
| out3 = dfop(spec, coefs, alpha) | |
| dfop.set_forward("complex_strided") | |
| out4 = dfop(spec, coefs, alpha) | |
| torch.testing.assert_allclose(out1, out2) | |
| torch.testing.assert_allclose(out1, out3) | |
| torch.testing.assert_allclose(out1, out4) | |
| # This forward method requires external padding/lookahead as well as spectrogram buffer | |
| # handling, i.e. via a ring buffer. Could be used in real time usage. | |
| dfop.set_forward("real_one_step") | |
| spec_padded = spec_pad(spec, o, d, dim=-3) | |
| out5 = torch.zeros_like(out1) | |
| for i in range(t): | |
| out5[:, :, i] = dfop( | |
| spec_padded[:, :, i : i + o], coefs[:, i].unsqueeze(1), alpha[:, i].unsqueeze(1) | |
| ) | |
| torch.testing.assert_allclose(out1, out5) | |
| # Forward method that does the padding/lookahead handling using an internal hidden state. | |
| dfop.freq_bins = F | |
| dfop.set_forward("real_hidden_state_loop") | |
| out6 = dfop(spec, coefs, alpha) | |
| torch.testing.assert_allclose(out1, out6) | |