File size: 4,363 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) 2023 Amphion.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from typing import List, Tuple
import os
import numpy as np
import torch
from text.symbol_table import SymbolTable
from text import text_to_sequence


"""
    TextToken: map text to id
"""


# TextTokenCollator is modified from
# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/collation.py
class TextTokenCollator:
    def __init__(
        self,
        text_tokens: List[str],
        add_eos: bool = True,
        add_bos: bool = True,
        pad_symbol: str = "<pad>",
        bos_symbol: str = "<bos>",
        eos_symbol: str = "<eos>",
    ):
        self.pad_symbol = pad_symbol
        self.add_eos = add_eos
        self.add_bos = add_bos
        self.bos_symbol = bos_symbol
        self.eos_symbol = eos_symbol

        unique_tokens = [pad_symbol]
        if add_bos:
            unique_tokens.append(bos_symbol)
        if add_eos:
            unique_tokens.append(eos_symbol)
        unique_tokens.extend(sorted(text_tokens))

        self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
        self.idx2token = unique_tokens

    def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]:
        seqs, seq_lens = [], []
        for tokens in tokens_list:
            assert all([True if s in self.token2idx else False for s in tokens]) is True
            seq = (
                ([self.bos_symbol] if self.add_bos else [])
                + list(tokens)
                + ([self.eos_symbol] if self.add_eos else [])
            )
            seqs.append(seq)
            seq_lens.append(len(seq))

        max_len = max(seq_lens)
        for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)):
            seq.extend([self.pad_symbol] * (max_len - seq_len))

        tokens = torch.from_numpy(
            np.array(
                [[self.token2idx[token] for token in seq] for seq in seqs],
                dtype=np.int64,
            )
        )
        tokens_lens = torch.IntTensor(seq_lens)

        return tokens, tokens_lens

    def __call__(self, text):
        tokens_seq = [p for p in text]
        seq = (
            ([self.bos_symbol] if self.add_bos else [])
            + tokens_seq
            + ([self.eos_symbol] if self.add_eos else [])
        )

        token_ids = [self.token2idx[token] for token in seq]
        token_lens = len(tokens_seq) + self.add_eos + self.add_bos

        return token_ids, token_lens


def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator:
    text_tokens_path = Path(text_tokens_file)
    unique_tokens = SymbolTable.from_file(text_tokens_path)
    collater = TextTokenCollator(unique_tokens.symbols, add_bos=True, add_eos=True)
    token2idx = collater.token2idx
    return collater, token2idx


class phoneIDCollation:
    def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None:
        if cfg.preprocess.phone_extractor != "lexicon":
            ### get text token collator
            if symbols_dict_file is None:
                assert dataset is not None
                symbols_dict_file = os.path.join(
                    cfg.preprocess.processed_dir, dataset, cfg.preprocess.symbols_dict
                )
            self.text_token_colloator, token2idx = get_text_token_collater(
                symbols_dict_file
            )
            # # unique_tokens = SymbolTable.from_file(symbols_dict_path)
            # # text_tokenizer = TextToken(unique_tokens.symbols, add_bos=True, add_eos=True)

            # # update phone symbols dict file with pad_symbol or optional tokens (add_bos and add_eos) in TextTokenCollator
            # phone_symbol_dict = SymbolTable()
            # for s in sorted(list(set(token2idx.keys()))):
            #     phone_symbol_dict.add(s)
            # phone_symbol_dict.to_file(symbols_dict_file)

    def get_phone_id_sequence(self, cfg, phones_seq):
        if cfg.preprocess.phone_extractor == "lexicon":
            phones_seq = " ".join(phones_seq)
            sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners)
        else:
            sequence, seq_len = self.text_token_colloator(phones_seq)
        return sequence