Bingsu commited on
Commit
d07276d
1 Parent(s): 33fff92

Upload 3 files

Browse files
Files changed (3) hide show
  1. collator.py +68 -0
  2. modeling_tacotron2.py +323 -0
  3. processing_tacotron2.py +224 -0
collator.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ from transformers.utils import TensorType
5
+ from transformers.feature_extraction_utils import BatchFeature
6
+
7
+
8
+ class PadAndSortCollator:
9
+ def __init__(self, processor, return_tensors: Union[str, TensorType] = "pt"):
10
+ self.processor = processor
11
+ self.return_tensors = return_tensors
12
+
13
+ def __call__(self, batch):
14
+ """
15
+ expect batch with `return_tensors=None` from processor
16
+ batch: input_ids, length(optional), mel_specgram, mel_specgram_length(optional)
17
+ """
18
+ text_batch = {}
19
+ text_batch["input_ids"] = [x["input_ids"] for x in batch]
20
+ if "length" in batch[0]:
21
+ text_batch["length"] = [x["length"] for x in batch]
22
+ else:
23
+ text_batch["length"] = [len(x["input_ids"]) for x in batch]
24
+
25
+ audio_batch = {}
26
+ # transpose mel_specgram for padding
27
+ audio_batch["mel_specgram"] = [
28
+ x["mel_specgram"][0].transpose(1, 0) for x in batch
29
+ ]
30
+ if "mel_specgram_length" in batch[0]:
31
+ audio_batch["mel_specgram_length"] = [
32
+ x["mel_specgram_length"] for x in batch
33
+ ]
34
+ else:
35
+ audio_batch["mel_specgram_length"] = [
36
+ x["mel_specgram"][0].shape[1] for x in batch
37
+ ]
38
+
39
+ text_batch = self.processor.tokenizer.pad(
40
+ text_batch,
41
+ padding=True,
42
+ return_tensors="np",
43
+ return_attention_mask=False,
44
+ )
45
+
46
+ audio_batch = self.processor.feature_extractor.pad(
47
+ audio_batch,
48
+ padding=True,
49
+ return_tensors="np",
50
+ return_attention_mask=True,
51
+ )
52
+ audio_batch["mel_specgram"] = audio_batch["mel_specgram"].transpose(0, 2, 1)
53
+
54
+ attention_mask = audio_batch.pop("attention_mask")
55
+ gate_padded = 1 - attention_mask
56
+ gate_padded = np.roll(gate_padded, -1, axis=1)
57
+ gate_padded[:, -1] = 1
58
+ gate_padded = gate_padded.astype(np.float32)
59
+
60
+ output = {**text_batch, **audio_batch, "gate_padded": gate_padded}
61
+
62
+ # sort by text length
63
+ sort_idx = np.argsort(output["length"])[::-1]
64
+
65
+ for key, value in output.items():
66
+ output[key] = value[sort_idx]
67
+
68
+ return BatchFeature(output, tensor_type=self.return_tensors)
modeling_tacotron2.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from torch import Tensor, nn
6
+ from torchaudio.models import Tacotron2
7
+ from transformers import PretrainedConfig, PreTrainedModel
8
+ from transformers.utils import ModelOutput
9
+
10
+
11
+ @dataclass
12
+ class Tacotron2Output(ModelOutput):
13
+ """
14
+ mel_outputs_postnet
15
+ The predicted mel spectrogram with shape
16
+ `(n_batch, n_mels, max of mel_specgram_lengths)`.
17
+ mel_specgram_lengths
18
+ The length of the predicted mel spectrogram with shape `(n_batch, )`.
19
+ alignments
20
+ Sequence of attention weights from the decoder with shape
21
+ `(n_batch, max of mel_specgram_lengths, max of lengths)`.
22
+ """
23
+
24
+ mel_outputs_postnet: Tensor = None
25
+ mel_specgram_lengths: Tensor = None
26
+ alignments: Tensor = None
27
+
28
+
29
+ @dataclass
30
+ class Tacotron2ForPreTrainingOutput(ModelOutput):
31
+ """
32
+ mel_specgram
33
+ Mel spectrogram before Postnet with shape
34
+ `(n_batch, n_mels, max of mel_specgram_lengths)`.
35
+ mel_specgram_postnet
36
+ Mel spectrogram after Postnet with shape
37
+ `(n_batch, n_mels, max of mel_specgram_lengths)`.
38
+ gate_outputs
39
+ The output for stop token at each time step with shape
40
+ `(n_batch, max of mel_specgram_lengths)`.
41
+ alignments
42
+ Sequence of attention weights from the decoder with shape
43
+ `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
44
+ """
45
+
46
+ mel_specgram: Tensor = None
47
+ mel_specgram_postnet: Tensor = None
48
+ gate_outputs: Tensor = None
49
+ alignments: Tensor = None
50
+ loss: Optional[Tensor] = None
51
+ mel_loss: Optional[Tensor] = None
52
+ mel_postnet_loss: Optional[Tensor] = None
53
+ gate_loss: Optional[Tensor] = None
54
+
55
+
56
+ class Tacotron2Config(PretrainedConfig):
57
+ def __init__(
58
+ self,
59
+ mask_padding: bool = False,
60
+ n_mels: int = 80,
61
+ n_symbol: int = 392,
62
+ n_frames_per_step: int = 1,
63
+ symbol_embedding_dim: int = 512,
64
+ encoder_embedding_dim: int = 512,
65
+ encoder_n_convolution: int = 3,
66
+ encoder_kernel_size: int = 5,
67
+ decoder_rnn_dim: int = 1024,
68
+ decoder_max_step: int = 2000,
69
+ decoder_dropout: float = 0.1,
70
+ decoder_early_stopping: bool = True,
71
+ attention_rnn_dim: int = 1024,
72
+ attention_hidden_dim: int = 128,
73
+ attention_location_n_filter: int = 32,
74
+ attention_location_kernel_size: int = 31,
75
+ attention_dropout: float = 0.1,
76
+ prenet_dim: int = 256,
77
+ postnet_n_convolution: int = 5,
78
+ postnet_kernel_size: int = 5,
79
+ postnet_embedding_dim: int = 512,
80
+ gate_threshold: float = 0.5,
81
+ **kwargs,
82
+ ):
83
+ # https://pytorch.org/audio/stable/generated/torchaudio.models.Tacotron2.html#torchaudio.models.Tacotron2 # noqa
84
+ if n_frames_per_step != 1:
85
+ raise ValueError(
86
+ f"n_frames_per_step: only 1 is supported, got {n_frames_per_step}"
87
+ )
88
+
89
+ self.mask_padding = mask_padding
90
+ self.n_mels = n_mels
91
+ self.n_symbol = n_symbol
92
+ self.n_frames_per_step = n_frames_per_step
93
+ self.symbol_embedding_dim = symbol_embedding_dim
94
+ self.encoder_embedding_dim = encoder_embedding_dim
95
+ self.encoder_n_convolution = encoder_n_convolution
96
+ self.encoder_kernel_size = encoder_kernel_size
97
+ self.decoder_rnn_dim = decoder_rnn_dim
98
+ self.decoder_max_step = decoder_max_step
99
+ self.decoder_dropout = decoder_dropout
100
+ self.decoder_early_stopping = decoder_early_stopping
101
+ self.attention_rnn_dim = attention_rnn_dim
102
+ self.attention_hidden_dim = attention_hidden_dim
103
+ self.attention_location_n_filter = attention_location_n_filter
104
+ self.attention_location_kernel_size = attention_location_kernel_size
105
+ self.attention_dropout = attention_dropout
106
+ self.prenet_dim = prenet_dim
107
+ self.postnet_n_convolution = postnet_n_convolution
108
+ self.postnet_kernel_size = postnet_kernel_size
109
+ self.postnet_embedding_dim = postnet_embedding_dim
110
+ self.gate_threshold = gate_threshold
111
+ super().__init__(**kwargs)
112
+
113
+
114
+ class Tacotron2PreTrainedModel(PreTrainedModel):
115
+ config_class = Tacotron2Config
116
+ base_model_prefix = "tacotron2"
117
+ main_input_name = "input_ids"
118
+
119
+
120
+ class Tacotron2Model(Tacotron2PreTrainedModel):
121
+ def __init__(self, config: Tacotron2Config):
122
+ super().__init__(config)
123
+ self.tacotron2 = Tacotron2(
124
+ mask_padding=config.mask_padding,
125
+ n_mels=config.n_mels,
126
+ n_symbol=config.n_symbol,
127
+ n_frames_per_step=config.n_frames_per_step,
128
+ symbol_embedding_dim=config.symbol_embedding_dim,
129
+ encoder_embedding_dim=config.encoder_embedding_dim,
130
+ encoder_n_convolution=config.encoder_n_convolution,
131
+ encoder_kernel_size=config.encoder_kernel_size,
132
+ decoder_rnn_dim=config.decoder_rnn_dim,
133
+ decoder_max_step=config.decoder_max_step,
134
+ decoder_dropout=config.decoder_dropout,
135
+ decoder_early_stopping=config.decoder_early_stopping,
136
+ attention_rnn_dim=config.attention_rnn_dim,
137
+ attention_hidden_dim=config.attention_hidden_dim,
138
+ attention_location_n_filter=config.attention_location_n_filter,
139
+ attention_location_kernel_size=config.attention_location_kernel_size,
140
+ attention_dropout=config.attention_dropout,
141
+ prenet_dim=config.prenet_dim,
142
+ postnet_n_convolution=config.postnet_n_convolution,
143
+ postnet_kernel_size=config.postnet_kernel_size,
144
+ postnet_embedding_dim=config.postnet_embedding_dim,
145
+ gate_threshold=config.gate_threshold,
146
+ )
147
+
148
+ def forward(
149
+ self,
150
+ input_ids: Tensor,
151
+ length: Optional[Tensor] = None,
152
+ return_dict: Optional[bool] = None,
153
+ ):
154
+ r"""
155
+ Using Tacotron2 for inference. The input is a batch of encoded
156
+ sentences (``tokens``) and its corresponding lengths (``lengths``). The
157
+ output is the generated mel spectrograms, its corresponding lengths, and
158
+ the attention weights from the decoder.
159
+
160
+ The input `tokens` should be padded with zeros to length max of ``lengths``.
161
+
162
+ Args:
163
+ tokens (Tensor):
164
+ The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
165
+ lengths (Tensor or None, optional):
166
+ The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
167
+ If ``None``, it is assumed that the all the tokens are valid.
168
+ Default: ``None``
169
+
170
+ Returns:
171
+ (Tensor, Tensor, Tensor):
172
+ Tensor
173
+ The predicted mel spectrogram with shape
174
+ `(n_batch, n_mels, max of mel_specgram_lengths)`.
175
+ Tensor
176
+ The length of the predicted mel spectrogram with shape
177
+ `(n_batch, )`.
178
+ Tensor
179
+ Sequence of attention weights from the decoder with shape
180
+ `(n_batch, max of mel_specgram_lengths, max of lengths)`.
181
+ """
182
+ return_dict = (
183
+ return_dict if return_dict is not None else self.config.use_return_dict
184
+ )
185
+ outputs = self.tacotron2.infer(tokens=input_ids, lengths=length)
186
+
187
+ if not return_dict:
188
+ return outputs
189
+
190
+ return Tacotron2Output(
191
+ mel_outputs_postnet=outputs[0],
192
+ mel_specgram_lengths=outputs[1],
193
+ alignments=outputs[2],
194
+ )
195
+
196
+
197
+ class Tacotron2Loss(nn.Module):
198
+ """Tacotron2 loss function modified from:
199
+ https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py # noqa
200
+ """
201
+
202
+ def __init__(self):
203
+ super().__init__()
204
+
205
+ self.mse_loss = nn.MSELoss(reduction="mean")
206
+ self.bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
207
+
208
+ def forward(
209
+ self,
210
+ model_outputs: Tuple[Tensor, Tensor, Tensor],
211
+ targets: Tuple[Tensor, Tensor],
212
+ ) -> Tuple[Tensor, Tensor, Tensor]:
213
+ r"""Pass the input through the Tacotron2 loss.
214
+ The original implementation was introduced in
215
+ *Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
216
+ [:footcite:`shen2018natural`].
217
+ Args:
218
+ model_outputs (tuple of three Tensors): The outputs of the
219
+ Tacotron2. These outputs should include three items:
220
+ (1) the predicted mel spectrogram before the postnet (``mel_specgram``)
221
+ with shape (batch, mel, time).
222
+ (2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``) # noqa
223
+ with shape (batch, mel, time), and
224
+ (3) the stop token prediction (``gate_out``) with shape (batch, ).
225
+ targets (tuple of two Tensors):
226
+ The ground truth mel spectrogram (batch, mel, time) and
227
+ stop token with shape (batch, ).
228
+
229
+ Returns:
230
+ mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram # noqa
231
+ with shape ``torch.Size([])``.
232
+ mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
233
+ ground truth mel spectrogram with shape ``torch.Size([])``.
234
+ gate_loss (Tensor): The mean binary cross entropy loss of
235
+ the prediction on the stop token with shape ``torch.Size([])``.
236
+ """
237
+ mel_target, gate_target = targets[0], targets[1]
238
+ gate_target = gate_target.view(-1, 1)
239
+
240
+ mel_specgram, mel_specgram_postnet, gate_out = model_outputs
241
+ gate_out = gate_out.view(-1, 1)
242
+ mel_loss = self.mse_loss(mel_specgram, mel_target)
243
+ mel_postnet_loss = self.mse_loss(mel_specgram_postnet, mel_target)
244
+ gate_loss = self.bce_loss(gate_out, gate_target)
245
+ return mel_loss, mel_postnet_loss, gate_loss
246
+
247
+
248
+ class Tacotron2ForPreTraining(Tacotron2PreTrainedModel):
249
+ def __init__(self, config: Tacotron2Config):
250
+ super().__init__(config)
251
+ self.tacotron2 = Tacotron2(
252
+ mask_padding=config.mask_padding,
253
+ n_mels=config.n_mels,
254
+ n_symbol=config.n_symbol,
255
+ n_frames_per_step=config.n_frames_per_step,
256
+ symbol_embedding_dim=config.symbol_embedding_dim,
257
+ encoder_embedding_dim=config.encoder_embedding_dim,
258
+ encoder_n_convolution=config.encoder_n_convolution,
259
+ encoder_kernel_size=config.encoder_kernel_size,
260
+ decoder_rnn_dim=config.decoder_rnn_dim,
261
+ decoder_max_step=config.decoder_max_step,
262
+ decoder_dropout=config.decoder_dropout,
263
+ decoder_early_stopping=config.decoder_early_stopping,
264
+ attention_rnn_dim=config.attention_rnn_dim,
265
+ attention_hidden_dim=config.attention_hidden_dim,
266
+ attention_location_n_filter=config.attention_location_n_filter,
267
+ attention_location_kernel_size=config.attention_location_kernel_size,
268
+ attention_dropout=config.attention_dropout,
269
+ prenet_dim=config.prenet_dim,
270
+ postnet_n_convolution=config.postnet_n_convolution,
271
+ postnet_kernel_size=config.postnet_kernel_size,
272
+ postnet_embedding_dim=config.postnet_embedding_dim,
273
+ gate_threshold=config.gate_threshold,
274
+ )
275
+
276
+ self.loss_fct = Tacotron2Loss()
277
+
278
+ def sync_batchnorm(self):
279
+ self.tacotron2 = nn.SyncBatchNorm.convert_sync_batchnorm(self.tacotron2)
280
+
281
+ def forward(
282
+ self,
283
+ input_ids: Tensor,
284
+ length: Tensor,
285
+ mel_specgram: Tensor,
286
+ mel_specgram_length: Tensor,
287
+ gate_padded: Optional[Tensor] = None,
288
+ return_dict: Optional[bool] = None,
289
+ ):
290
+ return_dict = (
291
+ return_dict if return_dict is not None else self.config.use_return_dict
292
+ )
293
+
294
+ outputs = self.tacotron2(
295
+ tokens=input_ids,
296
+ token_lengths=length,
297
+ mel_specgram=mel_specgram,
298
+ mel_specgram_lengths=mel_specgram_length,
299
+ )
300
+
301
+ loss = mel_loss = mel_postnet_loss = gate_loss = None
302
+ if gate_padded is not None:
303
+ targets = (mel_specgram, gate_padded)
304
+ targets[0].requires_grad = False
305
+ targets[1].requires_grad = False
306
+ mel_loss, mel_postnet_loss, gate_loss = self.loss_fct(outputs[:3], targets)
307
+ loss = mel_loss + mel_postnet_loss + gate_loss
308
+
309
+ if not return_dict:
310
+ if loss is not None:
311
+ return outputs + (loss, mel_loss, mel_postnet_loss, gate_loss)
312
+ return outputs
313
+
314
+ return Tacotron2ForPreTrainingOutput(
315
+ mel_specgram=outputs[0],
316
+ mel_specgram_postnet=outputs[1],
317
+ gate_outputs=outputs[2],
318
+ alignments=outputs[3],
319
+ loss=loss,
320
+ mel_loss=mel_loss,
321
+ mel_postnet_loss=mel_postnet_loss,
322
+ gate_loss=gate_loss,
323
+ )
processing_tacotron2.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Any, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torchaudio.transforms import MelSpectrogram
7
+ from transformers import Wav2Vec2PhonemeCTCTokenizer
8
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
9
+ from transformers.feature_extraction_utils import BatchFeature
10
+ from transformers.processing_utils import ProcessorMixin
11
+ from transformers.utils import TensorType, logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+ AudioType = Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]]
15
+
16
+
17
+ class Tacotron2FeatureExtractor(SequenceFeatureExtractor):
18
+ model_input_names = ["mel_specgram", "mel_specgram_length", "gate_padded"]
19
+
20
+ def __init__(
21
+ self,
22
+ feature_size: int = 80, # n_mels
23
+ sampling_rate: int = 22050,
24
+ n_fft: int = 1024,
25
+ hop_length: int = 256,
26
+ win_length: int = 1024,
27
+ mel_fmin: float = 0.0,
28
+ mel_fmax: float = 8000.0,
29
+ padding_value: float = 0.0,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(
33
+ feature_size=feature_size,
34
+ sampling_rate=sampling_rate,
35
+ padding_value=padding_value,
36
+ **kwargs,
37
+ )
38
+ self.feature_size = feature_size
39
+ self.sampling_rate = sampling_rate
40
+ self.n_fft = n_fft
41
+ self.hop_length = hop_length
42
+ self.win_length = win_length
43
+ self.mel_fmin = mel_fmin
44
+ self.mel_fmax = mel_fmax
45
+
46
+ def mel_specgram(self, waveform: torch.Tensor) -> torch.Tensor:
47
+ if not hasattr(self, "_mel_specgram"):
48
+ self._mel_specgram = MelSpectrogram(
49
+ sample_rate=self.sampling_rate,
50
+ n_fft=self.n_fft,
51
+ win_length=self.win_length,
52
+ hop_length=self.hop_length,
53
+ f_min=self.mel_fmin,
54
+ f_max=self.mel_fmax,
55
+ n_mels=self.feature_size,
56
+ mel_scale="slaney",
57
+ normalized=False,
58
+ power=1,
59
+ norm="slaney",
60
+ )
61
+ melspectrogram = self._mel_specgram(waveform)
62
+ # spectral normalization
63
+ output = torch.log(torch.clamp(melspectrogram, min=1e-5))
64
+
65
+ # transpose for padding
66
+ return output.permute(1, 0)
67
+
68
+ def __call__(
69
+ self,
70
+ audio: AudioType,
71
+ sampling_rate: Optional[int] = None,
72
+ padding: Union[bool, str] = True,
73
+ return_tensors: Optional[Union[str, TensorType]] = None,
74
+ return_length: bool = False,
75
+ return_gate_padded: bool = False,
76
+ **kwargs,
77
+ ) -> BatchFeature:
78
+
79
+ if sampling_rate is not None:
80
+ if sampling_rate != self.sampling_rate:
81
+ raise ValueError(
82
+ f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
83
+ f" {self.sampling_rate}. Please make sure that the provided `audio` input was sampled with"
84
+ f" {self.sampling_rate} and not {sampling_rate}."
85
+ )
86
+
87
+ else:
88
+ logger.warning(
89
+ "It is strongly recommended to pass the `sampling_rate` argument to this function. "
90
+ "Failing to do so can result in silent errors that might be hard to debug."
91
+ )
92
+
93
+ is_batched = bool(
94
+ isinstance(audio, (list, tuple))
95
+ and (
96
+ isinstance(audio[0], np.ndarray) or isinstance(audio[0], (tuple, list))
97
+ )
98
+ )
99
+
100
+ if is_batched:
101
+ audio = [np.asarray(speech, dtype=np.float32) for speech in audio]
102
+ elif not is_batched and not isinstance(audio, np.ndarray):
103
+ audio = np.asarray(audio, dtype=np.float32)
104
+ elif isinstance(audio, np.ndarray) and audio.dtype is np.dtype(np.float64):
105
+ audio = audio.astype(np.float32)
106
+
107
+ # always return batch
108
+ if not is_batched:
109
+ audio = [audio]
110
+
111
+ features = [
112
+ self.mel_specgram(torch.from_numpy(one_waveform)).numpy()
113
+ for one_waveform in audio
114
+ ]
115
+
116
+ encoded_inputs = BatchFeature({"mel_specgram": features})
117
+
118
+ padded_inputs = self.pad(
119
+ encoded_inputs,
120
+ padding=padding,
121
+ return_attention_mask=return_gate_padded,
122
+ **kwargs,
123
+ )
124
+
125
+ if return_length:
126
+ mel_specgram_length = [mel.shape[0] for mel in features]
127
+ if len(mel_specgram_length) == 1 and return_tensors is None:
128
+ mel_specgram_length = mel_specgram_length[0]
129
+ padded_inputs["mel_specgram_length"] = mel_specgram_length
130
+
131
+ if return_gate_padded:
132
+ gate_padded = 1 - padded_inputs.pop("attention_mask")
133
+ gate_padded = np.roll(gate_padded, -1, axis=1)
134
+ gate_padded[:, -1] = 1
135
+ gate_padded = gate_padded.astype(np.float32)
136
+ padded_inputs["gate_padded"] = gate_padded
137
+
138
+ mel_specgram = padded_inputs["mel_specgram"]
139
+ if isinstance(mel_specgram[0], list):
140
+ padded_inputs["mel_specgram"] = [
141
+ np.asarray(feature, dtype=np.float32) for feature in mel_specgram
142
+ ]
143
+
144
+ padded_inputs["mel_specgram"] = [
145
+ spec.transpose(1, 0) for spec in padded_inputs["mel_specgram"]
146
+ ]
147
+
148
+ if return_tensors is not None:
149
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
150
+
151
+ return padded_inputs
152
+
153
+ def to_dict(self) -> Dict[str, Any]:
154
+ """
155
+ Serializes this instance to a Python dictionary.
156
+ Returns:
157
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this feature extractor instance.
158
+ """
159
+ output = copy.deepcopy(self.__dict__)
160
+ output["feature_extractor_type"] = self.__class__.__name__
161
+ output.pop("_mel_specgram", None)
162
+
163
+ return output
164
+
165
+
166
+ class Tacotron2Processor(ProcessorMixin):
167
+ feature_extractor_class = "AutoFeatureExtractor"
168
+ tokenizer_class = "Wav2Vec2PhonemeCTCTokenizer"
169
+
170
+ def __init__(self, feature_extractor, tokenizer):
171
+ self.feature_extractor = feature_extractor
172
+ self.tokenizer = tokenizer
173
+ self.current_processor = self.feature_extractor
174
+
175
+ def __call__(
176
+ self,
177
+ text: Optional[str] = None,
178
+ audio: Optional[AudioType] = None,
179
+ return_tensors: Optional[Union[str, TensorType]] = None,
180
+ return_length: bool = True,
181
+ **kwargs,
182
+ ) -> Any:
183
+ if text is None and audio is None:
184
+ raise ValueError(
185
+ "You have to specify either text or audio. Both cannot be none."
186
+ )
187
+
188
+ if text is not None:
189
+ encoding = self.tokenizer(
190
+ text,
191
+ return_tensors=return_tensors,
192
+ padding=True,
193
+ return_attention_mask=False,
194
+ return_length=return_length,
195
+ )
196
+
197
+ if audio is not None:
198
+ features = self.feature_extractor(
199
+ audio,
200
+ return_tensors=return_tensors,
201
+ return_length=return_length,
202
+ **kwargs,
203
+ )
204
+
205
+ if text is not None and audio is not None:
206
+ return BatchFeature({**features, **encoding})
207
+ elif text is not None:
208
+ return encoding
209
+ else:
210
+ return features
211
+
212
+ def batch_decode(self, *args, **kwargs):
213
+ """
214
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
215
+ refer to the docstring of this method for more information.
216
+ """
217
+ return self.tokenizer.batch_decode(*args, **kwargs)
218
+
219
+ def decode(self, *args, **kwargs):
220
+ """
221
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
222
+ to the docstring of this method for more information.
223
+ """
224
+ return self.tokenizer.decode(*args, **kwargs)