Spaces:
Running
Running
import torch | |
class WeightedSDR: | |
def __init__(self): | |
self.loss = weighted_signal_distortion_ratio_loss | |
def __call__(self, output, bd): | |
return self.loss(output, bd) | |
def dotproduct(y, y_hat): | |
# batch x channel x nsamples | |
return torch.bmm(y.view(y.shape[0], 1, y.shape[-1]), y_hat.view(y_hat.shape[0], y_hat.shape[-1], 1)).reshape(-1) | |
def weighted_signal_distortion_ratio_loss(output, bd): | |
y = bd['y'] # target signal | |
z = bd['z'] # noise signal | |
y_hat = output | |
z_hat = bd['x'] - y_hat # expected noise signal | |
# mono channel only... | |
# can i fix this? | |
y_norm = torch.norm(y, dim=-1).squeeze(1) | |
z_norm = torch.norm(z, dim=-1).squeeze(1) | |
y_hat_norm = torch.norm(y_hat, dim=-1).squeeze(1) | |
z_hat_norm = torch.norm(z_hat, dim=-1).squeeze(1) | |
def loss_sdr(a, a_hat, a_norm, a_hat_norm): | |
return dotproduct(a, a_hat) / (a_norm * a_hat_norm + 1e-8) | |
alpha = y_norm.pow(2) / (y_norm.pow(2) + z_norm.pow(2) + 1e-8) | |
loss_wSDR = -alpha * loss_sdr(y, y_hat, y_norm, y_hat_norm) - (1 - alpha) * loss_sdr(z, z_hat, z_norm, z_hat_norm) | |
return loss_wSDR.mean() | |