|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Perform preprocessing and raw feature extraction for LJSpeech dataset.""" |
|
|
|
import os |
|
import re |
|
|
|
import numpy as np |
|
import soundfile as sf |
|
from dataclasses import dataclass |
|
from tensorflow_tts.processor import BaseProcessor |
|
from tensorflow_tts.utils import cleaners |
|
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME |
|
|
|
valid_symbols = [ |
|
"AA", |
|
"AA0", |
|
"AA1", |
|
"AA2", |
|
"AE", |
|
"AE0", |
|
"AE1", |
|
"AE2", |
|
"AH", |
|
"AH0", |
|
"AH1", |
|
"AH2", |
|
"AO", |
|
"AO0", |
|
"AO1", |
|
"AO2", |
|
"AW", |
|
"AW0", |
|
"AW1", |
|
"AW2", |
|
"AY", |
|
"AY0", |
|
"AY1", |
|
"AY2", |
|
"B", |
|
"CH", |
|
"D", |
|
"DH", |
|
"EH", |
|
"EH0", |
|
"EH1", |
|
"EH2", |
|
"ER", |
|
"ER0", |
|
"ER1", |
|
"ER2", |
|
"EY", |
|
"EY0", |
|
"EY1", |
|
"EY2", |
|
"F", |
|
"G", |
|
"HH", |
|
"IH", |
|
"IH0", |
|
"IH1", |
|
"IH2", |
|
"IY", |
|
"IY0", |
|
"IY1", |
|
"IY2", |
|
"JH", |
|
"K", |
|
"L", |
|
"M", |
|
"N", |
|
"NG", |
|
"OW", |
|
"OW0", |
|
"OW1", |
|
"OW2", |
|
"OY", |
|
"OY0", |
|
"OY1", |
|
"OY2", |
|
"P", |
|
"R", |
|
"S", |
|
"SH", |
|
"T", |
|
"TH", |
|
"UH", |
|
"UH0", |
|
"UH1", |
|
"UH2", |
|
"UW", |
|
"UW0", |
|
"UW1", |
|
"UW2", |
|
"V", |
|
"W", |
|
"Y", |
|
"Z", |
|
"ZH", |
|
] |
|
|
|
_pad = "pad" |
|
_eos = "eos" |
|
_punctuation = "!'(),.:;? " |
|
_special = "-" |
|
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" |
|
|
|
|
|
_arpabet = ["@" + s for s in valid_symbols] |
|
|
|
|
|
LJSPEECH_SYMBOLS = ( |
|
[_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + [_eos] |
|
) |
|
|
|
|
|
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") |
|
|
|
|
|
@dataclass |
|
class LJSpeechProcessor(BaseProcessor): |
|
"""LJSpeech processor.""" |
|
|
|
cleaner_names: str = "english_cleaners" |
|
positions = { |
|
"wave_file": 0, |
|
"text": 1, |
|
"text_norm": 2, |
|
} |
|
train_f_name: str = "metadata.csv" |
|
|
|
def create_items(self): |
|
if self.data_dir: |
|
with open( |
|
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8" |
|
) as f: |
|
self.items = [self.split_line(self.data_dir, line, "|") for line in f] |
|
|
|
def split_line(self, data_dir, line, split): |
|
parts = line.strip().split(split) |
|
wave_file = parts[self.positions["wave_file"]] |
|
text_norm = parts[self.positions["text_norm"]] |
|
wav_path = os.path.join(data_dir, "wavs", f"{wave_file}.wav") |
|
speaker_name = "ljspeech" |
|
return text_norm, wav_path, speaker_name |
|
|
|
def setup_eos_token(self): |
|
return _eos |
|
|
|
def save_pretrained(self, saved_path): |
|
os.makedirs(saved_path, exist_ok=True) |
|
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) |
|
|
|
def get_one_sample(self, item): |
|
text, wav_path, speaker_name = item |
|
|
|
|
|
audio, rate = sf.read(wav_path) |
|
audio = audio.astype(np.float32) |
|
|
|
|
|
text_ids = np.asarray(self.text_to_sequence(text), np.int32) |
|
|
|
sample = { |
|
"raw_text": text, |
|
"text_ids": text_ids, |
|
"audio": audio, |
|
"utt_id": os.path.split(wav_path)[-1].split(".")[0], |
|
"speaker_name": speaker_name, |
|
"rate": rate, |
|
} |
|
|
|
return sample |
|
|
|
def text_to_sequence(self, text): |
|
sequence = [] |
|
|
|
while len(text): |
|
m = _curly_re.match(text) |
|
if not m: |
|
sequence += self._symbols_to_sequence( |
|
self._clean_text(text, [self.cleaner_names]) |
|
) |
|
break |
|
sequence += self._symbols_to_sequence( |
|
self._clean_text(m.group(1), [self.cleaner_names]) |
|
) |
|
sequence += self._arpabet_to_sequence(m.group(2)) |
|
text = m.group(3) |
|
|
|
|
|
sequence += [self.eos_id] |
|
return sequence |
|
|
|
def _clean_text(self, text, cleaner_names): |
|
for name in cleaner_names: |
|
cleaner = getattr(cleaners, name) |
|
if not cleaner: |
|
raise Exception("Unknown cleaner: %s" % name) |
|
text = cleaner(text) |
|
return text |
|
|
|
def _symbols_to_sequence(self, symbols): |
|
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)] |
|
|
|
def _arpabet_to_sequence(self, text): |
|
return self._symbols_to_sequence(["@" + s for s in text.split()]) |
|
|
|
def _should_keep_symbol(self, s): |
|
return s in self.symbol_to_id and s != "_" and s != "~" |
|
|