Emotion_Aware_TTS / dataset.py
Ionut-Bostan's picture
Upload 82 files
feec0bf
raw
history blame contribute delete
No virus
8.71 kB
import json
import math
import os
import numpy as np
from torch.utils.data import Dataset
from text import text_to_sequence
from utils.tools import pad_1D, pad_2D
class Dataset(Dataset):
def __init__(
self, filename, preprocess_config, train_config, sort=False, drop_last=False
):
self.dataset_name = preprocess_config["dataset"]
self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.batch_size = train_config["optimizer"]["batch_size"]
self.basename, self.speaker, self.text, self.raw_text, self.emotion = self.process_meta(
filename
)
with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
self.speaker_map = json.load(f)
with open(os.path.join(self.preprocessed_path, "emotions.json")) as f:
self.emotion_map = json.load(f)
self.sort = sort
self.drop_last = drop_last
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
emotion = self.emotion[idx]
speaker_id = self.speaker_map[speaker]
emotion_id = self.emotion_map[emotion]
raw_text = self.raw_text[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
mel_path = os.path.join(
self.preprocessed_path,
"mel",
"{}-mel-{}.npy".format(speaker, basename),
)
mel = np.load(mel_path)
pitch_path = os.path.join(
self.preprocessed_path,
"pitch",
"{}-pitch-{}.npy".format(speaker, basename),
)
pitch = np.load(pitch_path)
energy_path = os.path.join(
self.preprocessed_path,
"energy",
"{}-energy-{}.npy".format(speaker, basename),
)
energy = np.load(energy_path)
duration_path = os.path.join(
self.preprocessed_path,
"duration",
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path)
sample = {
"id": basename,
"speaker": speaker_id,
"emotion": emotion_id,
"text": phone,
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
"energy": energy,
"duration": duration,
}
return sample
def process_meta(self, filename):
with open(
os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8"
) as f:
name = []
speaker = []
emotion = []
text = []
raw_text = []
for line in f.readlines():
n, s, t, r, e = line.strip("\n").split("|")
name.append(n)
speaker.append(s)
text.append(t)
raw_text.append(r)
emotion.append(e)
return name, speaker, text, raw_text, emotion
def reprocess(self, data, idxs):
ids = [data[idx]["id"] for idx in idxs]
speakers = [data[idx]["speaker"] for idx in idxs]
emotions = [data[idx]["emotion"] for idx in idxs]
texts = [data[idx]["text"] for idx in idxs]
raw_texts = [data[idx]["raw_text"] for idx in idxs]
mels = [data[idx]["mel"] for idx in idxs]
pitches = [data[idx]["pitch"] for idx in idxs]
energies = [data[idx]["energy"] for idx in idxs]
durations = [data[idx]["duration"] for idx in idxs]
text_lens = np.array([text.shape[0] for text in texts])
mel_lens = np.array([mel.shape[0] for mel in mels])
speakers = np.array(speakers)
emotions = np.array(emotions)
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
energies = pad_1D(energies)
durations = pad_1D(durations)
return (
ids,
raw_texts,
speakers,
texts,
text_lens,
max(text_lens),
emotions,
mels,
mel_lens,
max(mel_lens),
pitches,
energies,
durations,
)
def collate_fn(self, data):
data_size = len(data)
if self.sort:
len_arr = np.array([d["text"].shape[0] for d in data])
idx_arr = np.argsort(-len_arr)
else:
idx_arr = np.arange(data_size)
tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size):]
idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]
idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
if not self.drop_last and len(tail) > 0:
idx_arr += [tail.tolist()]
output = list()
for idx in idx_arr:
output.append(self.reprocess(data, idx))
return output
class TextDataset(Dataset):
def __init__(self, filepath, preprocess_config):
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.basename, self.speaker, self.text, self.raw_text, self.emotion = self.process_meta(
filepath
)
with open(
os.path.join(
preprocess_config["path"]["preprocessed_path"], "speakers.json"
)
) as f:
self.speaker_map = json.load(f)
with open(os.path.join(
preprocess_config["path"]["preprocessed"], "emotions.json"
)
) as f:
self.emotion_map = json.load(f)
def __len__(self):
return len(self.text)
def __getitem__(self, idx):
basename = self.basename[idx]
speaker = self.speaker[idx]
speaker_id = self.speaker_map[speaker]
raw_text = self.raw_text[idx]
emotion = self.emotion[idx]
phone = np.array(text_to_sequence(self.text[idx], self.cleaners))
return (basename, speaker_id, phone, raw_text, emotion)
def process_meta(self, filename):
with open(filename, "r", encoding="utf-8") as f:
name = []
speaker = []
text = []
raw_text = []
emotion = []
for line in f.readlines():
n, s, t, r, e = line.strip("\n").split("|")
name.append(n)
speaker.append(s)
text.append(t)
raw_text.append(r)
emotion.append(e)
return name, speaker, text, raw_text, emotion
def collate_fn(self, data):
ids = [d[0] for d in data]
speakers = np.array([d[1] for d in data])
texts = [d[2] for d in data]
raw_texts = [d[3] for d in data]
emotions = [d[4] for d in data]
text_lens = np.array([text.shape[0] for text in texts])
texts = pad_1D(texts)
return ids, raw_texts, speakers, texts, emotions, text_lens, max(text_lens)
if __name__ == "__main__":
# Test
import torch
import yaml
from torch.utils.data import DataLoader
from utils.utils import to_device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocess_config = yaml.load(
open("./config/LJSpeech/preprocess.yaml", "r"), Loader=yaml.FullLoader
)
train_config = yaml.load(
open("./config/LJSpeech/train.yaml", "r"), Loader=yaml.FullLoader
)
train_dataset = Dataset(
"train.txt", preprocess_config, train_config, sort=True, drop_last=True
)
val_dataset = Dataset(
"val.txt", preprocess_config, train_config, sort=False, drop_last=False
)
train_loader = DataLoader(
train_dataset,
batch_size=train_config["optimizer"]["batch_size"] * 4,
shuffle=True,
collate_fn=train_dataset.collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=train_config["optimizer"]["batch_size"],
shuffle=False,
collate_fn=val_dataset.collate_fn,
)
n_batch = 0
for batchs in train_loader:
for batch in batchs:
to_device(batch, device)
n_batch += 1
print(
"Training set with size {} is composed of {} batches.".format(
len(train_dataset), n_batch
)
)
n_batch = 0
for batchs in val_loader:
for batch in batchs:
to_device(batch, device)
n_batch += 1
print(
"Validation set with size {} is composed of {} batches.".format(
len(val_dataset), n_batch
)
)