vishred18's picture
Upload 364 files
d5ee97c
raw
history blame contribute delete
No virus
8.12 kB
# -*- coding: utf-8 -*-
# Copyright 2020 TensorFlowTTS Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base Processor for all processor."""
import abc
import json
import os
from typing import Dict, List, Union
from dataclasses import dataclass, field
class DataProcessorError(Exception):
pass
@dataclass
class BaseProcessor(abc.ABC):
data_dir: str
symbols: List[str] = field(default_factory=list)
speakers_map: Dict[str, int] = field(default_factory=dict)
train_f_name: str = "train.txt"
delimiter: str = "|"
positions = {
"file": 0,
"text": 1,
"speaker_name": 2,
} # positions of file,text,speaker_name after split line
f_extension: str = ".wav"
saved_mapper_path: str = None
loaded_mapper_path: str = None
# extras
items: List[List[str]] = field(default_factory=list) # text, wav_path, speaker_name
symbol_to_id: Dict[str, int] = field(default_factory=dict)
id_to_symbol: Dict[int, str] = field(default_factory=dict)
def __post_init__(self):
if self.loaded_mapper_path is not None:
self._load_mapper(loaded_path=self.loaded_mapper_path)
if self.setup_eos_token():
self.add_symbol(
self.setup_eos_token()
) # if this eos token not yet present in symbols list.
self.eos_id = self.symbol_to_id[self.setup_eos_token()]
return
if self.symbols.__len__() < 1:
raise DataProcessorError("Symbols list is empty but mapper isn't loaded")
self.create_items()
self.create_speaker_map()
self.reverse_speaker = {v: k for k, v in self.speakers_map.items()}
self.create_symbols()
if self.saved_mapper_path is not None:
self._save_mapper(saved_path=self.saved_mapper_path)
# processor name. usefull to use it for AutoProcessor
self._processor_name = type(self).__name__
if self.setup_eos_token():
self.add_symbol(
self.setup_eos_token()
) # if this eos token not yet present in symbols list.
self.eos_id = self.symbol_to_id[self.setup_eos_token()]
def __getattr__(self, name: str) -> Union[str, int]:
if "_id" in name: # map symbol to id
return self.symbol_to_id[name.replace("_id", "")]
return self.symbol_to_id[name] # map symbol to value
def create_speaker_map(self):
"""
Create speaker map for dataset.
"""
sp_id = 0
for i in self.items:
speaker_name = i[-1]
if speaker_name not in self.speakers_map:
self.speakers_map[speaker_name] = sp_id
sp_id += 1
def get_speaker_id(self, name: str) -> int:
return self.speakers_map[name]
def get_speaker_name(self, speaker_id: int) -> str:
return self.speakers_map[speaker_id]
def create_symbols(self):
self.symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self.id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
def create_items(self):
"""
Method used to create items from training file
items struct example => text, wav_file_path, speaker_name.
Note that the speaker_name should be a last.
"""
with open(
os.path.join(self.data_dir, self.train_f_name), mode="r", encoding="utf-8"
) as f:
for line in f:
parts = line.strip().split(self.delimiter)
wav_path = os.path.join(self.data_dir, parts[self.positions["file"]])
wav_path = (
wav_path + self.f_extension
if wav_path[-len(self.f_extension) :] != self.f_extension
else wav_path
)
text = parts[self.positions["text"]]
speaker_name = parts[self.positions["speaker_name"]]
self.items.append([text, wav_path, speaker_name])
def add_symbol(self, symbol: Union[str, list]):
if isinstance(symbol, str):
if symbol in self.symbol_to_id:
return
self.symbols.append(symbol)
symbol_id = len(self.symbol_to_id)
self.symbol_to_id[symbol] = symbol_id
self.id_to_symbol[symbol_id] = symbol
elif isinstance(symbol, list):
for i in symbol:
self.add_symbol(i)
else:
raise ValueError("A new_symbols must be a string or list of string.")
@abc.abstractmethod
def get_one_sample(self, item):
"""Get one sample from dataset items.
Args:
item: one item in Dataset items.
Dataset items may include (raw_text, speaker_id, wav_path, ...)
Returns:
sample (dict): sample dictionary return all feature used for preprocessing later.
"""
sample = {
"raw_text": None,
"text_ids": None,
"audio": None,
"utt_id": None,
"speaker_name": None,
"rate": None,
}
return sample
@abc.abstractmethod
def text_to_sequence(self, text: str):
return []
@abc.abstractmethod
def setup_eos_token(self):
"""Return eos symbol of type string."""
return "eos"
def convert_symbols_to_ids(self, symbols: Union[str, list]):
sequence = []
if isinstance(symbols, str):
sequence.append(self._symbol_to_id[symbols])
return sequence
elif isinstance(symbols, list):
for s in symbols:
if isinstance(s, str):
sequence.append(self._symbol_to_id[s])
else:
raise ValueError("All elements of symbols must be a string.")
else:
raise ValueError("A symbols must be a string or list of string.")
return sequence
def _load_mapper(self, loaded_path: str = None):
"""
Save all needed mappers to file
"""
loaded_path = (
os.path.join(self.data_dir, "mapper.json")
if loaded_path is None
else loaded_path
)
with open(loaded_path, "r") as f:
data = json.load(f)
self.speakers_map = data["speakers_map"]
self.symbol_to_id = data["symbol_to_id"]
self.id_to_symbol = {int(k): v for k, v in data["id_to_symbol"].items()}
self._processor_name = data["processor_name"]
# other keys
all_data_keys = data.keys()
for key in all_data_keys:
if key not in ["speakers_map", "symbol_to_id", "id_to_symbol"]:
setattr(self, key, data[key])
def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None):
"""
Save all needed mappers to file
"""
saved_path = (
os.path.join(self.data_dir, "mapper.json")
if saved_path is None
else saved_path
)
with open(saved_path, "w") as f:
full_mapper = {
"symbol_to_id": self.symbol_to_id,
"id_to_symbol": self.id_to_symbol,
"speakers_map": self.speakers_map,
"processor_name": self._processor_name,
}
if extra_attrs_to_save:
full_mapper = {**full_mapper, **extra_attrs_to_save}
json.dump(full_mapper, f)
@abc.abstractmethod
def save_pretrained(self, saved_path):
"""Save mappers to file"""
pass