File size: 1,145 Bytes
936f6fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch


class WeightedSDR:
    def __init__(self):
        self.loss = weighted_signal_distortion_ratio_loss

    def __call__(self, output, bd):
        return self.loss(output, bd)


def dotproduct(y, y_hat):
    # batch x channel x nsamples
    return torch.bmm(y.view(y.shape[0], 1, y.shape[-1]), y_hat.view(y_hat.shape[0], y_hat.shape[-1], 1)).reshape(-1)


def weighted_signal_distortion_ratio_loss(output, bd):
    y = bd['y']  # target signal
    z = bd['z']  # noise signal

    y_hat = output
    z_hat = bd['x'] - y_hat  # expected noise signal

    # mono channel only...
    # can i fix this?
    y_norm = torch.norm(y, dim=-1).squeeze(1)
    z_norm = torch.norm(z, dim=-1).squeeze(1)
    y_hat_norm = torch.norm(y_hat, dim=-1).squeeze(1)
    z_hat_norm = torch.norm(z_hat, dim=-1).squeeze(1)

    def loss_sdr(a, a_hat, a_norm, a_hat_norm):
        return dotproduct(a, a_hat) / (a_norm * a_hat_norm + 1e-8)

    alpha = y_norm.pow(2) / (y_norm.pow(2) + z_norm.pow(2) + 1e-8)
    loss_wSDR = -alpha * loss_sdr(y, y_hat, y_norm, y_hat_norm) - (1 - alpha) * loss_sdr(z, z_hat, z_norm, z_hat_norm)

    return loss_wSDR.mean()