|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import typing as tp |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: |
|
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] |
|
with K the kernel size, by extracting frames with the given stride. |
|
This will pad the input so that `F = ceil(T / K)`. |
|
see https://github.com/pytorch/pytorch/issues/60466 |
|
""" |
|
*shape, length = a.shape |
|
n_frames = math.ceil(length / stride) |
|
tgt_length = (n_frames - 1) * stride + kernel_size |
|
a = F.pad(a, (0, tgt_length - length)) |
|
strides = list(a.stride()) |
|
assert strides[-1] == 1, "data should be contiguous" |
|
strides = strides[:-1] + [stride, 1] |
|
return a.as_strided([*shape, n_frames, kernel_size], strides) |
|
|
|
|
|
def _center(x: torch.Tensor) -> torch.Tensor: |
|
return x - x.mean(-1, True) |
|
|
|
|
|
def _norm2(x: torch.Tensor) -> torch.Tensor: |
|
return x.pow(2).sum(-1, True) |
|
|
|
|
|
class SISNR(nn.Module): |
|
"""SISNR loss. |
|
|
|
Input should be [B, C, T], output is scalar. |
|
|
|
..Warning:: This function returns the opposite of the SI-SNR (e.g. `-1 * regular_SI_SNR`). |
|
Consequently, lower scores are better in terms of reconstruction quality, |
|
in particular, it should be negative if training goes well. This done this way so |
|
that this module can also be used as a loss function for training model. |
|
|
|
Args: |
|
sample_rate (int): Sample rate. |
|
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on |
|
entire audio only. |
|
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. |
|
epsilon (float): Epsilon value for numerical stability. |
|
""" |
|
def __init__( |
|
self, |
|
sample_rate: int = 16000, |
|
segment: tp.Optional[float] = 20, |
|
overlap: float = 0.5, |
|
epsilon: float = torch.finfo(torch.float32).eps, |
|
): |
|
super().__init__() |
|
self.sample_rate = sample_rate |
|
self.segment = segment |
|
self.overlap = overlap |
|
self.epsilon = epsilon |
|
|
|
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: |
|
B, C, T = ref_sig.shape |
|
assert ref_sig.shape == out_sig.shape |
|
|
|
if self.segment is None: |
|
frame = T |
|
stride = T |
|
else: |
|
frame = int(self.segment * self.sample_rate) |
|
stride = int(frame * (1 - self.overlap)) |
|
|
|
epsilon = self.epsilon * frame |
|
|
|
gt = _unfold(ref_sig, frame, stride) |
|
est = _unfold(out_sig, frame, stride) |
|
if self.segment is None: |
|
assert gt.shape[-1] == 1 |
|
|
|
gt = _center(gt) |
|
est = _center(est) |
|
dot = torch.einsum("bcft,bcft->bcf", gt, est) |
|
|
|
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) |
|
noise = est - proj |
|
|
|
sisnr = 10 * ( |
|
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) |
|
) |
|
return -1 * sisnr[..., 0].mean() |
|
|