SuYuanS's picture
Duplicate from GrandaddyShmax/AudioCraft_Plus
a93c016
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from torch import nn
from torch.nn import functional as F
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
*shape, length = a.shape
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(a, (0, tgt_length - length))
strides = list(a.stride())
assert strides[-1] == 1, "data should be contiguous"
strides = strides[:-1] + [stride, 1]
return a.as_strided([*shape, n_frames, kernel_size], strides)
def _center(x: torch.Tensor) -> torch.Tensor:
return x - x.mean(-1, True)
def _norm2(x: torch.Tensor) -> torch.Tensor:
return x.pow(2).sum(-1, True)
class SISNR(nn.Module):
"""SISNR loss.
Input should be [B, C, T], output is scalar.
Args:
sample_rate (int): Sample rate.
segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
entire audio only.
overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
epsilon (float): Epsilon value for numerical stability.
"""
def __init__(
self,
sample_rate: int = 16000,
segment: tp.Optional[float] = 20,
overlap: float = 0.5,
epsilon: float = torch.finfo(torch.float32).eps,
):
super().__init__()
self.sample_rate = sample_rate
self.segment = segment
self.overlap = overlap
self.epsilon = epsilon
def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
B, C, T = ref_sig.shape
assert ref_sig.shape == out_sig.shape
if self.segment is None:
frame = T
stride = T
else:
frame = int(self.segment * self.sample_rate)
stride = int(frame * (1 - self.overlap))
epsilon = self.epsilon * frame # make epsilon prop to frame size.
gt = _unfold(ref_sig, frame, stride)
est = _unfold(out_sig, frame, stride)
if self.segment is None:
assert gt.shape[-1] == 1
gt = _center(gt)
est = _center(est)
dot = torch.einsum("bcft,bcft->bcf", gt, est)
proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
noise = est - proj
sisnr = 10 * (
torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
)
return -1 * sisnr[..., 0].mean()