sudip1310 commited on
Commit
78655fa
1 Parent(s): e460b84

Upload 6 files

Browse files
datasets/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ LJSpeech-1.1/
2
+ MBSpeech-1.0/
3
+ *.tar.gz
datasets/__init__.py ADDED
File without changes
datasets/data_loader.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data.dataloader import default_collate, DataLoader
6
+ from torch.utils.data.sampler import Sampler
7
+
8
+ __all__ = ['Text2MelDataLoader', 'SSRNDataLoader']
9
+
10
+
11
+ class Text2MelDataLoader(DataLoader):
12
+ def __init__(self, text2mel_dataset, batch_size, mode='train', num_workers=8):
13
+ if mode == 'train':
14
+ text2mel_dataset.slice(0, -batch_size)
15
+ elif mode == 'valid':
16
+ text2mel_dataset.slice(len(text2mel_dataset) - batch_size, -1)
17
+ else:
18
+ raise ValueError("mode must be either 'train' or 'valid'")
19
+ super().__init__(text2mel_dataset,
20
+ batch_size=batch_size,
21
+ num_workers=num_workers,
22
+ collate_fn=collate_fn,
23
+ shuffle=True)
24
+
25
+
26
+ class SSRNDataLoader(DataLoader):
27
+ def __init__(self, ssrn_dataset, batch_size, mode='train', num_workers=8):
28
+ if mode == 'train':
29
+ ssrn_dataset.slice(0, -batch_size)
30
+ super().__init__(ssrn_dataset,
31
+ batch_size=batch_size,
32
+ num_workers=num_workers,
33
+ collate_fn=collate_fn,
34
+ sampler=PartiallyRandomizedSimilarTimeLengthSampler(lengths=ssrn_dataset.text_lengths,
35
+ data_source=None,
36
+ batch_size=batch_size))
37
+ elif mode == 'valid':
38
+ ssrn_dataset.slice(len(ssrn_dataset) - batch_size, -1)
39
+ super().__init__(ssrn_dataset,
40
+ batch_size=batch_size,
41
+ num_workers=num_workers,
42
+ collate_fn=collate_fn,
43
+ shuffle=True)
44
+ else:
45
+ raise ValueError("mode must be either 'train' or 'valid'")
46
+
47
+
48
+ def collate_fn(batch):
49
+ keys = batch[0].keys()
50
+ max_lengths = {key: 0 for key in keys}
51
+ collated_batch = {key: [] for key in keys}
52
+
53
+ # find out the max lengths
54
+ for row in batch:
55
+ for key in keys:
56
+ max_lengths[key] = max(max_lengths[key], row[key].shape[0])
57
+
58
+ # pad to the max lengths
59
+ for row in batch:
60
+ for key in keys:
61
+ array = row[key]
62
+ dim = len(array.shape)
63
+ assert dim == 1 or dim == 2
64
+ # TODO: because of pre processing, later we want to have (n_mels, T)
65
+ if dim == 1:
66
+ padded_array = np.pad(array, (0, max_lengths[key] - array.shape[0]), mode='constant')
67
+ else:
68
+ padded_array = np.pad(array, ((0, max_lengths[key] - array.shape[0]), (0, 0)), mode='constant')
69
+ collated_batch[key].append(padded_array)
70
+
71
+ # use the default_collate to convert to tensors
72
+ for key in keys:
73
+ collated_batch[key] = default_collate(collated_batch[key])
74
+ return collated_batch
75
+
76
+
77
+ class PartiallyRandomizedSimilarTimeLengthSampler(Sampler):
78
+ """Copied from: https://github.com/r9y9/deepvoice3_pytorch/blob/master/train.py.
79
+ Partially randomized sampler
80
+ 1. Sort by lengths
81
+ 2. Pick a small patch and randomize it
82
+ 3. Permutate mini-batches
83
+ """
84
+
85
+ def __init__(self, lengths, data_source, batch_size=16, batch_group_size=None, permutate=True):
86
+ super().__init__(data_source)
87
+ self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths))
88
+ self.batch_size = batch_size
89
+ if batch_group_size is None:
90
+ batch_group_size = min(batch_size * 32, len(self.lengths))
91
+ if batch_group_size % batch_size != 0:
92
+ batch_group_size -= batch_group_size % batch_size
93
+
94
+ self.batch_group_size = batch_group_size
95
+ assert batch_group_size % batch_size == 0
96
+ self.permutate = permutate
97
+
98
+ def __iter__(self):
99
+ indices = self.sorted_indices.clone()
100
+ batch_group_size = self.batch_group_size
101
+ s, e = 0, 0
102
+ for i in range(len(indices) // batch_group_size):
103
+ s = i * batch_group_size
104
+ e = s + batch_group_size
105
+ random.shuffle(indices[s:e])
106
+
107
+ # Permutate batches
108
+ if self.permutate:
109
+ perm = np.arange(len(indices[:e]) // self.batch_size)
110
+ random.shuffle(perm)
111
+ indices[:e] = indices[:e].view(-1, self.batch_size)[perm, :].view(-1)
112
+
113
+ # Handle last elements
114
+ s += batch_group_size
115
+ if s < len(indices):
116
+ random.shuffle(indices[s:])
117
+
118
+ return iter(indices)
119
+
120
+ def __len__(self):
121
+ return len(self.sorted_indices)
datasets/emovdb.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loader for the Emovdb dataset. See: https://github.com/numediart/EmoV-DB"""
2
+ import os
3
+ import re
4
+ import codecs
5
+ import unicodedata
6
+ import numpy as np
7
+ from audio import preprocess
8
+
9
+ from torch.utils.data import Dataset
10
+
11
+ vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS.
12
+ char2idx = {char: idx for idx, char in enumerate(vocab)}
13
+ idx2char = {idx: char for idx, char in enumerate(vocab)}
14
+
15
+
16
+ def text_normalize(text):
17
+ text = ''.join(char for char in unicodedata.normalize('NFD', text)
18
+ if unicodedata.category(char) != 'Mn') # Strip accents
19
+
20
+ text = text.lower()
21
+ text = re.sub("[^{}]".format(vocab), " ", text)
22
+ text = re.sub("[ ]+", " ", text)
23
+ return text
24
+
25
+
26
+ def read_metadata(metadata_file):
27
+ fnames, text_lengths, texts = [], [], []
28
+ transcript = os.path.join(metadata_file)
29
+ lines = codecs.open(transcript, 'r', 'utf-8').readlines()
30
+ for line in lines:
31
+ fname, text = line.strip().split("|")
32
+
33
+ fnames.append(fname)
34
+
35
+ text = text_normalize(text) + "E" # E: EOS
36
+ text = [char2idx[char] for char in text]
37
+ text_lengths.append(len(text))
38
+ texts.append(np.array(text, np.long))
39
+
40
+ return fnames, text_lengths, texts
41
+
42
+
43
+ def get_test_data(sentences, max_n):
44
+ normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS
45
+ texts = np.zeros((len(normalized_sentences), max_n + 1), np.long)
46
+ for i, sent in enumerate(normalized_sentences):
47
+ texts[i, :len(sent)] = [char2idx[char] for char in sent]
48
+ return texts
49
+
50
+
51
+ class Emovdb(Dataset):
52
+ def __init__(self, keys, dir_name='/home/brihi16142/work2/processed_emovdb_disgust'):
53
+ self.keys = keys
54
+ self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name)
55
+ self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'transcript_bea.csv'))
56
+ preprocess(dir_name, self)
57
+ print('Generated mels and mags')
58
+
59
+ def slice(self, start, end):
60
+ self.fnames = self.fnames[start:end]
61
+ self.text_lengths = self.text_lengths[start:end]
62
+ self.texts = self.texts[start:end]
63
+
64
+ def __len__(self):
65
+ return len(self.fnames)
66
+
67
+ def __getitem__(self, index):
68
+ data = {}
69
+ if 'texts' in self.keys:
70
+ data['texts'] = self.texts[index]
71
+ if 'mels' in self.keys:
72
+ # (39, 80)
73
+ data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index]))
74
+ if 'mags' in self.keys:
75
+ # (39, 80)
76
+ data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index]))
77
+ if 'mel_gates' in self.keys:
78
+ data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int) # TODO: because pre processing!
79
+ if 'mag_gates' in self.keys:
80
+ data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int) # TODO: because pre processing!
81
+ return data
82
+
datasets/lj_speech.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loader for the LJSpeech dataset. See: https://keithito.com/LJ-Speech-Dataset/"""
2
+ import os
3
+ import re
4
+ import codecs
5
+ import unicodedata
6
+ import numpy as np
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+ vocab = "PE abcdefghijklmnopqrstuvwxyz'.?" # P: Padding, E: EOS.
11
+ char2idx = {char: idx for idx, char in enumerate(vocab)}
12
+ idx2char = {idx: char for idx, char in enumerate(vocab)}
13
+
14
+
15
+ def text_normalize(text):
16
+ text = ''.join(char for char in unicodedata.normalize('NFD', text)
17
+ if unicodedata.category(char) != 'Mn') # Strip accents
18
+
19
+ text = text.lower()
20
+ text = re.sub("[^{}]".format(vocab), " ", text)
21
+ text = re.sub("[ ]+", " ", text)
22
+ return text
23
+
24
+
25
+ def read_metadata(metadata_file):
26
+ fnames, text_lengths, texts = [], [], []
27
+ transcript = os.path.join(metadata_file)
28
+ lines = codecs.open(transcript, 'r', 'utf-8').readlines()
29
+ for line in lines:
30
+ fname, _, text = line.strip().split("|")
31
+
32
+ fnames.append(fname)
33
+
34
+ text = text_normalize(text) + "E" # E: EOS
35
+ text = [char2idx[char] for char in text]
36
+ text_lengths.append(len(text))
37
+ texts.append(np.array(text, np.long))
38
+
39
+ return fnames, text_lengths, texts
40
+
41
+
42
+ def get_test_data(sentences, max_n):
43
+ normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS
44
+ texts = np.zeros((len(normalized_sentences), max_n + 1), np.long)
45
+ for i, sent in enumerate(normalized_sentences):
46
+ texts[i, :len(sent)] = [char2idx[char] for char in sent]
47
+ return texts
48
+
49
+
50
+ class LJSpeech(Dataset):
51
+ def __init__(self, keys, dir_name='LJSpeech-1.1'):
52
+ self.keys = keys
53
+ self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name)
54
+ self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv'))
55
+
56
+ def slice(self, start, end):
57
+ self.fnames = self.fnames[start:end]
58
+ self.text_lengths = self.text_lengths[start:end]
59
+ self.texts = self.texts[start:end]
60
+
61
+ def __len__(self):
62
+ return len(self.fnames)
63
+
64
+ def __getitem__(self, index):
65
+ data = {}
66
+ if 'texts' in self.keys:
67
+ data['texts'] = self.texts[index]
68
+ if 'mels' in self.keys:
69
+ # (39, 80)
70
+ data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index]))
71
+ if 'mags' in self.keys:
72
+ # (39, 80)
73
+ data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index]))
74
+ if 'mel_gates' in self.keys:
75
+ data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int) # TODO: because pre processing!
76
+ if 'mag_gates' in self.keys:
77
+ data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int) # TODO: because pre processing!
78
+ return data
datasets/mb_speech.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loader for the Mongolian Bible dataset."""
2
+ import os
3
+ import codecs
4
+ import numpy as np
5
+
6
+ from torch.utils.data import Dataset
7
+
8
+ vocab = "PE абвгдеёжзийклмноөпрстуүфхцчшъыьэюя-.,!?" # P: Padding, E: EOS.
9
+ char2idx = {char: idx for idx, char in enumerate(vocab)}
10
+ idx2char = {idx: char for idx, char in enumerate(vocab)}
11
+
12
+
13
+ def text_normalize(text):
14
+ text = text.lower()
15
+ # text = text.replace(",", "'")
16
+ # text = text.replace("!", "?")
17
+ for c in "-—:":
18
+ text = text.replace(c, "-")
19
+ for c in "()\"«»“”'":
20
+ text = text.replace(c, ",")
21
+ return text
22
+
23
+
24
+ def read_metadata(metadata_file):
25
+ fnames, text_lengths, texts = [], [], []
26
+ transcript = os.path.join(metadata_file)
27
+ lines = codecs.open(transcript, 'r', 'utf-8').readlines()
28
+ for line in lines:
29
+ fname, _, text = line.strip().split("|")
30
+
31
+ fnames.append(fname)
32
+
33
+ text = text_normalize(text) + "E" # E: EOS
34
+ text = [char2idx[char] for char in text]
35
+ text_lengths.append(len(text))
36
+ texts.append(np.array(text, np.long))
37
+
38
+ return fnames, text_lengths, texts
39
+
40
+
41
+ def get_test_data(sentences, max_n):
42
+ normalized_sentences = [text_normalize(line).strip() + "E" for line in sentences] # text normalization, E: EOS
43
+ texts = np.zeros((len(normalized_sentences), max_n + 1), np.long)
44
+ for i, sent in enumerate(normalized_sentences):
45
+ texts[i, :len(sent)] = [char2idx[char] for char in sent]
46
+ return texts
47
+
48
+
49
+ class MBSpeech(Dataset):
50
+ def __init__(self, keys, dir_name='MBSpeech-1.0'):
51
+ self.keys = keys
52
+ self.path = os.path.join(os.path.dirname(os.path.realpath(__file__)), dir_name)
53
+ self.fnames, self.text_lengths, self.texts = read_metadata(os.path.join(self.path, 'metadata.csv'))
54
+
55
+ def slice(self, start, end):
56
+ self.fnames = self.fnames[start:end]
57
+ self.text_lengths = self.text_lengths[start:end]
58
+ self.texts = self.texts[start:end]
59
+
60
+ def __len__(self):
61
+ return len(self.fnames)
62
+
63
+ def __getitem__(self, index):
64
+ data = {}
65
+ if 'texts' in self.keys:
66
+ data['texts'] = self.texts[index]
67
+ if 'mels' in self.keys:
68
+ # (39, 80)
69
+ data['mels'] = np.load(os.path.join(self.path, 'mels', "%s.npy" % self.fnames[index]))
70
+ if 'mags' in self.keys:
71
+ # (39, 80)
72
+ data['mags'] = np.load(os.path.join(self.path, 'mags', "%s.npy" % self.fnames[index]))
73
+ if 'mel_gates' in self.keys:
74
+ data['mel_gates'] = np.ones(data['mels'].shape[0], dtype=np.int) # TODO: because pre processing!
75
+ if 'mag_gates' in self.keys:
76
+ data['mag_gates'] = np.ones(data['mags'].shape[0], dtype=np.int) # TODO: because pre processing!
77
+ return data
78
+
79
+ #
80
+ # simple method to convert mongolian numbers to text, copied from somewhere
81
+ #
82
+
83
+
84
+ def number2word(number):
85
+ digit_len = len(number)
86
+ digit_name = {1: '', 2: 'мянга', 3: 'сая', 4: 'тэрбум', 5: 'их наяд', 6: 'тунамал'}
87
+
88
+ if digit_len == 1:
89
+ return _last_digit_2_str(number)
90
+ if digit_len == 2:
91
+ return _2_digits_2_str(number)
92
+ if digit_len == 3:
93
+ return _3_digits_to_str(number)
94
+ if digit_len < 7:
95
+ return _3_digits_to_str(number[:-3], False) + ' ' + digit_name[2] + ' ' + _3_digits_to_str(number[-3:])
96
+
97
+ digitgroup = [number[0 if i - 3 < 0 else i - 3:i] for i in reversed(range(len(number), 0, -3))]
98
+ count = len(digitgroup)
99
+ i = 0
100
+ result = ''
101
+ while i < count - 1:
102
+ result += ' ' + (_3_digits_to_str(digitgroup[i], False) + ' ' + digit_name[count - i])
103
+ i += 1
104
+ return result.strip() + ' ' + _3_digits_to_str(digitgroup[-1])
105
+
106
+
107
+ def _1_digit_2_str(digit):
108
+ return {'0': '', '1': 'нэгэн', '2': 'хоёр', '3': 'гурван', '4': 'дөрвөн', '5': 'таван', '6': 'зургаан',
109
+ '7': 'долоон', '8': 'найман', '9': 'есөн'}[digit]
110
+
111
+
112
+ def _last_digit_2_str(digit):
113
+ return {'0': 'тэг', '1': 'нэг', '2': 'хоёр', '3': 'гурав', '4': 'дөрөв', '5': 'тав', '6': 'зургаа', '7': 'долоо',
114
+ '8': 'найм', '9': 'ес'}[digit]
115
+
116
+
117
+ def _2_digits_2_str(digit, is_fina=True):
118
+ word2 = {'0': '', '1': 'арван', '2': 'хорин', '3': 'гучин', '4': 'дөчин', '5': 'тавин', '6': 'жаран', '7': 'далан',
119
+ '8': 'наян', '9': 'ерэн'}
120
+ word2fina = {'10': 'арав', '20': 'хорь', '30': 'гуч', '40': 'дөч', '50': 'тавь', '60': 'жар', '70': 'дал',
121
+ '80': 'ная', '90': 'ер'}
122
+ if digit[1] == '0':
123
+ return word2fina[digit] if is_fina else word2[digit[0]]
124
+ digit1 = _last_digit_2_str(digit[1]) if is_fina else _1_digit_2_str(digit[1])
125
+ return (word2[digit[0]] + ' ' + digit1).strip()
126
+
127
+
128
+ def _3_digits_to_str(digit, is_fina=True):
129
+ digstr = digit.lstrip('0')
130
+ if len(digstr) == 0:
131
+ return ''
132
+ if len(digstr) == 1:
133
+ return _1_digit_2_str(digstr)
134
+ if len(digstr) == 2:
135
+ return _2_digits_2_str(digstr, is_fina)
136
+ if digit[-2:] == '00':
137
+ return _1_digit_2_str(digit[0]) + ' зуу' if is_fina else _1_digit_2_str(digit[0]) + ' зуун'
138
+ else:
139
+ return _1_digit_2_str(digit[0]) + ' зуун ' + _2_digits_2_str(digit[-2:], is_fina)