import os import json import torch import torch.nn.functional as F import numpy as np import matplotlib from scipy.io import wavfile from matplotlib import pyplot as plt matplotlib.use("Agg") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def to_device(data, device): if len(data) == 13: ( ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions, mels, mel_lens, max_mel_len, pitches, energies, durations, ) = data speakers = torch.from_numpy(speakers).long().to(device) texts = torch.from_numpy(texts).long().to(device) emotions = torch.from_numpy(emotions).long().to(device) src_lens = torch.from_numpy(src_lens).to(device) mels = torch.from_numpy(mels).float().to(device) mel_lens = torch.from_numpy(mel_lens).to(device) pitches = torch.from_numpy(pitches).float().to(device) energies = torch.from_numpy(energies).to(device) durations = torch.from_numpy(durations).long().to(device) return ( ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions, mels, mel_lens, max_mel_len, pitches, energies, durations, ) if len(data) == 6: (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data speakers = torch.from_numpy(speakers).long().to(device) texts = torch.from_numpy(texts).long().to(device) src_lens = torch.from_numpy(src_lens).to(device) return (ids, raw_texts, speakers, texts, src_lens, max_src_len) if len(data) == 7: (ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions) = data speakers = torch.from_numpy(speakers).long().to(device) emotions = torch.from_numpy(emotions).long().to(device) texts = torch.from_numpy(texts).long().to(device) src_lens = torch.from_numpy(src_lens).to(device) return (ids, raw_texts, speakers, texts, src_lens, max_src_len, emotions) def log( logger, step=None, losses=None, fig=None, audio=None, sampling_rate=22050, tag="" ): if losses is not None: logger.add_scalar("Loss/total_loss", losses[0], step) logger.add_scalar("Loss/mel_loss", losses[1], step) logger.add_scalar("Loss/mel_postnet_loss", losses[2], step) logger.add_scalar("Loss/pitch_loss", losses[3], step) logger.add_scalar("Loss/energy_loss", losses[4], step) logger.add_scalar("Loss/duration_loss", losses[5], step) if fig is not None: logger.add_figure(tag, fig) if audio is not None: logger.add_audio( tag, audio / max(abs(audio)), sample_rate=sampling_rate, ) def get_mask_from_lengths(lengths, max_len=None): batch_size = lengths.shape[0] if max_len is None: max_len = torch.max(lengths).item() ids = torch.arange(0, max_len).unsqueeze( 0).expand(batch_size, -1).to(device) mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) return mask def expand(values, durations): out = list() for value, d in zip(values, durations): out += [value] * max(0, int(d)) return np.array(out) def synth_one_sample(targets, predictions, vocoder, model_config, preprocess_config): basename = targets[0][0] src_len = predictions[8][0].item() mel_len = predictions[9][0].item() mel_target = targets[7][0, :mel_len].detach().transpose(0, 1) mel_prediction = predictions[1][0, :mel_len].detach().transpose(0, 1) duration = targets[12][0, :src_len].detach().cpu().numpy() if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": pitch = targets[10][0, :src_len].detach().cpu().numpy() pitch = expand(pitch, duration) else: pitch = targets[10][0, :mel_len].detach().cpu().numpy() if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": energy = targets[11][0, :src_len].detach().cpu().numpy() energy = expand(energy, duration) else: energy = targets[11][0, :mel_len].detach().cpu().numpy() with open( os.path.join(preprocess_config["path"] ["preprocessed_path"], "stats.json") ) as f: stats = json.load(f) stats = stats["pitch"] + stats["energy"][:2] fig = plot_mel( [ (mel_prediction.cpu().numpy(), pitch, energy), (mel_target.cpu().numpy(), pitch, energy), ], stats, ["Synthetized Spectrogram", "Ground-Truth Spectrogram"], ) if vocoder is not None: from .model import vocoder_infer wav_reconstruction = vocoder_infer( mel_target.unsqueeze(0), vocoder, model_config, preprocess_config, )[0] wav_prediction = vocoder_infer( mel_prediction.unsqueeze(0), vocoder, model_config, preprocess_config, )[0] else: wav_reconstruction = wav_prediction = None return fig, wav_reconstruction, wav_prediction, basename def synth_samples(targets, predictions, vocoder, model_config, preprocess_config, path): basenames = targets[0] for i in range(len(predictions[0])): basename = basenames[i] src_len = predictions[8][i].item() mel_len = predictions[9][i].item() mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1) duration = predictions[5][i, :src_len].detach().cpu().numpy() if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level": pitch = predictions[2][i, :src_len].detach().cpu().numpy() pitch = expand(pitch, duration) else: pitch = predictions[2][i, :mel_len].detach().cpu().numpy() if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level": energy = predictions[3][i, :src_len].detach().cpu().numpy() energy = expand(energy, duration) else: energy = predictions[3][i, :mel_len].detach().cpu().numpy() with open( os.path.join(preprocess_config["path"] ["preprocessed_path"], "stats.json") ) as f: stats = json.load(f) stats = stats["pitch"] + stats["energy"][:2] fig = plot_mel( [ (mel_prediction.cpu().numpy(), pitch, energy), ], stats, ["Synthetized Spectrogram"], ) plt.savefig(os.path.join(path, "{}.png".format(basename))) plt.close() from .model import vocoder_infer mel_predictions = predictions[1].transpose(1, 2) lengths = predictions[9] * \ preprocess_config["preprocessing"]["stft"]["hop_length"] wav_predictions = vocoder_infer( mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths ) sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] for wav, basename in zip(wav_predictions, basenames): wavfile.write(os.path.join( path, "{}.wav".format(basename)), sampling_rate, wav) def plot_mel(data, stats, titles): fig, axes = plt.subplots(len(data), 1, squeeze=False) if titles is None: titles = [None for i in range(len(data))] pitch_min, pitch_max, pitch_mean, pitch_std, energy_min, energy_max = stats pitch_min = pitch_min * pitch_std + pitch_mean pitch_max = pitch_max * pitch_std + pitch_mean def add_axis(fig, old_ax): ax = fig.add_axes(old_ax.get_position(), anchor="W") ax.set_facecolor("None") return ax for i in range(len(data)): mel, pitch, energy = data[i] pitch = pitch * pitch_std + pitch_mean axes[i][0].imshow(mel, origin="lower") axes[i][0].set_aspect(2.5, adjustable="box") axes[i][0].set_ylim(0, mel.shape[0]) axes[i][0].set_title(titles[i], fontsize="medium") axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False) axes[i][0].set_anchor("W") ax1 = add_axis(fig, axes[i][0]) ax1.plot(pitch, color="tomato") ax1.set_xlim(0, mel.shape[1]) ax1.set_ylim(0, pitch_max) ax1.set_ylabel("F0", color="tomato") ax1.tick_params( labelsize="x-small", colors="tomato", bottom=False, labelbottom=False ) ax2 = add_axis(fig, axes[i][0]) ax2.plot(energy, color="darkviolet") ax2.set_xlim(0, mel.shape[1]) ax2.set_ylim(energy_min, energy_max) ax2.set_ylabel("Energy", color="darkviolet") ax2.yaxis.set_label_position("right") ax2.tick_params( labelsize="x-small", colors="darkviolet", bottom=False, labelbottom=False, left=False, labelleft=False, right=True, labelright=True, ) return fig def pad_1D(inputs, PAD=0): def pad_data(x, length, PAD): x_padded = np.pad( x, (0, length - x.shape[0]), mode="constant", constant_values=PAD ) return x_padded max_len = max((len(x) for x in inputs)) padded = np.stack([pad_data(x, max_len, PAD) for x in inputs]) return padded def pad_2D(inputs, maxlen=None): def pad(x, max_len): PAD = 0 if np.shape(x)[0] > max_len: raise ValueError("not max_len") s = np.shape(x)[1] x_padded = np.pad( x, (0, max_len - np.shape(x)[0]), mode="constant", constant_values=PAD ) return x_padded[:, :s] if maxlen: output = np.stack([pad(x, maxlen) for x in inputs]) else: max_len = max(np.shape(x)[0] for x in inputs) output = np.stack([pad(x, max_len) for x in inputs]) return output def pad(input_ele, mel_max_length=None): if mel_max_length: max_len = mel_max_length else: max_len = max([input_ele[i].size(0) for i in range(len(input_ele))]) out_list = list() for i, batch in enumerate(input_ele): if len(batch.shape) == 1: one_batch_padded = F.pad( batch, (0, max_len - batch.size(0)), "constant", 0.0 ) elif len(batch.shape) == 2: one_batch_padded = F.pad( batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0 ) out_list.append(one_batch_padded) out_padded = torch.stack(out_list) return out_padded def get_roberta_emotion_embeddings(tokenizer, model, text): model.to(device) tokenized_input = tokenizer(text, padding='max_length', max_length=128, truncation=True, return_tensors="pt") input_ids = tokenized_input['input_ids'].to(model.device) attention_mask = tokenized_input['attention_mask'].to(model.device) emotions = "amused", "anger", "disgust", "neutral", "sleepiness" with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) embeddings = outputs.logits # get the index of the predicted emotion emotion_index = torch.argmax(embeddings, dim=1).item() # get the corresponding emotion from the list predicted_emotion = emotions[emotion_index] print("Predicted emotion:", predicted_emotion) return embeddings