|
import torch
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.nn import functional as F
|
|
|
|
class SignalNoisePNormRatio(_Loss):
|
|
def __init__(
|
|
self,
|
|
p: float = 1.0,
|
|
scale_invariant: bool = False,
|
|
zero_mean: bool = False,
|
|
take_log: bool = True,
|
|
reduction: str = "mean",
|
|
EPS: float = 1e-3,
|
|
) -> None:
|
|
assert reduction != "sum", NotImplementedError
|
|
super().__init__(reduction=reduction)
|
|
assert not zero_mean
|
|
|
|
self.p = p
|
|
|
|
self.EPS = EPS
|
|
self.take_log = take_log
|
|
|
|
self.scale_invariant = scale_invariant
|
|
|
|
def forward(
|
|
self,
|
|
est_target: torch.Tensor,
|
|
target: torch.Tensor
|
|
) -> torch.Tensor:
|
|
|
|
target_ = target
|
|
if self.scale_invariant:
|
|
ndim = target.ndim
|
|
dot = torch.sum(est_target * torch.conj(target), dim=-1, keepdim=True)
|
|
s_target_energy = (
|
|
torch.sum(target * torch.conj(target), dim=-1, keepdim=True)
|
|
)
|
|
|
|
if ndim > 2:
|
|
dot = torch.sum(dot, dim=list(range(1, ndim)), keepdim=True)
|
|
s_target_energy = torch.sum(s_target_energy, dim=list(range(1, ndim)), keepdim=True)
|
|
|
|
target_scaler = (dot + 1e-8) / (s_target_energy + 1e-8)
|
|
target = target_ * target_scaler
|
|
|
|
if torch.is_complex(est_target):
|
|
est_target = torch.view_as_real(est_target)
|
|
target = torch.view_as_real(target)
|
|
|
|
|
|
batch_size = est_target.shape[0]
|
|
est_target = est_target.reshape(batch_size, -1)
|
|
target = target.reshape(batch_size, -1)
|
|
|
|
|
|
if self.p == 1:
|
|
e_error = torch.abs(est_target-target).mean(dim=-1)
|
|
e_target = torch.abs(target).mean(dim=-1)
|
|
elif self.p == 2:
|
|
e_error = torch.square(est_target-target).mean(dim=-1)
|
|
e_target = torch.square(target).mean(dim=-1)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if self.take_log:
|
|
loss = 10*(torch.log10(e_error + self.EPS) - torch.log10(e_target + self.EPS))
|
|
else:
|
|
loss = (e_error + self.EPS)/(e_target + self.EPS)
|
|
|
|
if self.reduction == "mean":
|
|
loss = loss.mean()
|
|
elif self.reduction == "sum":
|
|
loss = loss.sum()
|
|
|
|
return loss
|
|
|
|
|
|
|
|
class MultichannelSingleSrcNegSDR(_Loss):
|
|
def __init__(
|
|
self,
|
|
sdr_type: str,
|
|
p: float = 2.0,
|
|
zero_mean: bool = True,
|
|
take_log: bool = True,
|
|
reduction: str = "mean",
|
|
EPS: float = 1e-8,
|
|
) -> None:
|
|
assert reduction != "sum", NotImplementedError
|
|
super().__init__(reduction=reduction)
|
|
|
|
assert sdr_type in ["snr", "sisdr", "sdsdr"]
|
|
self.sdr_type = sdr_type
|
|
self.zero_mean = zero_mean
|
|
self.take_log = take_log
|
|
self.EPS = 1e-8
|
|
|
|
self.p = p
|
|
|
|
def forward(
|
|
self,
|
|
est_target: torch.Tensor,
|
|
target: torch.Tensor
|
|
) -> torch.Tensor:
|
|
if target.size() != est_target.size() or target.ndim != 3:
|
|
raise TypeError(
|
|
f"Inputs must be of shape [batch, time], got {target.size()} and {est_target.size()} instead"
|
|
)
|
|
|
|
if self.zero_mean:
|
|
mean_source = torch.mean(target, dim=[1, 2], keepdim=True)
|
|
mean_estimate = torch.mean(est_target, dim=[1, 2], keepdim=True)
|
|
target = target - mean_source
|
|
est_target = est_target - mean_estimate
|
|
|
|
if self.sdr_type in ["sisdr", "sdsdr"]:
|
|
|
|
dot = torch.sum(est_target * target, dim=[1, 2], keepdim=True)
|
|
|
|
s_target_energy = (
|
|
torch.sum(target ** 2, dim=[1, 2], keepdim=True) + self.EPS
|
|
)
|
|
|
|
scaled_target = dot * target / s_target_energy
|
|
else:
|
|
|
|
scaled_target = target
|
|
if self.sdr_type in ["sdsdr", "snr"]:
|
|
e_noise = est_target - target
|
|
else:
|
|
e_noise = est_target - scaled_target
|
|
|
|
|
|
if self.p == 2.0:
|
|
losses = torch.sum(scaled_target ** 2, dim=[1, 2]) / (
|
|
torch.sum(e_noise ** 2, dim=[1, 2]) + self.EPS
|
|
)
|
|
else:
|
|
losses = torch.norm(scaled_target, p=self.p, dim=[1, 2]) / (
|
|
torch.linalg.vector_norm(e_noise, p=self.p, dim=[1, 2]) + self.EPS
|
|
)
|
|
if self.take_log:
|
|
losses = 10 * torch.log10(losses + self.EPS)
|
|
losses = losses.mean() if self.reduction == "mean" else losses
|
|
return -losses
|
|
|