vishred18's picture
Upload 364 files
d5ee97c
raw
history blame contribute delete
No virus
5.18 kB
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""STFT-based loss modules."""
import tensorflow as tf
class TFSpectralConvergence(tf.keras.layers.Layer):
"""Spectral convergence loss."""
def __init__(self):
"""Initialize."""
super().__init__()
def call(self, y_mag, x_mag):
"""Calculate forward propagation.
Args:
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return tf.norm(y_mag - x_mag, ord="fro", axis=(-2, -1)) / tf.norm(
y_mag, ord="fro", axis=(-2, -1)
)
class TFLogSTFTMagnitude(tf.keras.layers.Layer):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initialize."""
super().__init__()
def call(self, y_mag, x_mag):
"""Calculate forward propagation.
Args:
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return tf.abs(tf.math.log(y_mag) - tf.math.log(x_mag))
class TFSTFT(tf.keras.layers.Layer):
"""STFT loss module."""
def __init__(self, frame_length=600, frame_step=120, fft_length=1024):
"""Initialize."""
super().__init__()
self.frame_length = frame_length
self.frame_step = frame_step
self.fft_length = fft_length
self.spectral_convergenge_loss = TFSpectralConvergence()
self.log_stft_magnitude_loss = TFLogSTFTMagnitude()
def call(self, y, x):
"""Calculate forward propagation.
Args:
y (Tensor): Groundtruth signal (B, T).
x (Tensor): Predicted signal (B, T).
Returns:
Tensor: Spectral convergence loss value (pre-reduce).
Tensor: Log STFT magnitude loss value (pre-reduce).
"""
x_mag = tf.abs(
tf.signal.stft(
signals=x,
frame_length=self.frame_length,
frame_step=self.frame_step,
fft_length=self.fft_length,
)
)
y_mag = tf.abs(
tf.signal.stft(
signals=y,
frame_length=self.frame_length,
frame_step=self.frame_step,
fft_length=self.fft_length,
)
)
# add small number to prevent nan value.
# compatible with pytorch version.
x_mag = tf.clip_by_value(tf.math.sqrt(x_mag ** 2 + 1e-7), 1e-7, 1e3)
y_mag = tf.clip_by_value(tf.math.sqrt(y_mag ** 2 + 1e-7), 1e-7, 1e3)
sc_loss = self.spectral_convergenge_loss(y_mag, x_mag)
mag_loss = self.log_stft_magnitude_loss(y_mag, x_mag)
return sc_loss, mag_loss
class TFMultiResolutionSTFT(tf.keras.layers.Layer):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_lengths=[1024, 2048, 512],
frame_lengths=[600, 1200, 240],
frame_steps=[120, 240, 50],
):
"""Initialize Multi resolution STFT loss module.
Args:
frame_lengths (list): List of FFT sizes.
frame_steps (list): List of hop sizes.
fft_lengths (list): List of window lengths.
"""
super().__init__()
assert len(frame_lengths) == len(frame_steps) == len(fft_lengths)
self.stft_losses = []
for frame_length, frame_step, fft_length in zip(
frame_lengths, frame_steps, fft_lengths
):
self.stft_losses.append(TFSTFT(frame_length, frame_step, fft_length))
def call(self, y, x):
"""Calculate forward propagation.
Args:
y (Tensor): Groundtruth signal (B, T).
x (Tensor): Predicted signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(y, x)
sc_loss += tf.reduce_mean(sc_l, axis=list(range(1, len(sc_l.shape))))
mag_loss += tf.reduce_mean(mag_l, axis=list(range(1, len(mag_l.shape))))
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss