File size: 2,914 Bytes
0d93e4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import typing as tp

import torch
from torch import nn
from torch.nn import functional as F


def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
    """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
    with K the kernel size, by extracting frames with the given stride.
    This will pad the input so that `F = ceil(T / K)`.
    see https://github.com/pytorch/pytorch/issues/60466
    """
    *shape, length = a.shape
    n_frames = math.ceil(length / stride)
    tgt_length = (n_frames - 1) * stride + kernel_size
    a = F.pad(a, (0, tgt_length - length))
    strides = list(a.stride())
    assert strides[-1] == 1, "data should be contiguous"
    strides = strides[:-1] + [stride, 1]
    return a.as_strided([*shape, n_frames, kernel_size], strides)


def _center(x: torch.Tensor) -> torch.Tensor:
    return x - x.mean(-1, True)


def _norm2(x: torch.Tensor) -> torch.Tensor:
    return x.pow(2).sum(-1, True)


class SISNR(nn.Module):
    """SISNR loss.

    Input should be [B, C, T], output is scalar.

    Args:
        sample_rate (int): Sample rate.
        segment (float or None): Evaluate on chunks of that many seconds. If None, evaluate on
            entire audio only.
        overlap (float): Overlap between chunks, i.e. 0.5 = 50 % overlap.
        epsilon (float): Epsilon value for numerical stability.
    """
    def __init__(
        self,
        sample_rate: int = 16000,
        segment: tp.Optional[float] = 20,
        overlap: float = 0.5,
        epsilon: float = torch.finfo(torch.float32).eps,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.segment = segment
        self.overlap = overlap
        self.epsilon = epsilon

    def forward(self, out_sig: torch.Tensor, ref_sig: torch.Tensor) -> torch.Tensor:
        B, C, T = ref_sig.shape
        assert ref_sig.shape == out_sig.shape

        if self.segment is None:
            frame = T
            stride = T
        else:
            frame = int(self.segment * self.sample_rate)
            stride = int(frame * (1 - self.overlap))

        epsilon = self.epsilon * frame  # make epsilon prop to frame size.

        gt = _unfold(ref_sig, frame, stride)
        est = _unfold(out_sig, frame, stride)
        if self.segment is None:
            assert gt.shape[-1] == 1

        gt = _center(gt)
        est = _center(est)
        dot = torch.einsum("bcft,bcft->bcf", gt, est)

        proj = dot[:, :, :, None] * gt / (epsilon + _norm2(gt))
        noise = est - proj

        sisnr = 10 * (
            torch.log10(epsilon + _norm2(proj)) - torch.log10(epsilon + _norm2(noise))
        )
        return -1 * sisnr[..., 0].mean()