Spaces:
Sleeping
Sleeping
| import torch | |
| from torch.nn.modules.loss import _Loss | |
| class PairwiseNegSDR(_Loss): | |
| def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): | |
| super().__init__() | |
| assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
| self.sdr_type = sdr_type | |
| self.zero_mean = zero_mean | |
| self.take_log = take_log | |
| self.EPS = EPS | |
| def forward(self, ests, targets): | |
| if targets.size() != ests.size() or targets.ndim != 3: | |
| raise TypeError( | |
| f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" | |
| ) | |
| assert targets.size() == ests.size() | |
| # Step 1. Zero-mean norm | |
| if self.zero_mean: | |
| mean_source = torch.mean(targets, dim=2, keepdim=True) | |
| mean_estimate = torch.mean(ests, dim=2, keepdim=True) | |
| targets = targets - mean_source | |
| ests = ests - mean_estimate | |
| # Step 2. Pair-wise SI-SDR. (Reshape to use broadcast) | |
| s_target = torch.unsqueeze(targets, dim=1) | |
| s_estimate = torch.unsqueeze(ests, dim=2) | |
| if self.sdr_type in ["sisdr", "sdsdr"]: | |
| # [batch, n_src, n_src, 1] | |
| pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True) | |
| # [batch, 1, n_src, 1] | |
| s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS | |
| # [batch, n_src, n_src, time] | |
| pair_wise_proj = pair_wise_dot * s_target / s_target_energy | |
| else: | |
| # [batch, n_src, n_src, time] | |
| pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1) | |
| if self.sdr_type in ["sdsdr", "snr"]: | |
| e_noise = s_estimate - s_target | |
| else: | |
| e_noise = s_estimate - pair_wise_proj | |
| # [batch, n_src, n_src] | |
| pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / ( | |
| torch.sum(e_noise ** 2, dim=3) + self.EPS | |
| ) | |
| if self.take_log: | |
| pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) | |
| return -pair_wise_sdr | |
| class SingleSrcNegSDR(_Loss): | |
| def __init__( | |
| self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8 | |
| ): | |
| assert reduction != "sum", NotImplementedError | |
| super().__init__(reduction=reduction) | |
| assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
| self.sdr_type = sdr_type | |
| self.zero_mean = zero_mean | |
| self.take_log = take_log | |
| self.EPS = 1e-8 | |
| def forward(self, ests, targets): | |
| if targets.size() != ests.size() or targets.ndim != 2: | |
| raise TypeError( | |
| f"Inputs must be of shape [batch, time], got {targets.size()} and {ests.size()} instead" | |
| ) | |
| # Step 1. Zero-mean norm | |
| if self.zero_mean: | |
| mean_source = torch.mean(targets, dim=1, keepdim=True) | |
| mean_estimate = torch.mean(ests, dim=1, keepdim=True) | |
| targets = targets - mean_source | |
| ests = ests - mean_estimate | |
| # Step 2. Pair-wise SI-SDR. | |
| if self.sdr_type in ["sisdr", "sdsdr"]: | |
| # [batch, 1] | |
| dot = torch.sum(ests * targets, dim=1, keepdim=True) | |
| # [batch, 1] | |
| s_target_energy = torch.sum(targets ** 2, dim=1, keepdim=True) + self.EPS | |
| # [batch, time] | |
| scaled_target = dot * targets / s_target_energy | |
| else: | |
| # [batch, time] | |
| scaled_target = targets | |
| if self.sdr_type in ["sdsdr", "snr"]: | |
| e_noise = ests - targets | |
| else: | |
| e_noise = ests - scaled_target | |
| # [batch] | |
| losses = torch.sum(scaled_target ** 2, dim=1) / ( | |
| torch.sum(e_noise ** 2, dim=1) + self.EPS | |
| ) | |
| if self.take_log: | |
| losses = 10 * torch.log10(losses + self.EPS) | |
| losses = losses.mean() if self.reduction == "mean" else losses | |
| return -losses | |
| class MultiSrcNegSDR(_Loss): | |
| def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8): | |
| super().__init__() | |
| assert sdr_type in ["snr", "sisdr", "sdsdr"] | |
| self.sdr_type = sdr_type | |
| self.zero_mean = zero_mean | |
| self.take_log = take_log | |
| self.EPS = 1e-8 | |
| def forward(self, ests, targets): | |
| if targets.size() != ests.size() or targets.ndim != 3: | |
| raise TypeError( | |
| f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead" | |
| ) | |
| # Step 1. Zero-mean norm | |
| if self.zero_mean: | |
| mean_source = torch.mean(targets, dim=2, keepdim=True) | |
| mean_est = torch.mean(ests, dim=2, keepdim=True) | |
| targets = targets - mean_source | |
| ests = ests - mean_est | |
| # Step 2. Pair-wise SI-SDR. | |
| if self.sdr_type in ["sisdr", "sdsdr"]: | |
| # [batch, n_src] | |
| pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True) | |
| # [batch, n_src] | |
| s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS | |
| # [batch, n_src, time] | |
| scaled_targets = pair_wise_dot * targets / s_target_energy | |
| else: | |
| # [batch, n_src, time] | |
| scaled_targets = targets | |
| if self.sdr_type in ["sdsdr", "snr"]: | |
| e_noise = ests - targets | |
| else: | |
| e_noise = ests - scaled_targets | |
| # [batch, n_src] | |
| pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / ( | |
| torch.sum(e_noise ** 2, dim=2) + self.EPS | |
| ) | |
| if self.take_log: | |
| pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS) | |
| return -torch.mean(pair_wise_sdr, dim=-1) | |
| class freq_MAE_WavL1Loss(_Loss): | |
| def __init__(self, win=2048, stride=512): | |
| super().__init__() | |
| self.EPS = 1e-8 | |
| self.win = win | |
| self.stride = stride | |
| def forward(self, ests, targets): | |
| B, nsrc, _ = ests.shape | |
| est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
| window=torch.hann_window(self.win).to(ests.device).float(), | |
| return_complex=True) | |
| est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
| window=torch.hann_window(self.win).to(ests.device).float(), | |
| return_complex=True) | |
| freq_L1 = (est_spec.real - est_target.real).abs().mean((1,2)) + (est_spec.imag - est_target.imag).abs().mean((1,2)) | |
| freq_L1 = freq_L1.view(B, nsrc).mean(-1) | |
| wave_l1 = (ests - targets).abs().mean(-1) | |
| # print(freq_L1.shape, wave_l1.shape) | |
| wave_l1 = wave_l1.view(B, nsrc).mean(-1) | |
| return freq_L1 + wave_l1 | |
| class freq_MSE_Loss(_Loss): | |
| def __init__(self, win=640, stride=160): | |
| super().__init__() | |
| self.EPS = 1e-8 | |
| self.win = win | |
| self.stride = stride | |
| def forward(self, ests, targets): | |
| B, nsrc, _ = ests.shape | |
| est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
| window=torch.hann_window(self.win).to(ests.device).float(), | |
| return_complex=True) | |
| est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride, | |
| window=torch.hann_window(self.win).to(ests.device).float(), | |
| return_complex=True) | |
| freq_mse = (est_spec.real - est_target.real).square().mean((1,2)) + (est_spec.imag - est_target.imag).square().mean((1,2)) | |
| freq_mse = freq_mse.view(B, nsrc).mean(-1) | |
| return freq_mse | |
| # aliases | |
| pairwise_neg_sisdr = PairwiseNegSDR("sisdr") | |
| pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr") | |
| pairwise_neg_snr = PairwiseNegSDR("snr") | |
| singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr") | |
| singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr") | |
| singlesrc_neg_snr = SingleSrcNegSDR("snr") | |
| multisrc_neg_sisdr = MultiSrcNegSDR("sisdr") | |
| multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr") | |
| multisrc_neg_snr = MultiSrcNegSDR("snr") | |
| freq_mae_wavl1loss = freq_MAE_WavL1Loss() | |
| pairwise_neg_sisdr_freq_mse = (PairwiseNegSDR("sisdr"), freq_MSE_Loss()) | |
| pairwise_neg_snr_multidecoder = (PairwiseNegSDR("snr"), MultiSrcNegSDR("snr")) |