import torch import auraloss import resampy import torchaudio from pesq import pesq import pyloudnorm as pyln def crest_factor(x): """Compute the crest factor of waveform.""" peak, _ = x.abs().max(dim=-1) rms = torch.sqrt((x ** 2).mean(dim=-1)) return 20 * torch.log(peak / rms.clamp(1e-8)) def rms_energy(x): rms = torch.sqrt((x ** 2).mean(dim=-1)) return 20 * torch.log(rms.clamp(1e-8)) def spectral_centroid(x): """Compute the crest factor of waveform. See: https://gist.github.com/endolith/359724 """ spectrum = torch.fft.rfft(x).abs() normalized_spectrum = spectrum / spectrum.sum() normalized_frequencies = torch.linspace(0, 1, spectrum.shape[-1]) spectral_centroid = torch.sum(normalized_frequencies * normalized_spectrum) return spectral_centroid def loudness(x, sample_rate): """Compute the loudness in dB LUFS of waveform.""" meter = pyln.Meter(sample_rate) # add stereo dim if needed if x.shape[0] < 2: x = x.repeat(2, 1) return torch.tensor(meter.integrated_loudness(x.permute(1, 0).numpy())) class MelSpectralDistance(torch.nn.Module): def __init__(self, sample_rate, length=65536): super().__init__() self.error = auraloss.freq.MelSTFTLoss( sample_rate, fft_size=length, hop_size=length, win_length=length, w_sc=0, w_log_mag=1, w_lin_mag=1, n_mels=128, scale_invariance=False, ) # I think scale invariance may not work well, # since aspects of the phase may be considered? def forward(self, input, target): return self.error(input, target) class PESQ(torch.nn.Module): def __init__(self, sample_rate): super().__init__() self.sample_rate = sample_rate def forward(self, input, target): if self.sample_rate != 16000: target = resampy.resample( target.view(-1).numpy(), self.sample_rate, 16000, ) input = resampy.resample( input.view(-1).numpy(), self.sample_rate, 16000, ) return pesq( 16000, target, input, "wb", ) class CrestFactorError(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input, target): return torch.nn.functional.l1_loss( crest_factor(input), crest_factor(target), ).item() class RMSEnergyError(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input, target): return torch.nn.functional.l1_loss( rms_energy(input), rms_energy(target), ).item() class SpectralCentroidError(torch.nn.Module): def __init__(self, sample_rate, n_fft=2048, hop_length=512): super().__init__() self.spectral_centroid = torchaudio.transforms.SpectralCentroid( sample_rate, n_fft=n_fft, hop_length=hop_length, ) def forward(self, input, target): return torch.nn.functional.l1_loss( self.spectral_centroid(input + 1e-16).mean(), self.spectral_centroid(target + 1e-16).mean(), ).item() class LoudnessError(torch.nn.Module): def __init__(self, sample_rate: int, peak_normalize: bool = False): super().__init__() self.sample_rate = sample_rate self.peak_normalize = peak_normalize def forward(self, input, target): if self.peak_normalize: # peak normalize x = input / input.abs().max() y = target / target.abs().max() else: x = input y = target return torch.nn.functional.l1_loss( loudness(x.view(1, -1), self.sample_rate), loudness(y.view(1, -1), self.sample_rate), ).item()