Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import typing as tp | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor: | |
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K] | |
with K the kernel size, by extracting frames with the given stride. | |
This will pad the input so that `F = ceil(T / K)`. | |
see https://github.com/pytorch/pytorch/issues/60466 | |
""" | |
*shape, length = a.shape | |
n_frames = math.ceil(length / stride) | |
tgt_length = (n_frames - 1) * stride + kernel_size | |
a = F.pad(a, (0, tgt_length - length)) | |
strides = list(a.stride()) | |
assert strides[-1] == 1, "data should be contiguous" | |
strides = strides[:-1] + [stride, 1] | |
return a.as_strided([*shape, n_frames, kernel_size], strides) | |
def _center(x: torch.Tensor) -> torch.Tensor: | |
return x - x.mean(-1, True) | |
def _norm2(x: torch.Tensor) -> torch.Tensor: | |
return x.pow(2).sum(-1, True) | |
class SISNR(nn.Module): | |
"""SISNR loss. | |
Input should be [B, C, T], output is scalar. | |
Args: | |
sample_rate (int): Sample rate. | |
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on | |
entire audio only. | |
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap. | |
epsilon (float): Epsilon value for numerical stability. | |
""" | |
def __init__( | |
self, | |
sample_rate: int = 16000, | |
segment: tp.Optional[float] = 20, | |
overlap: float = 0.5, | |
epsilon: float = torch.finfo(torch.float32).eps, | |
): | |
super().__init__() | |
self.sample_rate = sample_rate | |
self.segment = segment | |
self.overlap = overlap | |
self.epsilon = epsilon | |
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor: | |
B, C, T = ref_sig.shape | |
assert ref_sig.shape == out_sig.shape | |
if self.segment is None: | |
frame = T | |
stride = T | |
else: | |
frame = int(self.segment * self.sample_rate) | |
stride = int(frame * (1 - self.overlap)) | |
epsilon = self.epsilon * frame # make epsilon prop to frame size. | |
gt = _unfold(ref_sig, frame, stride) | |
est = _unfold(out_sig, frame, stride) | |
if self.segment is None: | |
assert gt.shape[-1] == 1 | |
gt = _center(gt) | |
est = _center(est) | |
dot = torch.einsum("bcft,bcft->bcf", gt, est) | |
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt)) | |
noise = est - proj | |
sisnr = 10 * ( | |
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise)) | |
) | |
return -1 * sisnr[..., 0].mean() | |