|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Perform preprocessing and raw feature extraction for KSS dataset.""" |
|
|
|
import os |
|
import re |
|
|
|
import numpy as np |
|
import soundfile as sf |
|
from dataclasses import dataclass |
|
from tensorflow_tts.processor import BaseProcessor |
|
from tensorflow_tts.utils import cleaners |
|
from tensorflow_tts.utils.korean import symbols as KSS_SYMBOLS |
|
from tensorflow_tts.utils.utils import PROCESSOR_FILE_NAME |
|
|
|
|
|
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") |
|
|
|
|
|
@dataclass |
|
class KSSProcessor(BaseProcessor): |
|
"""KSS processor.""" |
|
|
|
cleaner_names: str = "korean_cleaners" |
|
positions = { |
|
"wave_file": 0, |
|
"text_norm": 2, |
|
} |
|
train_f_name: str = "transcript.v.1.4.txt" |
|
|
|
def create_items(self): |
|
if self.data_dir: |
|
with open( |
|
os.path.join(self.data_dir, self.train_f_name), encoding="utf-8" |
|
) as f: |
|
self.items = [self.split_line(self.data_dir, line, "|") for line in f] |
|
|
|
def split_line(self, data_dir, line, split): |
|
parts = line.strip().split(split) |
|
wave_file = parts[self.positions["wave_file"]] |
|
text_norm = parts[self.positions["text_norm"]] |
|
wav_path = os.path.join(data_dir, "kss", wave_file) |
|
speaker_name = "kss" |
|
return text_norm, wav_path, speaker_name |
|
|
|
def setup_eos_token(self): |
|
return "eos" |
|
|
|
def save_pretrained(self, saved_path): |
|
os.makedirs(saved_path, exist_ok=True) |
|
self._save_mapper(os.path.join(saved_path, PROCESSOR_FILE_NAME), {}) |
|
|
|
def get_one_sample(self, item): |
|
text, wav_path, speaker_name = item |
|
|
|
|
|
audio, rate = sf.read(wav_path) |
|
audio = audio.astype(np.float32) |
|
|
|
|
|
text_ids = np.asarray(self.text_to_sequence(text), np.int32) |
|
|
|
sample = { |
|
"raw_text": text, |
|
"text_ids": text_ids, |
|
"audio": audio, |
|
"utt_id": os.path.split(wav_path)[-1].split(".")[0], |
|
"speaker_name": speaker_name, |
|
"rate": rate, |
|
} |
|
|
|
return sample |
|
|
|
def text_to_sequence(self, text): |
|
|
|
sequence = [] |
|
|
|
while len(text): |
|
m = _curly_re.match(text) |
|
if not m: |
|
sequence += self._symbols_to_sequence( |
|
self._clean_text(text, [self.cleaner_names]) |
|
) |
|
break |
|
sequence += self._symbols_to_sequence( |
|
self._clean_text(m.group(1), [self.cleaner_names]) |
|
) |
|
sequence += self._arpabet_to_sequence(m.group(2)) |
|
text = m.group(3) |
|
|
|
|
|
sequence += [self.eos_id] |
|
return sequence |
|
|
|
def _clean_text(self, text, cleaner_names): |
|
for name in cleaner_names: |
|
cleaner = getattr(cleaners, name) |
|
if not cleaner: |
|
raise Exception("Unknown cleaner: %s" % name) |
|
text = cleaner(text) |
|
return text |
|
|
|
def _symbols_to_sequence(self, symbols): |
|
return [self.symbol_to_id[s] for s in symbols if self._should_keep_symbol(s)] |
|
|
|
def _arpabet_to_sequence(self, text): |
|
return self._symbols_to_sequence(["@" + s for s in text.split()]) |
|
|
|
def _should_keep_symbol(self, s): |
|
return s in self.symbol_to_id and s != "_" and s != "~" |
|
|