import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from scipy.io import wavfile
from vocoder.vocgan_generator import Generator
import hparams as hp
import os
import text
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_alignment(tier):
sil_phones = ['sil', 'sp', 'spn']
phones = []
durations = []
start_time = 0
end_time = 0
end_idx = 0
for t in tier._objects:
s, e, p = t.start_time, t.end_time, t.text
# Trimming leading silences
if phones == []:
if p in sil_phones:
start_time = s
if p not in sil_phones:
end_time = e
end_idx = len(phones)
# Trimming tailing silences
phones = phones[:end_idx]
durations = durations[:end_idx]
return phones, np.array(durations), start_time, end_time
def process_meta(meta_path):
with open(meta_path, "r", encoding="utf-8") as f:
text = []
name = []
for line in f.readlines():
n, t = line.strip('\n').split('|')
return name, text
def get_param_num(model):
num_param = sum(param.numel() for param in model.parameters())
return num_param
def plot_data(data, titles=None, filename=None):
fig, axes = plt.subplots(len(data), 1, squeeze=False)
if titles is None:
titles = [None for i in range(len(data))]
def add_axis(fig, old_ax, offset=0):
ax = fig.add_axes(old_ax.get_position(), anchor='W')
return ax
for i in range(len(data)):
spectrogram, pitch, energy = data[i]
axes[i][0].imshow(spectrogram, origin='lower')
axes[i][0].set_aspect(2.5, adjustable='box')
axes[i][0].set_ylim(0, hp.n_mel_channels)
axes[i][0].set_title(titles[i], fontsize='medium')
axes[i][0].tick_params(labelsize='x-small', left=False, labelleft=False)
ax1 = add_axis(fig, axes[i][0])
ax1.plot(pitch, color='tomato')
ax1.set_xlim(0, spectrogram.shape[1])
ax1.set_ylim(0, hp.f0_max)
ax1.set_ylabel('F0', color='tomato')
ax1.tick_params(labelsize='x-small', colors='tomato', bottom=False, labelbottom=False)
ax2 = add_axis(fig, axes[i][0], 1.2)
ax2.plot(energy, color='darkviolet')
ax2.set_xlim(0, spectrogram.shape[1])
ax2.set_ylim(hp.energy_min, hp.energy_max)
ax2.set_ylabel('Energy', color='darkviolet')
ax2.tick_params(labelsize='x-small', colors='darkviolet', bottom=False, labelbottom=False, left=False, labelleft=False, right=True, labelright=True)
plt.savefig(filename, dpi=200)
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)
mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len))
return mask
def get_vocgan(ckpt_path, n_mel_channels=hp.n_mel_channels, generator_ratio = [4, 4, 2, 2, 2, 2], n_residual_layers=4, mult=256, out_channels=1):
checkpoint = torch.load(ckpt_path)
model = Generator(n_mel_channels, n_residual_layers,
ratios=generator_ratio, mult=mult,
return model
def get_vocoder():
def vocgan_infer(mel, vocoder, path):
model = vocoder
with torch.no_grad():
if len(mel.shape) == 2:
mel = mel.unsqueeze(0)
audio = model.infer(mel).squeeze()
audio = hp.max_wav_value * audio[:-(hp.hop_length*10)]
audio = audio.clamp(min=-hp.max_wav_value, max=hp.max_wav_value-1)
audio = audio.short().cpu().detach().numpy()
wavfile.write(path, hp.sampling_rate, audio)
def pad_1D(inputs, PAD=0):
def pad_data(x, length, PAD):
x_padded = np.pad(x, (0, length - x.shape[0]),
return x_padded
max_len = max((len(x) for x in inputs))
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])
return padded
def pad_2D(inputs, maxlen=None):
def pad(x, max_len):
PAD = 0
if np.shape(x)[0] > max_len:
raise ValueError("not max_len")
s = np.shape(x)[1]
x_padded = np.pad(x, (0, max_len - np.shape(x)[0]),
return x_padded[:, :s]
if maxlen:
output = np.stack([pad(x, maxlen) for x in inputs])
max_len = max(np.shape(x)[0] for x in inputs)
output = np.stack([pad(x, max_len) for x in inputs])
return output
def pad(input_ele, mel_max_length=None):
if mel_max_length:
max_len = mel_max_length
max_len = max([input_ele[i].size(0)for i in range(len(input_ele))])
out_list = list()
for i, batch in enumerate(input_ele):
if len(batch.shape) == 1:
one_batch_padded = F.pad(
batch, (0, max_len-batch.size(0)), "constant", 0.0)
elif len(batch.shape) == 2:
one_batch_padded = F.pad(
batch, (0, 0, 0, max_len-batch.size(0)), "constant", 0.0)
out_padded = torch.stack(out_list)
return out_padded
# from dathudeptrai's FastSpeech2 implementation
def standard_norm(x, mean, std, is_mel=False):
if not is_mel:
x = remove_outlier(x)
zero_idxs = np.where(x == 0.0)[0]
x = (x - mean) / std
x[zero_idxs] = 0.0
return x
return (x - mean) / std
def de_norm(x, mean, std):
zero_idxs = torch.where(x == 0.0)[0]
x = mean + std * x
x[zero_idxs] = 0.0
return x
def _is_outlier(x, p25, p75):
"""Check if value is an outlier."""
lower = p25 - 1.5 * (p75 - p25)
upper = p75 + 1.5 * (p75 - p25)
return np.logical_or(x <= lower, x >= upper)
def remove_outlier(x):
"""Remove outlier from x."""
p25 = np.percentile(x, 25)
p75 = np.percentile(x, 75)
indices_of_outliers = []
for ind, value in enumerate(x):
if _is_outlier(value, p25, p75):
x[indices_of_outliers] = 0.0
# replace by mean f0.
x[indices_of_outliers] = np.max(x)
return x
def average_by_duration(x, durs):
mel_len = durs.sum()
durs_cum = np.cumsum(np.pad(durs, (1, 0)))
# calculate charactor f0/energy
x_char = np.zeros((durs.shape[0],), dtype=np.float32)
for idx, start, end in zip(range(mel_len), durs_cum[:-1], durs_cum[1:]):
values = x[start:end][np.where(x[start:end] != 0.0)[0]]
x_char[idx] = np.mean(values) if len(values) > 0 else 0.0 # np.mean([]) = nan.
return x_char.astype(np.float32)