|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""STFT-based loss modules.""" |
|
|
|
import tensorflow as tf |
|
|
|
|
|
class TFSpectralConvergence(tf.keras.layers.Layer): |
|
"""Spectral convergence loss.""" |
|
|
|
def __init__(self): |
|
"""Initialize.""" |
|
super().__init__() |
|
|
|
def call(self, y_mag, x_mag): |
|
"""Calculate forward propagation. |
|
Args: |
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). |
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). |
|
Returns: |
|
Tensor: Spectral convergence loss value. |
|
""" |
|
return tf.norm(y_mag - x_mag, ord="fro", axis=(-2, -1)) / tf.norm( |
|
y_mag, ord="fro", axis=(-2, -1) |
|
) |
|
|
|
|
|
class TFLogSTFTMagnitude(tf.keras.layers.Layer): |
|
"""Log STFT magnitude loss module.""" |
|
|
|
def __init__(self): |
|
"""Initialize.""" |
|
super().__init__() |
|
|
|
def call(self, y_mag, x_mag): |
|
"""Calculate forward propagation. |
|
Args: |
|
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). |
|
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). |
|
Returns: |
|
Tensor: Spectral convergence loss value. |
|
""" |
|
return tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag)) |
|
|
|
|
|
class TFSTFT(tf.keras.layers.Layer): |
|
"""STFT loss module.""" |
|
|
|
def __init__(self, frame_length=600, frame_step=120, fft_length=1024): |
|
"""Initialize.""" |
|
super().__init__() |
|
self.frame_length = frame_length |
|
self.frame_step = frame_step |
|
self.fft_length = fft_length |
|
self.spectral_convergenge_loss = TFSpectralConvergence() |
|
self.log_stft_magnitude_loss = TFLogSTFTMagnitude() |
|
|
|
def call(self, y, x): |
|
"""Calculate forward propagation. |
|
Args: |
|
y (Tensor): Groundtruth signal (B, T). |
|
x (Tensor): Predicted signal (B, T). |
|
Returns: |
|
Tensor: Spectral convergence loss value (pre-reduce). |
|
Tensor: Log STFT magnitude loss value (pre-reduce). |
|
""" |
|
x_mag = tf.abs( |
|
tf.signal.stft( |
|
signals=x, |
|
frame_length=self.frame_length, |
|
frame_step=self.frame_step, |
|
fft_length=self.fft_length, |
|
) |
|
) |
|
y_mag = tf.abs( |
|
tf.signal.stft( |
|
signals=y, |
|
frame_length=self.frame_length, |
|
frame_step=self.frame_step, |
|
fft_length=self.fft_length, |
|
) |
|
) |
|
|
|
|
|
|
|
x_mag = tf.clip_by_value(tf.math.sqrt(x_mag ** 2 + 1e-7), 1e-7, 1e3) |
|
y_mag = tf.clip_by_value(tf.math.sqrt(y_mag ** 2 + 1e-7), 1e-7, 1e3) |
|
|
|
sc_loss = self.spectral_convergenge_loss(y_mag, x_mag) |
|
mag_loss = self.log_stft_magnitude_loss(y_mag, x_mag) |
|
|
|
return sc_loss, mag_loss |
|
|
|
|
|
class TFMultiResolutionSTFT(tf.keras.layers.Layer): |
|
"""Multi resolution STFT loss module.""" |
|
|
|
def __init__( |
|
self, |
|
fft_lengths=[1024, 2048, 512], |
|
frame_lengths=[600, 1200, 240], |
|
frame_steps=[120, 240, 50], |
|
): |
|
"""Initialize Multi resolution STFT loss module. |
|
Args: |
|
frame_lengths (list): List of FFT sizes. |
|
frame_steps (list): List of hop sizes. |
|
fft_lengths (list): List of window lengths. |
|
""" |
|
super().__init__() |
|
assert len(frame_lengths) == len(frame_steps) == len(fft_lengths) |
|
self.stft_losses = [] |
|
for frame_length, frame_step, fft_length in zip( |
|
frame_lengths, frame_steps, fft_lengths |
|
): |
|
self.stft_losses.append(TFSTFT(frame_length, frame_step, fft_length)) |
|
|
|
def call(self, y, x): |
|
"""Calculate forward propagation. |
|
Args: |
|
y (Tensor): Groundtruth signal (B, T). |
|
x (Tensor): Predicted signal (B, T). |
|
Returns: |
|
Tensor: Multi resolution spectral convergence loss value. |
|
Tensor: Multi resolution log STFT magnitude loss value. |
|
""" |
|
sc_loss = 0.0 |
|
mag_loss = 0.0 |
|
for f in self.stft_losses: |
|
sc_l, mag_l = f(y, x) |
|
sc_loss += tf.reduce_mean(sc_l, axis=list(range(1, len(sc_l.shape)))) |
|
mag_loss += tf.reduce_mean(mag_l, axis=list(range(1, len(mag_l.shape)))) |
|
|
|
sc_loss /= len(self.stft_losses) |
|
mag_loss /= len(self.stft_losses) |
|
|
|
return sc_loss, mag_loss |
|
|