|
import os |
|
import json |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import matplotlib |
|
from scipy.io import wavfile |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
matplotlib.use("Agg") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
def to_device(data, device): |
|
if len(data) == 13: |
|
( |
|
ids, |
|
raw_texts, |
|
speakers, |
|
texts, |
|
src_lens, |
|
max_src_len, |
|
emotions, |
|
mels, |
|
mel_lens, |
|
max_mel_len, |
|
pitches, |
|
energies, |
|
durations, |
|
) = data |
|
|
|
speakers = torch.from_numpy(speakers).long().to(device) |
|
texts = torch.from_numpy(texts).long().to(device) |
|
emotions = torch.from_numpy(emotions).long().to(device) |
|
src_lens = torch.from_numpy(src_lens).to(device) |
|
mels = torch.from_numpy(mels).float().to(device) |
|
mel_lens = torch.from_numpy(mel_lens).to(device) |
|
pitches = torch.from_numpy(pitches).float().to(device) |
|
energies = torch.from_numpy(energies).to(device) |
|
durations = torch.from_numpy(durations).long().to(device) |
|
|
|
return ( |
|
ids, |
|
raw_texts, |
|
speakers, |
|
texts, |
|
src_lens, |
|
max_src_len, |
|
emotions, |
|
mels, |
|
mel_lens, |
|
max_mel_len, |
|
pitches, |
|
energies, |
|
durations, |
|
) |
|
|
|
if len(data) == 6: |
|
(ids, raw_texts, speakers, texts, src_lens, max_src_len) = data |
|
|
|
speakers = torch.from_numpy(speakers).long().to(device) |
|
texts = torch.from_numpy(texts).long().to(device) |
|
src_lens = torch.from_numpy(src_lens).to(device) |
|
|
|
return (ids, raw_texts, speakers, texts, src_lens, max_src_len) |
|
|
|
if len(data) == 7: |
|
(ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions) = data |
|
|
|
speakers = torch.from_numpy(speakers).long().to(device) |
|
emotions = torch.from_numpy(emotions).long().to(device) |
|
texts = torch.from_numpy(texts).long().to(device) |
|
src_lens = torch.from_numpy(src_lens).to(device) |
|
|
|
return (ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions) |
|
|
|
|
|
def log( |
|
logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" |
|
): |
|
if losses is not None: |
|
logger.add_scalar("Loss/total_loss", losses[0], step) |
|
logger.add_scalar("Loss/mel_loss", losses[1], step) |
|
logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) |
|
logger.add_scalar("Loss/pitch_loss", losses[3], step) |
|
logger.add_scalar("Loss/energy_loss", losses[4], step) |
|
logger.add_scalar("Loss/duration_loss", losses[5], step) |
|
|
|
if fig is not None: |
|
logger.add_figure(tag, fig) |
|
|
|
if audio is not None: |
|
logger.add_audio( |
|
tag, |
|
audio / max(abs(audio)), |
|
sample_rate=sampling_rate, |
|
) |
|
|
|
|
|
def get_mask_from_lengths(lengths, max_len=None): |
|
batch_size = lengths.shape[0] |
|
if max_len is None: |
|
max_len = torch.max(lengths).item() |
|
|
|
ids = torch.arange(0, max_len).unsqueeze( |
|
0).expand(batch_size, -1).to(device) |
|
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) |
|
|
|
return mask |
|
|
|
|
|
def expand(values, durations): |
|
out = list() |
|
for value, d in zip(values, durations): |
|
out += [value] * max(0, int(d)) |
|
return np.array(out) |
|
|
|
|
|
def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): |
|
|
|
basename = targets[0][0] |
|
src_len = predictions[8][0].item() |
|
mel_len = predictions[9][0].item() |
|
mel_target = targets[7][0, :mel_len].detach().transpose(0, 1) |
|
mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) |
|
duration = targets[12][0, :src_len].detach().cpu().numpy() |
|
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": |
|
pitch = targets[10][0, :src_len].detach().cpu().numpy() |
|
pitch = expand(pitch, duration) |
|
else: |
|
pitch = targets[10][0, :mel_len].detach().cpu().numpy() |
|
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": |
|
energy = targets[11][0, :src_len].detach().cpu().numpy() |
|
energy = expand(energy, duration) |
|
else: |
|
energy = targets[11][0, :mel_len].detach().cpu().numpy() |
|
|
|
with open( |
|
os.path.join(preprocess_config["path"] |
|
["preprocessed_path"], "stats.json") |
|
) as f: |
|
stats = json.load(f) |
|
stats = stats["pitch"] + stats["energy"][:2] |
|
|
|
fig = plot_mel( |
|
[ |
|
(mel_prediction.cpu().numpy(), pitch, energy), |
|
(mel_target.cpu().numpy(), pitch, energy), |
|
], |
|
stats, |
|
["Synthetized Spectrogram", "Ground-Truth Spectrogram"], |
|
) |
|
|
|
if vocoder is not None: |
|
from .model import vocoder_infer |
|
|
|
wav_reconstruction = vocoder_infer( |
|
mel_target.unsqueeze(0), |
|
vocoder, |
|
model_config, |
|
preprocess_config, |
|
)[0] |
|
wav_prediction = vocoder_infer( |
|
mel_prediction.unsqueeze(0), |
|
vocoder, |
|
model_config, |
|
preprocess_config, |
|
)[0] |
|
else: |
|
wav_reconstruction = wav_prediction = None |
|
|
|
return fig, wav_reconstruction, wav_prediction, basename |
|
|
|
|
|
def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): |
|
|
|
basenames = targets[0] |
|
for i in range(len(predictions[0])): |
|
basename = basenames[i] |
|
src_len = predictions[8][i].item() |
|
mel_len = predictions[9][i].item() |
|
mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) |
|
duration = predictions[5][i, :src_len].detach().cpu().numpy() |
|
if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": |
|
pitch = predictions[2][i, :src_len].detach().cpu().numpy() |
|
pitch = expand(pitch, duration) |
|
else: |
|
pitch = predictions[2][i, :mel_len].detach().cpu().numpy() |
|
if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": |
|
energy = predictions[3][i, :src_len].detach().cpu().numpy() |
|
energy = expand(energy, duration) |
|
else: |
|
energy = predictions[3][i, :mel_len].detach().cpu().numpy() |
|
|
|
with open( |
|
os.path.join(preprocess_config["path"] |
|
["preprocessed_path"], "stats.json") |
|
) as f: |
|
stats = json.load(f) |
|
stats = stats["pitch"] + stats["energy"][:2] |
|
|
|
fig = plot_mel( |
|
[ |
|
(mel_prediction.cpu().numpy(), pitch, energy), |
|
], |
|
stats, |
|
["Synthetized Spectrogram"], |
|
) |
|
plt.savefig(os.path.join(path, "{}.png".format(basename))) |
|
plt.close() |
|
|
|
from .model import vocoder_infer |
|
|
|
mel_predictions = predictions[1].transpose(1, 2) |
|
lengths = predictions[9] * \ |
|
preprocess_config["preprocessing"]["stft"]["hop_length"] |
|
wav_predictions = vocoder_infer( |
|
mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths |
|
) |
|
|
|
sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] |
|
for wav, basename in zip(wav_predictions, basenames): |
|
wavfile.write(os.path.join( |
|
path, "{}.wav".format(basename)), sampling_rate, wav) |
|
|
|
|
|
def plot_mel(data, stats, titles): |
|
fig, axes = plt.subplots(len(data), 1, squeeze=False) |
|
if titles is None: |
|
titles = [None for i in range(len(data))] |
|
pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats |
|
pitch_min = pitch_min * pitch_std + pitch_mean |
|
pitch_max = pitch_max * pitch_std + pitch_mean |
|
|
|
def add_axis(fig, old_ax): |
|
ax = fig.add_axes(old_ax.get_position(), anchor="W") |
|
ax.set_facecolor("None") |
|
return ax |
|
|
|
for i in range(len(data)): |
|
mel, pitch, energy = data[i] |
|
pitch = pitch * pitch_std + pitch_mean |
|
axes[i][0].imshow(mel, origin="lower") |
|
axes[i][0].set_aspect(2.5, adjustable="box") |
|
axes[i][0].set_ylim(0, mel.shape[0]) |
|
axes[i][0].set_title(titles[i], fontsize="medium") |
|
axes[i][0].tick_params(labelsize="x-small", |
|
left=False, labelleft=False) |
|
axes[i][0].set_anchor("W") |
|
|
|
ax1 = add_axis(fig, axes[i][0]) |
|
ax1.plot(pitch, color="tomato") |
|
ax1.set_xlim(0, mel.shape[1]) |
|
ax1.set_ylim(0, pitch_max) |
|
ax1.set_ylabel("F0", color="tomato") |
|
ax1.tick_params( |
|
labelsize="x-small", colors="tomato", bottom=False, labelbottom=False |
|
) |
|
|
|
ax2 = add_axis(fig, axes[i][0]) |
|
ax2.plot(energy, color="darkviolet") |
|
ax2.set_xlim(0, mel.shape[1]) |
|
ax2.set_ylim(energy_min, energy_max) |
|
ax2.set_ylabel("Energy", color="darkviolet") |
|
ax2.yaxis.set_label_position("right") |
|
ax2.tick_params( |
|
labelsize="x-small", |
|
colors="darkviolet", |
|
bottom=False, |
|
labelbottom=False, |
|
left=False, |
|
labelleft=False, |
|
right=True, |
|
labelright=True, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def pad_1D(inputs, PAD=0): |
|
def pad_data(x, length, PAD): |
|
x_padded = np.pad( |
|
x, (0, length - x.shape[0]), mode="constant", constant_values=PAD |
|
) |
|
return x_padded |
|
|
|
max_len = max((len(x) for x in inputs)) |
|
padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) |
|
|
|
return padded |
|
|
|
|
|
def pad_2D(inputs, maxlen=None): |
|
def pad(x, max_len): |
|
PAD = 0 |
|
if np.shape(x)[0] > max_len: |
|
raise ValueError("not max_len") |
|
|
|
s = np.shape(x)[1] |
|
x_padded = np.pad( |
|
x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD |
|
) |
|
return x_padded[:, :s] |
|
|
|
if maxlen: |
|
output = np.stack([pad(x, maxlen) for x in inputs]) |
|
else: |
|
max_len = max(np.shape(x)[0] for x in inputs) |
|
output = np.stack([pad(x, max_len) for x in inputs]) |
|
|
|
return output |
|
|
|
|
|
def pad(input_ele, mel_max_length=None): |
|
if mel_max_length: |
|
max_len = mel_max_length |
|
else: |
|
max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) |
|
|
|
out_list = list() |
|
for i, batch in enumerate(input_ele): |
|
if len(batch.shape) == 1: |
|
one_batch_padded = F.pad( |
|
batch, (0, max_len - batch.size(0)), "constant", 0.0 |
|
) |
|
elif len(batch.shape) == 2: |
|
one_batch_padded = F.pad( |
|
batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 |
|
) |
|
out_list.append(one_batch_padded) |
|
out_padded = torch.stack(out_list) |
|
return out_padded |
|
|
|
def get_roberta_emotion_embeddings(tokenizer, model, text): |
|
model.to(device) |
|
tokenized_input = tokenizer(text, padding='max_length', max_length=128, truncation=True, return_tensors="pt") |
|
input_ids = tokenized_input['input_ids'].to(model.device) |
|
attention_mask = tokenized_input['attention_mask'].to(model.device) |
|
|
|
emotions = "amused", "anger", "disgust", "neutral", "sleepiness" |
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask=attention_mask) |
|
embeddings = outputs.logits |
|
|
|
|
|
emotion_index = torch.argmax(embeddings, dim=1).item() |
|
|
|
|
|
predicted_emotion = emotions[emotion_index] |
|
print("Predicted emotion:", predicted_emotion) |
|
return embeddings |
|
|