File size: 5,066 Bytes
51e2f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch.nn.modules.loss import _Loss
from torch.nn import functional as F

class SignalNoisePNormRatio(_Loss):
    def __init__(

            self,

            p: float = 1.0,

            scale_invariant: bool = False,

            zero_mean: bool = False,

            take_log: bool = True,

            reduction: str = "mean",

            EPS: float = 1e-3,

    ) -> None:
        assert reduction != "sum", NotImplementedError
        super().__init__(reduction=reduction)
        assert not zero_mean

        self.p = p

        self.EPS = EPS
        self.take_log = take_log

        self.scale_invariant = scale_invariant

    def forward(

            self,

            est_target: torch.Tensor,

            target: torch.Tensor

            ) -> torch.Tensor:

        target_ = target
        if self.scale_invariant:
            ndim = target.ndim
            dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
            s_target_energy = (
                    torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
            )

            if ndim > 2:
                dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
                s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)

            target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
            target = target_ * target_scaler

        if torch.is_complex(est_target):
            est_target = torch.view_as_real(est_target)
            target = torch.view_as_real(target)


        batch_size = est_target.shape[0]
        est_target = est_target.reshape(batch_size, -1)
        target = target.reshape(batch_size, -1)
        # target_ = target_.reshape(batch_size, -1)

        if self.p == 1:
            e_error = torch.abs(est_target-target).mean(dim=-1)
            e_target = torch.abs(target).mean(dim=-1)
        elif self.p == 2:
            e_error = torch.square(est_target-target).mean(dim=-1)
            e_target = torch.square(target).mean(dim=-1)
        else:
            raise NotImplementedError
        
        if self.take_log:
            loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
        else:
            loss = (e_error + self.EPS)/(e_target + self.EPS)

        if self.reduction == "mean":
            loss = loss.mean()
        elif self.reduction == "sum":
            loss = loss.sum()

        return loss

        

class MultichannelSingleSrcNegSDR(_Loss):
    def __init__(

            self,

            sdr_type: str,

            p: float = 2.0,

            zero_mean: bool = True,

            take_log: bool = True,

            reduction: str = "mean",

            EPS: float = 1e-8,

    ) -> None:
        assert reduction != "sum", NotImplementedError
        super().__init__(reduction=reduction)

        assert sdr_type in ["snr", "sisdr", "sdsdr"]
        self.sdr_type = sdr_type
        self.zero_mean = zero_mean
        self.take_log = take_log
        self.EPS = 1e-8

        self.p = p

    def forward(

            self,

            est_target: torch.Tensor,

            target: torch.Tensor

            ) -> torch.Tensor:
        if target.size() != est_target.size() or target.ndim != 3:
            raise TypeError(
                    f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
            )
        # Step 1. Zero-mean norm
        if self.zero_mean:
            mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
            mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
            target = target - mean_source
            est_target = est_target - mean_estimate
        # Step 2. Pair-wise SI-SDR.
        if self.sdr_type in ["sisdr", "sdsdr"]:
            # [batch, 1]
            dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
            # [batch, 1]
            s_target_energy = (
                    torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
            )
            # [batch, time]
            scaled_target = dot * target / s_target_energy
        else:
            # [batch, time]
            scaled_target = target
        if self.sdr_type in ["sdsdr", "snr"]:
            e_noise = est_target - target
        else:
            e_noise = est_target - scaled_target
        # [batch]

        if self.p == 2.0:
            losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
                    torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
            )
        else:
            losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
                    torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
            )
        if self.take_log:
            losses = 10 * torch.log10(losses + self.EPS)
        losses = losses.mean() if self.reduction == "mean" else losses
        return -losses