Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 4,514 Bytes
			
			| 0d80816 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |  # Copyright (c) 2023 Amphion.
 
 # This source code is licensed under the MIT license found in the
 # LICENSE file in the root directory of this source tree.
 
from pathlib import Path
from typing import List, Tuple
import os
import numpy as np
import torch
from text.symbol_table import SymbolTable
from text import text_to_sequence
'''
    TextToken: map text to id
'''
# TextTokenCollator is modified from 
# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/collation.py
class TextTokenCollator:
    def __init__(
        self,
        text_tokens: List[str],
        add_eos: bool = True,
        add_bos: bool = True,
        pad_symbol: str = "<pad>",
        bos_symbol: str = "<bos>",
        eos_symbol: str = "<eos>",
    ):
        self.pad_symbol = pad_symbol
        self.add_eos = add_eos
        self.add_bos = add_bos
        self.bos_symbol = bos_symbol
        self.eos_symbol = eos_symbol
        unique_tokens = [pad_symbol]
        if add_bos:
            unique_tokens.append(bos_symbol)
        if add_eos:
            unique_tokens.append(eos_symbol)
        unique_tokens.extend(sorted(text_tokens))
        self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
        self.idx2token = unique_tokens
    def index(
        self, tokens_list: List[str]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        seqs, seq_lens = [], []
        for tokens in tokens_list:
            assert (
                all([True if s in self.token2idx else False for s in tokens])
                is True
            )
            seq = (
                ([self.bos_symbol] if self.add_bos else [])
                + list(tokens)
                + ([self.eos_symbol] if self.add_eos else [])
            )
            seqs.append(seq)
            seq_lens.append(len(seq))
        max_len = max(seq_lens)
        for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
            seq.extend([self.pad_symbol] * (max_len - seq_len))
        tokens = torch.from_numpy(
            np.array(
                [[self.token2idx[token] for token in seq] for seq in seqs],
                dtype=np.int64,
            )
        )
        tokens_lens = torch.IntTensor(seq_lens)
        return tokens, tokens_lens
                
    def __call__(self, text):
        tokens_seq = [p for p in text]
        seq = (
            ([self.bos_symbol] if self.add_bos else [])
            + tokens_seq
            + ([self.eos_symbol] if self.add_eos else [])
        )
        token_ids = [self.token2idx[token] for token in seq]
        token_lens = len(tokens_seq) + self.add_eos + self.add_bos
        return token_ids, token_lens  
    
def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator:
    text_tokens_path = Path(text_tokens_file)
    unique_tokens = SymbolTable.from_file(text_tokens_path)
    collater = TextTokenCollator(
        unique_tokens.symbols, add_bos=True, add_eos=True
    )
    token2idx = collater.token2idx
    return collater, token2idx
class phoneIDCollation:
    def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None:
        if cfg.preprocess.phone_extractor != 'lexicon':
            ### get text token collator
            if symbols_dict_file is None:
                assert dataset is not None
                symbols_dict_file = os.path.join(
                    cfg.preprocess.processed_dir,
                    dataset,
                    cfg.preprocess.symbols_dict
                )
            self.text_token_colloator, token2idx = get_text_token_collater(symbols_dict_file)
            # # unique_tokens = SymbolTable.from_file(symbols_dict_path)
            # # text_tokenizer = TextToken(unique_tokens.symbols, add_bos=True, add_eos=True)
          
            # # update phone symbols dict file with pad_symbol or optional tokens (add_bos and add_eos) in TextTokenCollator
            # phone_symbol_dict = SymbolTable()
            # for s in sorted(list(set(token2idx.keys()))):
            #     phone_symbol_dict.add(s)
            # phone_symbol_dict.to_file(symbols_dict_file)  
            
    def get_phone_id_sequence(self, cfg, phones_seq):    
      
        if cfg.preprocess.phone_extractor == 'lexicon':
            phones_seq = ' '.join(phones_seq)
            sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners)
        else:
            sequence, seq_len = self.text_token_colloator(phones_seq)
        return sequence
    | 
