# -*- coding: utf-8 -*- # Copyright 2020 Minh Nguyen (@dathudeptrai) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """STFT-based loss modules.""" import tensorflow as tf class TFSpectralConvergence(tf.keras.layers.Layer): """Spectral convergence loss.""" def __init__(self): """Initialize.""" super().__init__() def call(self, y_mag, x_mag): """Calculate forward propagation. Args: y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). Returns: Tensor: Spectral convergence loss value. """ return tf.norm(y_mag - x_mag, ord="fro", axis=(-2, -1)) / tf.norm( y_mag, ord="fro", axis=(-2, -1) ) class TFLogSTFTMagnitude(tf.keras.layers.Layer): """Log STFT magnitude loss module.""" def __init__(self): """Initialize.""" super().__init__() def call(self, y_mag, x_mag): """Calculate forward propagation. Args: y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). Returns: Tensor: Spectral convergence loss value. """ return tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag)) class TFSTFT(tf.keras.layers.Layer): """STFT loss module.""" def __init__(self, frame_length=600, frame_step=120, fft_length=1024): """Initialize.""" super().__init__() self.frame_length = frame_length self.frame_step = frame_step self.fft_length = fft_length self.spectral_convergenge_loss = TFSpectralConvergence() self.log_stft_magnitude_loss = TFLogSTFTMagnitude() def call(self, y, x): """Calculate forward propagation. Args: y (Tensor): Groundtruth signal (B, T). x (Tensor): Predicted signal (B, T). Returns: Tensor: Spectral convergence loss value (pre-reduce). Tensor: Log STFT magnitude loss value (pre-reduce). """ x_mag = tf.abs( tf.signal.stft( signals=x, frame_length=self.frame_length, frame_step=self.frame_step, fft_length=self.fft_length, ) ) y_mag = tf.abs( tf.signal.stft( signals=y, frame_length=self.frame_length, frame_step=self.frame_step, fft_length=self.fft_length, ) ) # add small number to prevent nan value. # compatible with pytorch version. x_mag = tf.clip_by_value(tf.math.sqrt(x_mag ** 2 + 1e-7), 1e-7, 1e3) y_mag = tf.clip_by_value(tf.math.sqrt(y_mag ** 2 + 1e-7), 1e-7, 1e3) sc_loss = self.spectral_convergenge_loss(y_mag, x_mag) mag_loss = self.log_stft_magnitude_loss(y_mag, x_mag) return sc_loss, mag_loss class TFMultiResolutionSTFT(tf.keras.layers.Layer): """Multi resolution STFT loss module.""" def __init__( self, fft_lengths=[1024, 2048, 512], frame_lengths=[600, 1200, 240], frame_steps=[120, 240, 50], ): """Initialize Multi resolution STFT loss module. Args: frame_lengths (list): List of FFT sizes. frame_steps (list): List of hop sizes. fft_lengths (list): List of window lengths. """ super().__init__() assert len(frame_lengths) == len(frame_steps) == len(fft_lengths) self.stft_losses = [] for frame_length, frame_step, fft_length in zip( frame_lengths, frame_steps, fft_lengths ): self.stft_losses.append(TFSTFT(frame_length, frame_step, fft_length)) def call(self, y, x): """Calculate forward propagation. Args: y (Tensor): Groundtruth signal (B, T). x (Tensor): Predicted signal (B, T). Returns: Tensor: Multi resolution spectral convergence loss value. Tensor: Multi resolution log STFT magnitude loss value. """ sc_loss = 0.0 mag_loss = 0.0 for f in self.stft_losses: sc_l, mag_l = f(y, x) sc_loss += tf.reduce_mean(sc_l, axis=list(range(1, len(sc_l.shape)))) mag_loss += tf.reduce_mean(mag_l, axis=list(range(1, len(mag_l.shape)))) sc_loss /= len(self.stft_losses) mag_loss /= len(self.stft_losses) return sc_loss, mag_loss