Krisshvamsi commited on
Commit
ce88638
1 Parent(s): 4d0d969

Upload 3 files

Browse files
Files changed (3) hide show
  1. TTSModel.py +173 -0
  2. hyperparams.yaml +173 -0
  3. label_encoder.txt +46 -0
TTSModel.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import re
3
+ import logging
4
+ import torch
5
+ import torchaudio
6
+ import random
7
+ import speechbrain as sb
8
+ import torch as nn
9
+ from speechbrain.utils.fetching import fetch
10
+ from speechbrain.inference.interfaces import Pretrained
11
+ from speechbrain.inference.text import GraphemeToPhoneme
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TTSModel(Pretrained):
18
+ """
19
+ A ready-to-use wrapper for Transformer TTS (text -> mel_spec).
20
+ Arguments
21
+ ---------
22
+ hparams
23
+ Hyperparameters (from HyperPyYAML)"""
24
+
25
+ HPARAMS_NEEDED = ["model", "blank_index", "padding_mask", "lookahead_mask", "mel_spec_feats", "label_encoder"]
26
+ MODULES_NEEDED = ["modules"]
27
+
28
+ def __init__(self, *args, **kwargs):
29
+ super().__init__(*args, **kwargs)
30
+ self.label_encoder = self.hparams.label_encoder
31
+ #self.label_encoder.update_from_iterable(self.hparams["lexicon"], sequence_input=False)
32
+ self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p")
33
+
34
+
35
+ def text_to_phoneme(self, text):
36
+ """
37
+ Generates phoneme sequences for the given text using a Grapheme-to-Phoneme (G2P) model.
38
+
39
+ Args:
40
+ text (str): The input text.
41
+
42
+ Returns:
43
+ list: List of phoneme sequences for the words in the text.
44
+ """
45
+ abbreviation_expansions = {
46
+ "Mr.": "Mister",
47
+ "Mrs.": "Misess",
48
+ "Dr.": "Doctor",
49
+ "No.": "Number",
50
+ "St.": "Saint",
51
+ "Co.": "Company",
52
+ "Jr.": "Junior",
53
+ "Maj.": "Major",
54
+ "Gen.": "General",
55
+ "Drs.": "Doctors",
56
+ "Rev.": "Reverend",
57
+ "Lt.": "Lieutenant",
58
+ "Hon.": "Honorable",
59
+ "Sgt.": "Sergeant",
60
+ "Capt.": "Captain",
61
+ "Esq.": "Esquire",
62
+ "Ltd.": "Limited",
63
+ "Col.": "Colonel",
64
+ "Ft.": "Fort"
65
+ }
66
+
67
+ # Expand abbreviations
68
+ for abbreviation, expansion in abbreviation_expansions.items():
69
+ text = text.replace(abbreviation, expansion)
70
+
71
+ phonemes = self.g2p(text)
72
+ phonemes = self.label_encoder.encode_sequence(phonemes)
73
+ phoneme_seq = torch.LongTensor(phonemes)
74
+
75
+ return phoneme_seq, len(phoneme_seq)
76
+
77
+ def encode_batch(self, texts):
78
+ """Computes mel-spectrogram for a list of texts
79
+
80
+ Texts must be sorted in decreasing order on their lengths
81
+
82
+ Arguments
83
+ ---------
84
+ texts: List[str]
85
+ texts to be encoded into spectrogram
86
+
87
+ Returns
88
+ -------
89
+ tensors of output spectrograms, output lengths and alignments
90
+ """
91
+ with torch.no_grad():
92
+ phoneme_seqs = [self.text_to_phoneme(text)[0] for text in texts]
93
+ phoneme_seqs_padded, input_lengths = self.pad_sequences(phoneme_seqs)
94
+
95
+ encoded_phoneme = self.mods.encoder_emb(phoneme_seqs_padded)
96
+ encoder_emb = self.mods.enc_pre_net(encoded_phoneme)
97
+ pos_emb_enc = self.mods.pos_emb_enc(encoder_emb)
98
+ encoder_emb = encoder_emb + pos_emb_enc
99
+
100
+
101
+ stop_generated = False
102
+ decoder_input = torch.zeros(1, 80, 1, device=self.device)
103
+ stop_tokens_logits = []
104
+ max_generation_length = 1000
105
+ sequence_length = 0
106
+
107
+ result = []
108
+ result.append(decoder_input)
109
+
110
+ src_mask = torch.zeros(encoder_emb.size(1), encoder_emb.size(1), device=self.device)
111
+ src_key_padding_mask = self.hparams.padding_mask(encoder_emb, self.hparams.blank_index)
112
+
113
+
114
+ while not stop_generated and sequence_length < max_generation_length:
115
+ encoded_mel = self.mods.dec_pre_net(decoder_input)
116
+ pos_emb_dec = self.mods.pos_emb_dec(encoded_mel)
117
+ decoder_emb = encoded_mel + pos_emb_dec
118
+
119
+ decoder_output = self.mods.Seq2SeqTransformer(
120
+ encoder_emb, decoder_emb, src_mask=src_mask,
121
+ src_key_padding_mask=src_key_padding_mask)
122
+
123
+ mel_output = self.mods.mel_lin(decoder_output)
124
+
125
+ stop_token_logit = self.mods.stop_lin(decoder_output).squeeze(-1)
126
+
127
+ post_mel_outputs = self.mods.postnet(mel_output.to(self.device))
128
+ refined_mel_output = mel_output + post_mel_outputs.to(self.device)
129
+ refined_mel_output = refined_mel_output.transpose(1, 2)
130
+
131
+ stop_tokens_logits.append(stop_token_logit)
132
+ stop_token_probs = torch.sigmoid(stop_token_logit)
133
+
134
+ if torch.any(stop_token_probs[:, -1] >= self.hparams.stop_threshold):
135
+ stop_generated = True
136
+
137
+ decoder_input = refined_mel_output
138
+ result.append(decoder_input)
139
+ sequence_length += 1
140
+
141
+ results = torch.cat(result, dim=2)
142
+ stop_tokens_logits = torch.cat(stop_tokens_logits, dim=1)
143
+
144
+ return results
145
+
146
+ def pad_sequences(self, sequences):
147
+ """Pad sequences to the maximum length sequence in the batch.
148
+
149
+ Arguments
150
+ ---------
151
+ sequences: List[torch.Tensor]
152
+ The sequences to pad
153
+
154
+ Returns
155
+ -------
156
+ Padded sequences and original lengths
157
+ """
158
+ max_length = max([len(seq) for seq in sequences])
159
+ padded_seqs = torch.zeros(len(sequences), max_length, dtype=torch.long)
160
+ lengths = []
161
+ for i, seq in enumerate(sequences):
162
+ length = len(seq)
163
+ padded_seqs[i, :length] = seq
164
+ lengths.append(length)
165
+ return padded_seqs, torch.tensor(lengths)
166
+
167
+ def encode_text(self, text):
168
+ """Runs inference for a single text str"""
169
+ return self.encode_batch(text)
170
+
171
+ def forward(self, texts):
172
+ "Encodes the input texts."
173
+ return self.encode_batch(texts)
hyperparams.yaml ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ################################
4
+ # Audio Parameters #
5
+ ################################
6
+ sample_rate: 22050
7
+ hop_length: 256
8
+ win_length: 1024
9
+ n_mel_channels: 80
10
+ n_fft: 1024
11
+ mel_fmin: 0.0
12
+ mel_fmax: 8000.0
13
+ power: 1
14
+ normalized: False
15
+ min_max_energy_norm: True
16
+ norm: "slaney"
17
+ mel_scale: "slaney"
18
+ dynamic_range_compression: True
19
+ mel_normalized: False
20
+ min_f0: 65 #(torchaudio pyin values)
21
+ max_f0: 2093 #(torchaudio pyin values)
22
+
23
+ positive_weight: 5.0
24
+ lexicon:
25
+ - AA
26
+ - AE
27
+ - AH
28
+ - AO
29
+ - AW
30
+ - AY
31
+ - B
32
+ - CH
33
+ - D
34
+ - DH
35
+ - EH
36
+ - ER
37
+ - EY
38
+ - F
39
+ - G
40
+ - HH
41
+ - IH
42
+ - IY
43
+ - JH
44
+ - K
45
+ - L
46
+ - M
47
+ - N
48
+ - NG
49
+ - OW
50
+ - OY
51
+ - P
52
+ - R
53
+ - S
54
+ - SH
55
+ - T
56
+ - TH
57
+ - UH
58
+ - UW
59
+ - V
60
+ - W
61
+ - Y
62
+ - Z
63
+ - ZH
64
+ - ' '
65
+ n_symbols: 42 #fixed depending on symbols in the lexicon +1 for a dummy symbol used for padding
66
+ padding_idx: 0
67
+
68
+ # Define model architecture
69
+ d_model: 512
70
+ nhead: 8
71
+ num_encoder_layers: 6
72
+ num_decoder_layers: 6
73
+ dim_feedforward: 2048
74
+ dropout: 0.2
75
+ blank_index: 0 # This special token is for padding
76
+ bos_index: 1
77
+ eos_index: 2
78
+ stop_weight: 0.45
79
+ stop_threshold: 0.5
80
+
81
+
82
+ ###################PRENET#######################
83
+ enc_pre_net: !new:models.EncoderPrenet
84
+ dec_pre_net: !new:models.DecoderPrenet
85
+
86
+
87
+ encoder_emb: !new:torch.nn.Embedding
88
+ num_embeddings: 128
89
+ embedding_dim: !ref <d_model>
90
+ padding_idx: !ref <blank_index>
91
+
92
+ pos_emb_enc: !new:models.ScaledPositionalEncoding
93
+ d_model: !ref <d_model>
94
+
95
+ decoder_emb: !new:torch.nn.Embedding
96
+ num_embeddings: 128
97
+ embedding_dim: !ref <d_model>
98
+ padding_idx: !ref <blank_index>
99
+
100
+ pos_emb_dec: !new:models.ScaledPositionalEncoding
101
+ d_model: !ref <d_model>
102
+
103
+
104
+ Seq2SeqTransformer: !new:torch.nn.Transformer
105
+ d_model: !ref <d_model>
106
+ nhead: !ref <nhead>
107
+ num_encoder_layers: !ref <num_encoder_layers>
108
+ num_decoder_layers: !ref <num_decoder_layers>
109
+ dim_feedforward: !ref <dim_feedforward>
110
+ dropout: !ref <dropout>
111
+ batch_first: True
112
+
113
+ postnet: !new:models.PostNet
114
+ mel_channels: !ref <n_mel_channels>
115
+ postnet_channels: 512
116
+ kernel_size: 5
117
+ postnet_layers: 5
118
+
119
+ mel_lin: !new:speechbrain.nnet.linear.Linear
120
+ input_size: !ref <d_model>
121
+ n_neurons: !ref <n_mel_channels>
122
+
123
+ stop_lin: !new:speechbrain.nnet.linear.Linear
124
+ input_size: !ref <d_model>
125
+ n_neurons: 1
126
+
127
+ mel_spec_feats: !name:speechbrain.lobes.models.FastSpeech2.mel_spectogram
128
+ sample_rate: !ref <sample_rate>
129
+ hop_length: !ref <hop_length>
130
+ win_length: !ref <win_length>
131
+ n_fft: !ref <n_fft>
132
+ n_mels: !ref <n_mel_channels>
133
+ f_min: !ref <mel_fmin>
134
+ f_max: !ref <mel_fmax>
135
+ power: !ref <power>
136
+ normalized: !ref <normalized>
137
+ min_max_energy_norm: !ref <min_max_energy_norm>
138
+ norm: !ref <norm>
139
+ mel_scale: !ref <mel_scale>
140
+ compression: !ref <dynamic_range_compression>
141
+
142
+ modules:
143
+ enc_pre_net: !ref <enc_pre_net>
144
+ encoder_emb: !ref <encoder_emb>
145
+ pos_emb_enc: !ref <pos_emb_enc>
146
+
147
+ dec_pre_net: !ref <dec_pre_net>
148
+ #decoder_emb: !ref <decoder_emb>
149
+ pos_emb_dec: !ref <pos_emb_dec>
150
+
151
+ Seq2SeqTransformer: !ref <Seq2SeqTransformer>
152
+ postnet: !ref <postnet>
153
+ mel_lin: !ref <mel_lin>
154
+ stop_lin: !ref <stop_lin>
155
+ model: !ref <model>
156
+
157
+ lookahead_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_lookahead_mask
158
+ padding_mask: !name:speechbrain.lobes.models.transformer.Transformer.get_key_padding_mask
159
+
160
+ model: !new:torch.nn.ModuleList
161
+ - [!ref <enc_pre_net>, !ref <encoder_emb>, !ref <pos_emb_enc>, !ref <dec_pre_net>, !ref <pos_emb_dec>, !ref <Seq2SeqTransformer>, !ref <postnet>, !ref <mel_lin>, !ref <stop_lin>]
162
+
163
+ label_encoder: !new:speechbrain.dataio.encoder.TextEncoder
164
+
165
+ pretrained_path: /content/
166
+
167
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
168
+ loadables:
169
+ model: !ref <model>
170
+ label_encoder: !ref <label_encoder>
171
+ paths:
172
+ model: !ref <pretrained_path>/model.ckpt
173
+ label_encoder: !ref <pretrained_path>/label_encoder.txt
label_encoder.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'AA' => 0
2
+ 'AE' => 40
3
+ 'AH' => 41
4
+ 'AO' => 3
5
+ 'AW' => 4
6
+ 'AY' => 5
7
+ 'B' => 6
8
+ 'CH' => 7
9
+ 'D' => 8
10
+ 'DH' => 9
11
+ 'EH' => 10
12
+ 'ER' => 11
13
+ 'EY' => 12
14
+ 'F' => 13
15
+ 'G' => 14
16
+ 'HH' => 15
17
+ 'IH' => 16
18
+ 'IY' => 17
19
+ 'JH' => 18
20
+ 'K' => 19
21
+ 'L' => 20
22
+ 'M' => 21
23
+ 'N' => 22
24
+ 'NG' => 23
25
+ 'OW' => 24
26
+ 'OY' => 25
27
+ 'P' => 26
28
+ 'R' => 27
29
+ 'S' => 28
30
+ 'SH' => 29
31
+ 'T' => 30
32
+ 'TH' => 31
33
+ 'UH' => 32
34
+ 'UW' => 33
35
+ 'V' => 34
36
+ 'W' => 35
37
+ 'Y' => 36
38
+ 'Z' => 37
39
+ 'ZH' => 38
40
+ ' ' => 39
41
+ '<bos>' => 1
42
+ '<eos>' => 2
43
+ ================
44
+ 'starting_index' => 0
45
+ 'bos_label' => '<bos>'
46
+ 'eos_label' => '<eos>'