Spaces:
Paused
Paused
from typing import Dict | |
import numpy as np | |
import torch | |
from matplotlib import pyplot as plt | |
from TTS.tts.utils.visual import plot_spectrogram | |
from TTS.utils.audio import AudioProcessor | |
def interpolate_vocoder_input(scale_factor, spec): | |
"""Interpolate spectrogram by the scale factor. | |
It is mainly used to match the sampling rates of | |
the tts and vocoder models. | |
Args: | |
scale_factor (float): scale factor to interpolate the spectrogram | |
spec (np.array): spectrogram to be interpolated | |
Returns: | |
torch.tensor: interpolated spectrogram. | |
""" | |
print(" > before interpolation :", spec.shape) | |
spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable | |
spec = torch.nn.functional.interpolate( | |
spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False | |
).squeeze(0) | |
print(" > after interpolation :", spec.shape) | |
return spec | |
def plot_results(y_hat: torch.tensor, y: torch.tensor, ap: AudioProcessor, name_prefix: str = None) -> Dict: | |
"""Plot the predicted and the real waveform and their spectrograms. | |
Args: | |
y_hat (torch.tensor): Predicted waveform. | |
y (torch.tensor): Real waveform. | |
ap (AudioProcessor): Audio processor used to process the waveform. | |
name_prefix (str, optional): Name prefix used to name the figures. Defaults to None. | |
Returns: | |
Dict: output figures keyed by the name of the figures. | |
""" """Plot vocoder model results""" | |
if name_prefix is None: | |
name_prefix = "" | |
# select an instance from batch | |
y_hat = y_hat[0].squeeze().detach().cpu().numpy() | |
y = y[0].squeeze().detach().cpu().numpy() | |
spec_fake = ap.melspectrogram(y_hat).T | |
spec_real = ap.melspectrogram(y).T | |
spec_diff = np.abs(spec_fake - spec_real) | |
# plot figure and save it | |
fig_wave = plt.figure() | |
plt.subplot(2, 1, 1) | |
plt.plot(y) | |
plt.title("groundtruth speech") | |
plt.subplot(2, 1, 2) | |
plt.plot(y_hat) | |
plt.title("generated speech") | |
plt.tight_layout() | |
plt.close() | |
figures = { | |
name_prefix + "spectrogram/fake": plot_spectrogram(spec_fake), | |
name_prefix + "spectrogram/real": plot_spectrogram(spec_real), | |
name_prefix + "spectrogram/diff": plot_spectrogram(spec_diff), | |
name_prefix + "speech_comparison": fig_wave, | |
} | |
return figures | |