File size: 3,971 Bytes
4c01711
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from datasets import load_from_disk
from dp.phonemizer import Phonemizer
from speechbrain.pretrained import GraphemeToPhoneme
import cmudict
import re
import fire
import torch
from os.path import join

if torch.cuda.is_available() and torch.cuda.device_count() > 1:
    torch.multiprocessing.set_start_method('spawn')

class phonemization:
    def __init__(self):
        self.chars_to_ignore_regex = r'[,?.!-;:"]'
        self.dp_phonemizer_model_path = join('models','d_phonemizer','en_us_cmudict_forward.pt')
        self.sb_phonemizer_model_path = join('models','sb_phonemizer')

        
        self.cmu_dict = cmudict.dict()
        self.dp_phonemizer = Phonemizer.from_checkpoint(self.dp_phonemizer_model_path)
        if torch.cuda.is_available():
            self.sb_phonemizer = GraphemeToPhoneme.from_hparams(self.sb_phonemizer_model_path,run_opts={"device":"cuda"})
        else:
            self.sb_phonemizer = GraphemeToPhoneme.from_hparams(self.sb_phonemizer_model_path)
        self.normalize = False

    

        
        
    def dp_phonemize(self, text):
        return self.dp_phonemizer(text, lang='en_us',expand_acronyms=False).replace('[',' ').replace(']',' ').split()
    
    
    def cmu_phonemize(self, 
                      text, 
                      fallback_phonemizer=dp_phonemize):
        phoneme_lst=[]
        for word in text.split():
            if word in self.cmu_dict:
                phoneme_lst.extend(re.sub('[0-9]','',' '.join(self.cmu_dict.get(word)[0])).split())
            else:
                phoneme_lst.extend(fallback_phonemizer(self,word))
        phoneme_lst = [p.lower() for p in phoneme_lst]
        return(phoneme_lst)
    
    
    def sb_phonemize(self,text):
        return self.sb_phonemizer(text)

    def remove_special_characters(self,text):
        #print(text)
        return re.sub(self.chars_to_ignore_regex, ' ', text).lower() + " "

    def replace_multiple_spaces_with_single_space(self, input_string):
        """Replace multiple spaces with a single space."""
        return re.sub(r'\s+', ' ', input_string)
        
    def phonemize_batch(self, 
                        batch, 
                        phonamizer_fn=dp_phonemize, 
                        suffix=''):
        
        if self.normalize:
            text = batch['text_norm'].lower()
        else:
            text = batch['text'].lower()
        phoneme_str = ' '.join(phonamizer_fn(text))
        phoneme_str = phoneme_str.lower()
        phoneme_str = self.replace_multiple_spaces_with_single_space(phoneme_str)
        batch[f'phoneme{suffix}'] = phoneme_str.strip()
        return batch

    def remove_special_characters_batch(self, batch):
        batch["text_norm"] = self.remove_special_characters(batch["text"])
        return batch
        
    def run(self, 
            dataset_path, 
            output_path, 
            phonemizers='dp,sb,cmu', 
            normalize=True, 
            nproc=1):
        
        data = load_from_disk(dataset_path)
        
        if normalize:
            data = data.map(self.remove_special_characters_batch, num_proc=nproc)
        for phonemizer in phonemizers.split(','):
            if phonemizer == 'cmu':
                data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.cmu_phonemize,'suffix':'_cmu'},num_proc=nproc)                
            if phonemizer == 'dp':
                data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.dp_phonemize,'suffix':'_dp'},num_proc=nproc)
            if phonemizer == 'sb':
                if torch.cuda.is_available():
                    nproc = torch.cuda.device_count()
                data = data.map(self.phonemize_batch, fn_kwargs={'phonamizer_fn':self.sb_phonemize,'suffix':'_sb'},num_proc=nproc, cache_file_name='/g/data/iv96/mostafa/cache_sb', load_from_cache_file=True)
        data.save_to_disk(output_path)


if __name__=='__main__':
    fire.Fire(phonemization)