File size: 5,165 Bytes
df2accb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
from tqdm import tqdm
from text.g2p_module import G2PModule, LexiconModule
from text.symbol_table import SymbolTable

'''
    phoneExtractor: extract phone from text
'''
class phoneExtractor:
    def __init__(self, cfg, dataset_name=None, phone_symbol_file=None):
        '''
            Args:
                cfg: config
                dataset_name: name of dataset
        '''
        self.cfg = cfg

        #  phone symbols dict
        self.phone_symbols = set()
        
        # phone symbols dict file
        if phone_symbol_file is not None:
            self.phone_symbols_file = phone_symbol_file
        elif dataset_name is not None:
            self.dataset_name = dataset_name
            self.phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
                                            dataset_name, 
                                            cfg.preprocess.symbols_dict)

        
        # initialize g2p module
        if cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
            self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor)
        elif cfg.preprocess.phone_extractor == 'lexicon':
            assert cfg.preprocess.lexicon_path != ""
            self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path)
        else:
            print('No suppert to', cfg.preprocess.phone_extractor)
            raise

    
    def extract_phone(self, text):
        '''
            Extract phone from text
            Args:

                text:  text of utterance
                
            Returns:    
                phone_symbols: set of phone symbols
                phone_seq: list of phone sequence of each utterance
        '''
        
        if self.cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
            text = text.replace("”", '"').replace("“", '"')
            phone = self.g2p_module.g2p_conversion(text=text)  
            self.phone_symbols.update(phone)    
            phone_seq = [phn for phn in phone]
            
        elif self.cfg.preprocess.phone_extractor == 'lexicon':
            phone_seq = self.g2p_module.g2p_conversion(text)
            phone = phone_seq
            if not isinstance(phone_seq, list):
                phone_seq = phone_seq.split()
           
        return phone_seq

    def save_dataset_phone_symbols_to_table(self):
        # load and merge saved phone symbols                
        if os.path.exists(self.phone_symbols_file):
            phone_symbol_dict_saved = SymbolTable.from_file(self.phone_symbols_file)._sym2id.keys()
            self.phone_symbols.update(set(phone_symbol_dict_saved))

        # save phone symbols
        phone_symbol_dict = SymbolTable()
        for s in sorted(list(self.phone_symbols)):
            phone_symbol_dict.add(s)
        phone_symbol_dict.to_file(self.phone_symbols_file)    

                
def extract_utt_phone_sequence(cfg, metadata):
    '''
        Extract phone sequence from text
        Args:
            cfg: config
            metadata: list of dict, each dict contains "Uid", "Text"
            
    '''
    
    dataset_name = cfg.dataset[0]
    
    # output path
    out_path = os.path.join(cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir)
    os.makedirs(out_path, exist_ok=True)
        
    phone_extractor = phoneExtractor(cfg, dataset_name)

    for utt in tqdm(metadata):  
        uid = utt["Uid"]
        text = utt["Text"]    
                 
        phone_seq = phone_extractor.extract_phone(text)
                 
        phone_path = os.path.join(out_path, uid+'.phone')
        with open(phone_path, 'w') as fin:
            fin.write(' '.join(phone_seq))
    
    if cfg.preprocess.phone_extractor != 'lexicon':
        phone_extractor.save_dataset_phone_symbols_to_table()
    
    
        
def save_all_dataset_phone_symbols_to_table(self, cfg, dataset):
    #  phone symbols dict
    phone_symbols = set()
    
    for dataset_name in dataset:
        phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
                                          dataset_name, 
                                          cfg.preprocess.symbols_dict)
        
        # load and merge saved phone symbols                
        assert os.path.exists(phone_symbols_file)
        phone_symbol_dict_saved = SymbolTable.from_file(phone_symbols_file)._sym2id.keys()
        phone_symbols.update(set(phone_symbol_dict_saved))
        
    # save all phone symbols to each dataset
    phone_symbol_dict = SymbolTable()
    for s in sorted(list(phone_symbols)):
        phone_symbol_dict.add(s)
    for dataset_name in dataset:
        phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
                                          dataset_name, 
                                          cfg.preprocess.symbols_dict)
        phone_symbol_dict.to_file(phone_symbols_file)