Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/utils
/griffin_lim.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 Minh Nguyen (@dathudeptrai) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Griffin-Lim phase reconstruction algorithm from mel spectrogram.""" | |
import os | |
import librosa | |
import numpy as np | |
import soundfile as sf | |
import tensorflow as tf | |
from sklearn.preprocessing import StandardScaler | |
def griffin_lim_lb( | |
mel_spec, stats_path, dataset_config, n_iter=32, output_dir=None, wav_name="lb" | |
): | |
"""Generate wave from mel spectrogram with Griffin-Lim algorithm using Librosa. | |
Args: | |
mel_spec (ndarray): array representing the mel spectrogram. | |
stats_path (str): path to the `stats.npy` file containing norm statistics. | |
dataset_config (Dict): dataset configuration parameters. | |
n_iter (int): number of iterations for GL. | |
output_dir (str): output directory where audio file will be saved. | |
wav_name (str): name of the output file. | |
Returns: | |
gl_lb (ndarray): generated wave. | |
""" | |
scaler = StandardScaler() | |
scaler.mean_, scaler.scale_ = np.load(stats_path) | |
mel_spec = np.power(10.0, scaler.inverse_transform(mel_spec)).T | |
mel_basis = librosa.filters.mel( | |
dataset_config["sampling_rate"], | |
n_fft=dataset_config["fft_size"], | |
n_mels=dataset_config["num_mels"], | |
fmin=dataset_config["fmin"], | |
fmax=dataset_config["fmax"], | |
) | |
mel_to_linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel_spec)) | |
gl_lb = librosa.griffinlim( | |
mel_to_linear, | |
n_iter=n_iter, | |
hop_length=dataset_config["hop_size"], | |
win_length=dataset_config["win_length"] or dataset_config["fft_size"], | |
) | |
if output_dir: | |
output_path = os.path.join(output_dir, f"{wav_name}.wav") | |
sf.write(output_path, gl_lb, dataset_config["sampling_rate"], "PCM_16") | |
return gl_lb | |
class TFGriffinLim(tf.keras.layers.Layer): | |
"""Griffin-Lim algorithm for phase reconstruction from mel spectrogram magnitude.""" | |
def __init__(self, stats_path, dataset_config, normalized: bool = True): | |
"""Init GL params. | |
Args: | |
stats_path (str): path to the `stats.npy` file containing norm statistics. | |
dataset_config (Dict): dataset configuration parameters. | |
""" | |
super().__init__() | |
self.normalized = normalized | |
if normalized: | |
scaler = StandardScaler() | |
scaler.mean_, scaler.scale_ = np.load(stats_path) | |
self.scaler = scaler | |
self.ds_config = dataset_config | |
self.mel_basis = librosa.filters.mel( | |
self.ds_config["sampling_rate"], | |
n_fft=self.ds_config["fft_size"], | |
n_mels=self.ds_config["num_mels"], | |
fmin=self.ds_config["fmin"], | |
fmax=self.ds_config["fmax"], | |
) # [num_mels, fft_size // 2 + 1] | |
def save_wav(self, gl_tf, output_dir, wav_name): | |
"""Generate WAV file and save it. | |
Args: | |
gl_tf (tf.Tensor): reconstructed signal from GL algorithm. | |
output_dir (str): output directory where audio file will be saved. | |
wav_name (str): name of the output file. | |
""" | |
encode_fn = lambda x: tf.audio.encode_wav(x, self.ds_config["sampling_rate"]) | |
gl_tf = tf.expand_dims(gl_tf, -1) | |
if not isinstance(wav_name, list): | |
wav_name = [wav_name] | |
if len(gl_tf.shape) > 2: | |
bs, *_ = gl_tf.shape | |
assert bs == len(wav_name), "Batch and 'wav_name' have different size." | |
tf_wav = tf.map_fn(encode_fn, gl_tf, dtype=tf.string) | |
for idx in tf.range(bs): | |
output_path = os.path.join(output_dir, f"{wav_name[idx]}.wav") | |
tf.io.write_file(output_path, tf_wav[idx]) | |
else: | |
tf_wav = encode_fn(gl_tf) | |
tf.io.write_file(os.path.join(output_dir, f"{wav_name[0]}.wav"), tf_wav) | |
def call(self, mel_spec, n_iter=32): | |
"""Apply GL algorithm to batched mel spectrograms. | |
Args: | |
mel_spec (tf.Tensor): normalized mel spectrogram. | |
n_iter (int): number of iterations to run GL algorithm. | |
Returns: | |
(tf.Tensor): reconstructed signal from GL algorithm. | |
""" | |
# de-normalize mel spectogram | |
if self.normalized: | |
mel_spec = tf.math.pow( | |
10.0, mel_spec * self.scaler.scale_ + self.scaler.mean_ | |
) | |
else: | |
mel_spec = tf.math.pow( | |
10.0, mel_spec | |
) # TODO @dathudeptrai check if its ok without it wavs were too quiet | |
inverse_mel = tf.linalg.pinv(self.mel_basis) | |
# [:, num_mels] @ [fft_size // 2 + 1, num_mels].T | |
mel_to_linear = tf.linalg.matmul(mel_spec, inverse_mel, transpose_b=True) | |
mel_to_linear = tf.cast(tf.math.maximum(1e-10, mel_to_linear), tf.complex64) | |
init_phase = tf.cast( | |
tf.random.uniform(tf.shape(mel_to_linear), maxval=1), tf.complex64 | |
) | |
phase = tf.math.exp(2j * np.pi * init_phase) | |
for _ in tf.range(n_iter): | |
inverse = tf.signal.inverse_stft( | |
mel_to_linear * phase, | |
frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"], | |
frame_step=self.ds_config["hop_size"], | |
fft_length=self.ds_config["fft_size"], | |
window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]), | |
) | |
phase = tf.signal.stft( | |
inverse, | |
self.ds_config["win_length"] or self.ds_config["fft_size"], | |
self.ds_config["hop_size"], | |
self.ds_config["fft_size"], | |
) | |
phase /= tf.cast(tf.maximum(1e-10, tf.abs(phase)), tf.complex64) | |
return tf.signal.inverse_stft( | |
mel_to_linear * phase, | |
frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"], | |
frame_step=self.ds_config["hop_size"], | |
fft_length=self.ds_config["fft_size"], | |
window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]), | |
) | |