Spaces:
Runtime error
Runtime error
"""Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/""" | |
import os | |
import re | |
import codecs | |
import unicodedata | |
import numpy as np | |
from torch.utils.data import Dataset | |
vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS. | |
char2idx = {char: idx for idx, char in enumerate(vocab)} | |
idx2char = {idx: char for idx, char in enumerate(vocab)} | |
def text_normalize(text): | |
text = ''.join(char for char in unicodedata.normalize('NFD', text) | |
if unicodedata.category(char) != 'Mn') # Strip accents | |
text = text.lower() | |
text = re.sub("[^{}]".format(vocab), " ", text) | |
text = re.sub("[ ]+", " ", text) | |
return text | |
def read_metadata(metadata_file): | |
fnames, text_lengths, texts = [], [], [] | |
transcript = os.path.join(metadata_file) | |
lines = codecs.open(transcript, 'r', 'utf-8').readlines() | |
for line in lines: | |
fname, _, text = line.strip().split("|") | |
fnames.append(fname) | |
text = text_normalize(text) + "E" # E: EOS | |
text = [char2idx[char] for char in text] | |
text_lengths.append(len(text)) | |
texts.append(np.array(text, np.long)) | |
return fnames, text_lengths, texts | |
def get_test_data(sentences, max_n): | |
normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS | |
texts = np.zeros((len(normalized_sentences), max_n + 1), np.long) | |
for i, sent in enumerate(normalized_sentences): | |
texts[i, :len(sent)] = [char2idx[char] for char in sent] | |
return texts | |
class LJSpeech(Dataset): | |
def __init__(self, keys, dir_name='LJSpeech-1.1'): | |
self.keys = keys | |
self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) | |
self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv')) | |
def slice(self, start, end): | |
self.fnames = self.fnames[start:end] | |
self.text_lengths = self.text_lengths[start:end] | |
self.texts = self.texts[start:end] | |
def __len__(self): | |
return len(self.fnames) | |
def __getitem__(self, index): | |
data = {} | |
if 'texts' in self.keys: | |
data['texts'] = self.texts[index] | |
if 'mels' in self.keys: | |
# (39, 80) | |
data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) | |
if 'mags' in self.keys: | |
# (39, 80) | |
data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) | |
if 'mel_gates' in self.keys: | |
data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int) # TODO: because pre processing! | |
if 'mag_gates' in self.keys: | |
data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int) # TODO: because pre processing! | |
return data | |