"""Data loader for the Emovdb dataset. See: https://github.com/numediart/EmoV-DB""" import os import re import codecs import unicodedata import numpy as np from audio import preprocess from torch.utils.data import Dataset vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS. char2idx = {char: idx for idx, char in enumerate(vocab)} idx2char = {idx: char for idx, char in enumerate(vocab)} def text_normalize(text): text = ''.join(char for char in unicodedata.normalize('NFD', text) if unicodedata.category(char) != 'Mn') # Strip accents text = text.lower() text = re.sub("[^{}]".format(vocab), " ", text) text = re.sub("[ ]+", " ", text) return text def read_metadata(metadata_file): fnames, text_lengths, texts = [], [], [] transcript = os.path.join(metadata_file) lines = codecs.open(transcript, 'r', 'utf-8').readlines() for line in lines: fname, text = line.strip().split("|") fnames.append(fname) text = text_normalize(text) + "E" # E: EOS text = [char2idx[char] for char in text] text_lengths.append(len(text)) texts.append(np.array(text, np.long)) return fnames, text_lengths, texts def get_test_data(sentences, max_n): normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS texts = np.zeros((len(normalized_sentences), max_n + 1), np.long) for i, sent in enumerate(normalized_sentences): texts[i, :len(sent)] = [char2idx[char] for char in sent] return texts class Emovdb(Dataset): def __init__(self, keys, dir_name='/home/brihi16142/work2/processed_emovdb_disgust'): self.keys = keys self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name) self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'transcript_bea.csv')) preprocess(dir_name, self) print('Generated mels and mags') def slice(self, start, end): self.fnames = self.fnames[start:end] self.text_lengths = self.text_lengths[start:end] self.texts = self.texts[start:end] def __len__(self): return len(self.fnames) def __getitem__(self, index): data = {} if 'texts' in self.keys: data['texts'] = self.texts[index] if 'mels' in self.keys: # (39, 80) data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index])) if 'mags' in self.keys: # (39, 80) data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index])) if 'mel_gates' in self.keys: data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int) # TODO: because pre processing! if 'mag_gates' in self.keys: data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int) # TODO: because pre processing! return data