xjsc0's picture
1
64ec292
import argparse
from typing import Any, Callable, Optional, Union
import auraloss
import torch
import torch.nn.functional as F
from ml_collections import ConfigDict
from torch import nn
from torch_log_wmse import LogWMSE
def multistft_loss(
y_: torch.Tensor,
y: torch.Tensor,
loss_multistft: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""
Compute a (multi-resolution) STFT-based loss on waveforms.
Reshapes inputs to (B, C*T, L) when needed and delegates to a provided
multi-resolution STFT criterion (e.g., `auraloss.freq.MultiResolutionSTFTLoss`),
a widely used spectral loss for audio synthesis/enhancement that compares
magnitudes across multiple STFT settings.
See: Steinmetz & Reiss, 2020, “auraloss: Audio-focused loss functions in PyTorch”.
Args:
y_ (torch.Tensor): Predicted waveform tensor of shape (B, C, T) or (B, S, C, T).
y (torch.Tensor): Target waveform tensor with a compatible shape.
loss_multistft (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
A callable implementing the MR-STFT loss.
Returns:
torch.Tensor: Scalar loss tensor.
"""
if len(y_.shape) == 4:
y1_ = y_.reshape(y_.shape[0], y_.shape[1] * y_.shape[2], y_.shape[3])
elif len(y_.shape) == 3:
y1_ = y_
if len(y.shape) == 4:
y1 = y.reshape(y.shape[0], y.shape[1] * y.shape[2], y.shape[3])
elif len(y_.shape) == 3:
y1 = y
if len(y_.shape) not in [3, 4]:
raise ValueError(
f"Invalid shape for predicted array: {y_.shape}. Expected 3 or 4 dimensions."
)
return loss_multistft(y1_, y1)
def masked_loss(
y_: torch.Tensor, y: torch.Tensor, q: float, coarse: bool = True
) -> torch.Tensor:
"""
Robust, quantile-masked MSE (“trimmed” MSE).
Computes an elementwise MSE, optionally averages spatial dims (“coarse”),
then masks out the largest residuals by keeping values below the `q`-quantile.
This yields robustness to outliers akin to trimmed/robust regression losses.
See classical robust estimation: Huber, 1964; Rousseeuw & Leroy, 1987.
Args:
y_ (torch.Tensor): Predicted tensor matching `y`'s shape.
y (torch.Tensor): Ground-truth tensor.
q (float): Quantile in (0, 1] used to keep low-error elements.
coarse (bool, optional): If True, average over last two dims before masking.
Defaults to True.
Returns:
torch.Tensor: Scalar loss tensor.
"""
loss = torch.nn.MSELoss(reduction="none")(y_, y).transpose(0, 1)
if coarse:
loss = loss.mean(dim=(-1, -2))
loss = loss.reshape(loss.shape[0], -1)
quantile = torch.quantile(
loss.detach(), q, interpolation="linear", dim=1, keepdim=True
)
mask = loss < quantile
return (loss * mask).mean()
def spec_rmse_loss(
estimate: torch.Tensor, sources: torch.Tensor, stft_config: dict, eps: float = 1e-8
) -> torch.Tensor:
"""
RMSE in the complex STFT domain.
Computes STFT for prediction and target, represents complex values as
real+imag pairs, and applies RMSE (L2) over the spectral representation.
Spectral-domain L2/RMSE losses are common in speech/music enhancement.
See, e.g., Steinmetz & Reiss, 2020; Yamamoto et al., 2020 (Parallel WaveGAN).
Args:
estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T).
sources (torch.Tensor): Target time-domain signal(s), matching shape.
stft_config (dict): Parameters for `torch.stft` (e.g., n_fft, hop_length, win_length).
Returns:
torch.Tensor: Scalar loss tensor.
"""
lenc = estimate.shape[-1]
spec_estimate = estimate.view(-1, lenc)
spec_sources = sources.view(-1, lenc)
spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True)
spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True)
spec_estimate = torch.view_as_real(spec_estimate)
spec_sources = torch.view_as_real(spec_sources)
new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:]
spec_estimate = spec_estimate.view(*new_shape)
spec_sources = spec_sources.view(*new_shape)
loss = F.mse_loss(spec_estimate, spec_sources, reduction="none")
dims = tuple(range(2, loss.dim()))
loss = (loss.mean(dims) + eps).sqrt().mean(dim=(0, 1))
return loss
def spec_masked_loss(
estimate: torch.Tensor,
sources: torch.Tensor,
stft_config: dict,
q: float = 0.9,
coarse: bool = True,
) -> torch.Tensor:
"""
Quantile-masked MSE in the complex STFT domain.
Computes a complex STFT for prediction and target, forms an elementwise MSE
in the spectral domain, optionally averages spatial/frequency dims (“coarse”),
and masks out the highest-error elements using the `q`-quantile threshold for
robustness to outliers. Related to trimmed/robust spectral losses.
See: Huber, 1964; Rousseeuw & Leroy, 1987; spectral losses as in Steinmetz & Reiss, 2020.
Args:
estimate (torch.Tensor): Predicted time-domain signal(s), e.g., (B, S, C, T).
sources (torch.Tensor): Target time-domain signal(s), matching shape.
stft_config (dict): Parameters for `torch.stft`.
q (float, optional): Quantile in (0, 1] to keep low-error elements. Defaults to 0.9.
coarse (bool, optional): If True, average over spectral dims before masking. Defaults to True.
Returns:
torch.Tensor: Scalar loss tensor.
"""
lenc = estimate.shape[-1]
spec_estimate = estimate.view(-1, lenc)
spec_sources = sources.view(-1, lenc)
spec_estimate = torch.stft(spec_estimate, **stft_config, return_complex=True)
spec_sources = torch.stft(spec_sources, **stft_config, return_complex=True)
spec_estimate = torch.view_as_real(spec_estimate)
spec_sources = torch.view_as_real(spec_sources)
new_shape = estimate.shape[:-1] + spec_estimate.shape[-3:]
spec_estimate = spec_estimate.view(*new_shape)
spec_sources = spec_sources.view(*new_shape)
loss = F.mse_loss(spec_estimate, spec_sources, reduction="none")
if coarse:
loss = loss.mean(dim=(-3, -2))
loss = loss.reshape(loss.shape[0], -1)
quantile = torch.quantile(
loss.detach(), q, interpolation="linear", dim=1, keepdim=True
)
mask = loss < quantile
masked_loss = (loss * mask).mean()
return masked_loss
def l1_snr_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss in time domain.
L1-based signal-to-noise ratio loss (without additional regularization).
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import L1SNRLoss
loss_fn = L1SNRLoss(name="l1_snr")
return loss_fn(y_, y)
def l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss with dB-scale level regularization.
Extends L1-SNR with adaptive level-matching regularization in dB scale.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import L1SNRDBLoss
loss_fn = L1SNRDBLoss(name="l1_snr_db")
return loss_fn(y_, y)
def stft_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
L1-SNR loss in multi-resolution STFT domain.
Applies L1-SNR to complex STFT (real/imaginary) across multiple resolutions.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import STFTL1SNRDBLoss
loss_fn = STFTL1SNRDBLoss(name="stft_l1_snr_db")
return loss_fn(y_, y)
def multi_l1_snr_db_loss(y_: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Combined time + STFT domain L1-SNR loss.
Balances time-domain and spectral-domain L1-SNR with optional regularization.
This is the recommended loss from torch-l1-snr for most use cases.
From torch-l1-snr package.
Args:
y_ (torch.Tensor): Predicted waveform tensor.
y (torch.Tensor): Target waveform tensor.
Returns:
torch.Tensor: Scalar loss tensor.
"""
from torch_l1_snr import MultiL1SNRDBLoss
loss_fn = MultiL1SNRDBLoss(name="multi_l1_snr_db")
return loss_fn(y_, y)
def choice_loss(
args: argparse.Namespace, config: ConfigDict
) -> Callable[[Any, Any, Union[Any, None]], torch.Tensor]:
"""
Build a composite loss from CLI/config options.
Returns a callable that sums enabled terms (with per-term coefficients):
- `masked_loss`: robust, quantile-masked MSE (trimmed MSE; Huber, 1964; Rousseeuw & Leroy, 1987).
- `mse_loss`: standard mean squared error.
- `l1_loss`: mean absolute error.
- `multistft_loss`: multi-resolution STFT magnitude loss (Steinmetz & Reiss, 2020).
- `log_wmse_loss`: weighted MSE operating in a log/spectral perceptual space (log-weighted MSE).
- `l1_snr_loss`: L1-SNR loss in time domain (Watcharasupat et al., 2023).
- `l1_snr_db_loss`: L1-SNR with dB-scale level regularization.
- `stft_l1_snr_db_loss`: L1-SNR in multi-resolution STFT domain.
- `multi_l1_snr_db_loss`: combined time + STFT domain L1-SNR (recommended).
- `spec_rmse_loss`: RMSE in complex STFT domain.
- `spec_masked_loss`: quantile-masked spectral MSE (robust spectral loss).
Args:
args (argparse.Namespace): Parsed arguments specifying which losses are active
and their coefficients.
config (ConfigDict): Configuration with loss hyperparameters (e.g., STFT settings,
quantile `q`, coarse masking flag).
Returns:
Callable[[Any, Any, Optional[Any]], torch.Tensor]: A function `loss(y_pred, y_true, x=None)`
that computes the weighted sum of the selected loss terms.
"""
loss_fns = []
if "masked_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: masked_loss(
y_pred,
y_true,
q=config["training"]["q"],
coarse=config["training"]["coarse_loss_clip"],
)
* args.masked_loss_coef
)
if "mse_loss" in args.loss:
mse = nn.MSELoss()
loss_fns.append(
lambda y_pred, y_true, x=None: mse(y_pred, y_true) * args.mse_loss_coef
)
if "l1_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: F.l1_loss(y_pred, y_true) * args.l1_loss_coef
)
if "multistft_loss" in args.loss:
loss_options = dict(config.get("loss_multistft", {}))
stft_loss = auraloss.freq.MultiResolutionSTFTLoss(**loss_options)
loss_fns.append(
lambda y_pred, y_true, x=None: multistft_loss(y_pred, y_true, stft_loss)
* args.multistft_loss_coef
)
if "log_wmse_loss" in args.loss:
log_wmse = LogWMSE(
audio_length=int(getattr(config.audio, "chunk_size", 485100))
// int(getattr(config.audio, "sample_rate", 44100)),
sample_rate=int(getattr(config.audio, "sample_rate", 44100)),
return_as_loss=True,
bypass_filter=getattr(config.training, "bypass_filter", False),
)
loss_fns.append(
lambda y_pred, y_true, x: log_wmse(x, y_pred, y_true)
* args.log_wmse_loss_coef
)
if "l1_snr_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: l1_snr_loss(y_pred, y_true)
* args.l1_snr_loss_coef
)
if "l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: l1_snr_db_loss(y_pred, y_true)
* args.l1_snr_db_loss_coef
)
if "stft_l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: stft_l1_snr_db_loss(y_pred, y_true)
* args.stft_l1_snr_db_loss_coef
)
if "multi_l1_snr_db_loss" in args.loss:
loss_fns.append(
lambda y_pred, y_true, x=None: multi_l1_snr_db_loss(y_pred, y_true)
* args.multi_l1_snr_db_loss_coef
)
if "spec_rmse_loss" in args.loss:
stft_config = {
"n_fft": getattr(config.model, "nfft", 4096),
"hop_length": getattr(config.model, "hop_size", 1024),
"win_length": getattr(config.model, "win_size", 4096),
"center": True,
"normalized": getattr(config.model, "normalized", True),
}
loss_fns.append(
lambda y_pred, y_true, x=None: spec_rmse_loss(y_pred, y_true, stft_config)
* args.spec_rmse_loss_coef
)
if "spec_masked_loss" in args.loss:
stft_config = {
"n_fft": getattr(config.model, "nfft", 4096),
"hop_length": getattr(config.model, "hop_size", 1024),
"win_length": getattr(config.model, "win_size", 4096),
"center": True,
"normalized": getattr(config.model, "normalized", True),
}
loss_fns.append(
lambda y_pred, y_true, x=None: spec_masked_loss(
y_pred,
y_true,
stft_config,
q=config["training"]["q"],
coarse=config["training"]["coarse_loss_clip"],
)
* args.spec_masked_loss_coef
)
def multi_loss(y_pred: Any, y_true: Any, x: Optional[Any] = None) -> torch.Tensor:
total = 0
for fn in loss_fns:
total = total + fn(y_pred, y_true, x)
return total
return multi_loss