RMSnow's picture
init and interface
df2accb
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
from tqdm import tqdm
from text.g2p_module import G2PModule, LexiconModule
from text.symbol_table import SymbolTable
'''
phoneExtractor: extract phone from text
'''
class phoneExtractor:
def __init__(self, cfg, dataset_name=None, phone_symbol_file=None):
'''
Args:
cfg: config
dataset_name: name of dataset
'''
self.cfg = cfg
# phone symbols dict
self.phone_symbols = set()
# phone symbols dict file
if phone_symbol_file is not None:
self.phone_symbols_file = phone_symbol_file
elif dataset_name is not None:
self.dataset_name = dataset_name
self.phone_symbols_file = os.path.join(cfg.preprocess.processed_dir,
dataset_name,
cfg.preprocess.symbols_dict)
# initialize g2p module
if cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor)
elif cfg.preprocess.phone_extractor == 'lexicon':
assert cfg.preprocess.lexicon_path != ""
self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path)
else:
print('No suppert to', cfg.preprocess.phone_extractor)
raise
def extract_phone(self, text):
'''
Extract phone from text
Args:
text: text of utterance
Returns:
phone_symbols: set of phone symbols
phone_seq: list of phone sequence of each utterance
'''
if self.cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
text = text.replace("”", '"').replace("“", '"')
phone = self.g2p_module.g2p_conversion(text=text)
self.phone_symbols.update(phone)
phone_seq = [phn for phn in phone]
elif self.cfg.preprocess.phone_extractor == 'lexicon':
phone_seq = self.g2p_module.g2p_conversion(text)
phone = phone_seq
if not isinstance(phone_seq, list):
phone_seq = phone_seq.split()
return phone_seq
def save_dataset_phone_symbols_to_table(self):
# load and merge saved phone symbols
if os.path.exists(self.phone_symbols_file):
phone_symbol_dict_saved = SymbolTable.from_file(self.phone_symbols_file)._sym2id.keys()
self.phone_symbols.update(set(phone_symbol_dict_saved))
# save phone symbols
phone_symbol_dict = SymbolTable()
for s in sorted(list(self.phone_symbols)):
phone_symbol_dict.add(s)
phone_symbol_dict.to_file(self.phone_symbols_file)
def extract_utt_phone_sequence(cfg, metadata):
'''
Extract phone sequence from text
Args:
cfg: config
metadata: list of dict, each dict contains "Uid", "Text"
'''
dataset_name = cfg.dataset[0]
# output path
out_path = os.path.join(cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir)
os.makedirs(out_path, exist_ok=True)
phone_extractor = phoneExtractor(cfg, dataset_name)
for utt in tqdm(metadata):
uid = utt["Uid"]
text = utt["Text"]
phone_seq = phone_extractor.extract_phone(text)
phone_path = os.path.join(out_path, uid+'.phone')
with open(phone_path, 'w') as fin:
fin.write(' '.join(phone_seq))
if cfg.preprocess.phone_extractor != 'lexicon':
phone_extractor.save_dataset_phone_symbols_to_table()
def save_all_dataset_phone_symbols_to_table(self, cfg, dataset):
# phone symbols dict
phone_symbols = set()
for dataset_name in dataset:
phone_symbols_file = os.path.join(cfg.preprocess.processed_dir,
dataset_name,
cfg.preprocess.symbols_dict)
# load and merge saved phone symbols
assert os.path.exists(phone_symbols_file)
phone_symbol_dict_saved = SymbolTable.from_file(phone_symbols_file)._sym2id.keys()
phone_symbols.update(set(phone_symbol_dict_saved))
# save all phone symbols to each dataset
phone_symbol_dict = SymbolTable()
for s in sorted(list(phone_symbols)):
phone_symbol_dict.add(s)
for dataset_name in dataset:
phone_symbols_file = os.path.join(cfg.preprocess.processed_dir,
dataset_name,
cfg.preprocess.symbols_dict)
phone_symbol_dict.to_file(phone_symbols_file)