"""Collators for T2S and S2A. Copyright PolyAI Limited. """ from pathlib import Path from typing import List, Tuple, Union import numpy as np import torch from utils.symbol_table import SymbolTable class GlobalCollater: def __init__(self, n_codes, n_semantic_codes): self.n_codes = n_codes self.sem_mask_id = n_semantic_codes def collate(self, batch): output = { 'speaker': [], 'tts_quantize_input': [], 'tts_quantize_output': [], 'quantize_mask': [], 'f_names': [], 'semantic_tokens': [], 'quantization_lengths': [], } # Get the max length of everything max_len_q = 0 for _, q_s, q_e, _, _ in batch: if len(q_s) > max_len_q: max_len_q = len(q_s) output['quantization_lengths'].append(len(q_s)) # Pad each element, create mask for spkr, qs, qe, itm_name, s_tokens in batch: # Deal with quantizations q_mask = np.array( [False] * len(qs) + [True] * (max_len_q - len(qs))) qs = np.pad( qs, [[0, max_len_q-len(qs)], [0, 0]], constant_values=self.n_codes ) qe = np.pad( qe, [[0, max_len_q-len(qe)], [0, 0]], constant_values=self.n_codes ) # Deal with semantics s_tokens = s_tokens.flatten() s_tokens = np.pad( s_tokens, (0, max_len_q-len(s_tokens)), constant_values=self.sem_mask_id ) # Speaker padding spkr = np.concatenate( (spkr, np.zeros((max_len_q - len(spkr), 512)))) # Aggregate output['speaker'].append(spkr) output['tts_quantize_input'].append(qs) output['tts_quantize_output'].append(qe) output['quantize_mask'].append(q_mask) output['f_names'].append(itm_name) output["semantic_tokens"].append(s_tokens) for k in output.keys(): if k == 'f_names': continue output[k] = np.array(output[k]) if 'mask' in k: output[k] = torch.BoolTensor(output[k]) elif k in [ 'tts_quantize_input', 'tts_quantize_output', 'semantic_tokens', 'quantization_lengths' ]: output[k] = torch.LongTensor(output[k]) else: output[k] = torch.FloatTensor(output[k]) return output class TextTokenCollater: def __init__( self, text_tokens: List[str], add_eos: bool = True, add_bos: bool = True, pad_symbol: str = "", bos_symbol: str = "", eos_symbol: str = "", spkr_1_symbol: str = "spkr_1", spkr_2_symbol: str = "spkr_2", ): self.pad_symbol = pad_symbol self.add_eos = add_eos self.add_bos = add_bos self.bos_symbol = bos_symbol self.eos_symbol = eos_symbol self.spkr_1_symbol = spkr_1_symbol self.spkr_2_symbol = spkr_2_symbol unique_tokens = ( [pad_symbol] + ([bos_symbol] if add_bos else []) + ([eos_symbol] if add_eos else []) + ([spkr_1_symbol]) + ([spkr_2_symbol]) + sorted(text_tokens) ) self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} self.idx2token = [token for token in unique_tokens] def __call__( self, texts: List[str], texts_2: Union[None, List[str]] = None ) -> Tuple[torch.Tensor, torch.Tensor]: tokens_seqs = [[p for p in text] for text in texts] if texts_2 is None: seqs = [ ([self.bos_symbol] if self.add_bos else []) + [self.spkr_1_symbol] + list(seq) + ([self.eos_symbol] if self.add_eos else []) for seq in tokens_seqs ] else: tokens_seqs_2 = [[p for p in text] for text in texts_2] seqs = [ ([self.bos_symbol] if self.add_bos else []) + [self.spkr_1_symbol] + list(seq) + ([self.spkr_2_symbol]) + list(seq_2) + ([self.eos_symbol] if self.add_eos else []) for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2) ] tokens_batch = torch.from_numpy( np.array( [[self.token2idx[token] for token in seq] for seq in seqs], dtype=np.int64, ) ) return tokens_batch def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: text_tokens_path = Path(text_tokens_file) unique_tokens = SymbolTable.from_file(text_tokens_path) collater = TextTokenCollater( unique_tokens.symbols, add_bos=True, add_eos=True ) return collater def get_text_semantic_token_collater( text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater: text_tokens_path = Path(text_tokens_file) unique_tokens = SymbolTable.from_file(text_tokens_path) for semantic_idx in range(n_semantic_tokens): unique_tokens.add(str(semantic_idx)) collater = TextTokenCollater( unique_tokens.symbols, add_bos=True, add_eos=True ) return collater if __name__ == '__main__': text_tokens_file = 'ckpt/unique_text_tokens.k2symbols' collater = get_text_semantic_token_collater(text_tokens_file)