Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/processor
/base_processor.py
# -*- coding: utf-8 -*- | |
# Copyright 2020 TensorFlowTTS Team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Base Processor for all processor.""" | |
import abc | |
import json | |
import os | |
from typing import Dict, List, Union | |
from dataclasses import dataclass, field | |
class DataProcessorError(Exception): | |
pass | |
class BaseProcessor(abc.ABC): | |
data_dir: str | |
symbols: List[str] = field(default_factory=list) | |
speakers_map: Dict[str, int] = field(default_factory=dict) | |
train_f_name: str = "train.txt" | |
delimiter: str = "|" | |
positions = { | |
"file": 0, | |
"text": 1, | |
"speaker_name": 2, | |
} # positions of file,text,speaker_name after split line | |
f_extension: str = ".wav" | |
saved_mapper_path: str = None | |
loaded_mapper_path: str = None | |
# extras | |
items: List[List[str]] = field(default_factory=list) # text, wav_path, speaker_name | |
symbol_to_id: Dict[str, int] = field(default_factory=dict) | |
id_to_symbol: Dict[int, str] = field(default_factory=dict) | |
def __post_init__(self): | |
if self.loaded_mapper_path is not None: | |
self._load_mapper(loaded_path=self.loaded_mapper_path) | |
if self.setup_eos_token(): | |
self.add_symbol( | |
self.setup_eos_token() | |
) # if this eos token not yet present in symbols list. | |
self.eos_id = self.symbol_to_id[self.setup_eos_token()] | |
return | |
if self.symbols.__len__() < 1: | |
raise DataProcessorError("Symbols list is empty but mapper isn't loaded") | |
self.create_items() | |
self.create_speaker_map() | |
self.reverse_speaker = {v: k for k, v in self.speakers_map.items()} | |
self.create_symbols() | |
if self.saved_mapper_path is not None: | |
self._save_mapper(saved_path=self.saved_mapper_path) | |
# processor name. usefull to use it for AutoProcessor | |
self._processor_name = type(self).__name__ | |
if self.setup_eos_token(): | |
self.add_symbol( | |
self.setup_eos_token() | |
) # if this eos token not yet present in symbols list. | |
self.eos_id = self.symbol_to_id[self.setup_eos_token()] | |
def __getattr__(self, name: str) -> Union[str, int]: | |
if "_id" in name: # map symbol to id | |
return self.symbol_to_id[name.replace("_id", "")] | |
return self.symbol_to_id[name] # map symbol to value | |
def create_speaker_map(self): | |
""" | |
Create speaker map for dataset. | |
""" | |
sp_id = 0 | |
for i in self.items: | |
speaker_name = i[-1] | |
if speaker_name not in self.speakers_map: | |
self.speakers_map[speaker_name] = sp_id | |
sp_id += 1 | |
def get_speaker_id(self, name: str) -> int: | |
return self.speakers_map[name] | |
def get_speaker_name(self, speaker_id: int) -> str: | |
return self.speakers_map[speaker_id] | |
def create_symbols(self): | |
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)} | |
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)} | |
def create_items(self): | |
""" | |
Method used to create items from training file | |
items struct example => text, wav_file_path, speaker_name. | |
Note that the speaker_name should be a last. | |
""" | |
with open( | |
os.path.join(self.data_dir, self.train_f_name), mode="r", encoding="utf-8" | |
) as f: | |
for line in f: | |
parts = line.strip().split(self.delimiter) | |
wav_path = os.path.join(self.data_dir, parts[self.positions["file"]]) | |
wav_path = ( | |
wav_path + self.f_extension | |
if wav_path[-len(self.f_extension) :] != self.f_extension | |
else wav_path | |
) | |
text = parts[self.positions["text"]] | |
speaker_name = parts[self.positions["speaker_name"]] | |
self.items.append([text, wav_path, speaker_name]) | |
def add_symbol(self, symbol: Union[str, list]): | |
if isinstance(symbol, str): | |
if symbol in self.symbol_to_id: | |
return | |
self.symbols.append(symbol) | |
symbol_id = len(self.symbol_to_id) | |
self.symbol_to_id[symbol] = symbol_id | |
self.id_to_symbol[symbol_id] = symbol | |
elif isinstance(symbol, list): | |
for i in symbol: | |
self.add_symbol(i) | |
else: | |
raise ValueError("A new_symbols must be a string or list of string.") | |
def get_one_sample(self, item): | |
"""Get one sample from dataset items. | |
Args: | |
item: one item in Dataset items. | |
Dataset items may include (raw_text, speaker_id, wav_path, ...) | |
Returns: | |
sample (dict): sample dictionary return all feature used for preprocessing later. | |
""" | |
sample = { | |
"raw_text": None, | |
"text_ids": None, | |
"audio": None, | |
"utt_id": None, | |
"speaker_name": None, | |
"rate": None, | |
} | |
return sample | |
def text_to_sequence(self, text: str): | |
return [] | |
def setup_eos_token(self): | |
"""Return eos symbol of type string.""" | |
return "eos" | |
def convert_symbols_to_ids(self, symbols: Union[str, list]): | |
sequence = [] | |
if isinstance(symbols, str): | |
sequence.append(self._symbol_to_id[symbols]) | |
return sequence | |
elif isinstance(symbols, list): | |
for s in symbols: | |
if isinstance(s, str): | |
sequence.append(self._symbol_to_id[s]) | |
else: | |
raise ValueError("All elements of symbols must be a string.") | |
else: | |
raise ValueError("A symbols must be a string or list of string.") | |
return sequence | |
def _load_mapper(self, loaded_path: str = None): | |
""" | |
Save all needed mappers to file | |
""" | |
loaded_path = ( | |
os.path.join(self.data_dir, "mapper.json") | |
if loaded_path is None | |
else loaded_path | |
) | |
with open(loaded_path, "r") as f: | |
data = json.load(f) | |
self.speakers_map = data["speakers_map"] | |
self.symbol_to_id = data["symbol_to_id"] | |
self.id_to_symbol = {int(k): v for k, v in data["id_to_symbol"].items()} | |
self._processor_name = data["processor_name"] | |
# other keys | |
all_data_keys = data.keys() | |
for key in all_data_keys: | |
if key not in ["speakers_map", "symbol_to_id", "id_to_symbol"]: | |
setattr(self, key, data[key]) | |
def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None): | |
""" | |
Save all needed mappers to file | |
""" | |
saved_path = ( | |
os.path.join(self.data_dir, "mapper.json") | |
if saved_path is None | |
else saved_path | |
) | |
with open(saved_path, "w") as f: | |
full_mapper = { | |
"symbol_to_id": self.symbol_to_id, | |
"id_to_symbol": self.id_to_symbol, | |
"speakers_map": self.speakers_map, | |
"processor_name": self._processor_name, | |
} | |
if extra_attrs_to_save: | |
full_mapper = {**full_mapper, **extra_attrs_to_save} | |
json.dump(full_mapper, f) | |
def save_pretrained(self, saved_path): | |
"""Save mappers to file""" | |
pass | |