# Copyright (c) 2023 Amphion. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from pathlib import Path from typing import List, Tuple import os import numpy as np import torch from text.symbol_table import SymbolTable from text import text_to_sequence ''' TextToken: map text to id ''' # TextTokenCollator is modified from # https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/collation.py class TextTokenCollator: def __init__( self, text_tokens: List[str], add_eos: bool = True, add_bos: bool = True, pad_symbol: str = "", bos_symbol: str = "", eos_symbol: str = "", ): self.pad_symbol = pad_symbol self.add_eos = add_eos self.add_bos = add_bos self.bos_symbol = bos_symbol self.eos_symbol = eos_symbol unique_tokens = [pad_symbol] if add_bos: unique_tokens.append(bos_symbol) if add_eos: unique_tokens.append(eos_symbol) unique_tokens.extend(sorted(text_tokens)) self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} self.idx2token = unique_tokens def index( self, tokens_list: List[str] ) -> Tuple[torch.Tensor, torch.Tensor]: seqs, seq_lens = [], [] for tokens in tokens_list: assert ( all([True if s in self.token2idx else False for s in tokens]) is True ) seq = ( ([self.bos_symbol] if self.add_bos else []) + list(tokens) + ([self.eos_symbol] if self.add_eos else []) ) seqs.append(seq) seq_lens.append(len(seq)) max_len = max(seq_lens) for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): seq.extend([self.pad_symbol] * (max_len - seq_len)) tokens = torch.from_numpy( np.array( [[self.token2idx[token] for token in seq] for seq in seqs], dtype=np.int64, ) ) tokens_lens = torch.IntTensor(seq_lens) return tokens, tokens_lens def __call__(self, text): tokens_seq = [p for p in text] seq = ( ([self.bos_symbol] if self.add_bos else []) + tokens_seq + ([self.eos_symbol] if self.add_eos else []) ) token_ids = [self.token2idx[token] for token in seq] token_lens = len(tokens_seq) + self.add_eos + self.add_bos return token_ids, token_lens def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator: text_tokens_path = Path(text_tokens_file) unique_tokens = SymbolTable.from_file(text_tokens_path) collater = TextTokenCollator( unique_tokens.symbols, add_bos=True, add_eos=True ) token2idx = collater.token2idx return collater, token2idx class phoneIDCollation: def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None: if cfg.preprocess.phone_extractor != 'lexicon': ### get text token collator if symbols_dict_file is None: assert dataset is not None symbols_dict_file = os.path.join( cfg.preprocess.processed_dir, dataset, cfg.preprocess.symbols_dict ) self.text_token_colloator, token2idx = get_text_token_collater(symbols_dict_file) # # unique_tokens = SymbolTable.from_file(symbols_dict_path) # # text_tokenizer = TextToken(unique_tokens.symbols, add_bos=True, add_eos=True) # # update phone symbols dict file with pad_symbol or optional tokens (add_bos and add_eos) in TextTokenCollator # phone_symbol_dict = SymbolTable() # for s in sorted(list(set(token2idx.keys()))): # phone_symbol_dict.add(s) # phone_symbol_dict.to_file(symbols_dict_file) def get_phone_id_sequence(self, cfg, phones_seq): if cfg.preprocess.phone_extractor == 'lexicon': phones_seq = ' '.join(phones_seq) sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners) else: sequence, seq_len = self.text_token_colloator(phones_seq) return sequence