# Copyright 2021 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) # Adapted by Florian Lux 2021 import librosa import torch import torch.nn.functional as F class MelSpectrogram(torch.nn.Module): def __init__(self, fs=24000, fft_size=1536, hop_size=384, win_length=None, window="hann", num_mels=100, fmin=60, fmax=None, center=True, normalized=False, onesided=True, eps=1e-10, log_base=10.0, ): super().__init__() self.fft_size = fft_size if win_length is None: self.win_length = fft_size else: self.win_length = win_length self.hop_size = hop_size self.center = center self.normalized = normalized self.onesided = onesided if window is not None and not hasattr(torch, f"{window}_window"): raise ValueError(f"{window} window is not implemented") self.window = window self.eps = eps fmin = 0 if fmin is None else fmin fmax = fs / 2 if fmax is None else fmax melmat = librosa.filters.mel(sr=fs, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax, ) self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) self.stft_params = { "n_fft" : self.fft_size, "win_length": self.win_length, "hop_length": self.hop_size, "center" : self.center, "normalized": self.normalized, "onesided" : self.onesided, } self.stft_params["return_complex"] = False self.log_base = log_base if self.log_base is None: self.log = torch.log elif self.log_base == 2.0: self.log = torch.log2 elif self.log_base == 10.0: self.log = torch.log10 else: raise ValueError(f"log_base: {log_base} is not supported.") def forward(self, x): """ Calculate Mel-spectrogram. Args: x (Tensor): Input waveform tensor (B, T) or (B, 1, T). Returns: Tensor: Mel-spectrogram (B, #mels, #frames). """ if x.dim() == 3: # (B, C, T) -> (B*C, T) x = x.reshape(-1, x.size(2)) if self.window is not None: window_func = getattr(torch, f"{self.window}_window") window = window_func(self.win_length, dtype=x.dtype, device=x.device) else: window = None x_stft = torch.stft(x, window=window, **self.stft_params) # (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2) x_stft = x_stft.transpose(1, 2) x_power = x_stft[..., 0] ** 2 + x_stft[..., 1] ** 2 x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps)) x_mel = torch.matmul(x_amp, self.melmat) x_mel = torch.clamp(x_mel, min=self.eps) return self.log(x_mel).transpose(1, 2) class MelSpectrogramLoss(torch.nn.Module): def __init__(self, fs=24000, fft_size=1024, hop_size=256, win_length=None, window="hann", num_mels=128, fmin=20, fmax=None, center=True, normalized=False, onesided=True, eps=1e-10, log_base=10.0, ): super().__init__() self.mel_spectrogram = MelSpectrogram(fs=fs, fft_size=fft_size, hop_size=hop_size, win_length=win_length, window=window, num_mels=num_mels, fmin=fmin, fmax=fmax, center=center, normalized=normalized, onesided=onesided, eps=eps, log_base=log_base, ) def forward(self, y_hat, y): """ Calculate Mel-spectrogram loss. Args: y_hat (Tensor): Generated single tensor (B, 1, T). y (Tensor): Groundtruth single tensor (B, 1, T). Returns: Tensor: Mel-spectrogram loss value. """ mel_hat = self.mel_spectrogram(y_hat) mel = self.mel_spectrogram(y) mel_loss = F.l1_loss(mel_hat, mel) return mel_loss