Spaces:
Build error
Build error
""" | |
BSD 3-Clause License | |
Copyright (c) 2018, NVIDIA Corporation | |
All rights reserved. | |
Redistribution and use in source and binary forms, with or without | |
modification, are permitted provided that the following conditions are met: | |
* Redistributions of source code must retain the above copyright notice, this | |
list of conditions and the following disclaimer. | |
* Redistributions in binary form must reproduce the above copyright notice, | |
this list of conditions and the following disclaimer in the documentation | |
and/or other materials provided with the distribution. | |
* Neither the name of the copyright holder nor the names of its | |
contributors may be used to endorse or promote products derived from | |
this software without specific prior written permission. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
""" | |
import torch | |
class TextMelCollate: | |
"""Zero-pads model inputs and targets based on number of frames per setep""" | |
def __init__(self): | |
self.n_frames_per_step = 1 | |
def __call__(self, batch): | |
"""Collate's training batch from normalized text and mel-spectrogram | |
PARAMS | |
------ | |
batch: [text_normalized, mel_normalized] | |
""" | |
# Right zero-pad all one-hot text sequences to max input length | |
input_lengths, ids_sorted_decreasing = torch.sort( | |
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True | |
) | |
max_input_len = input_lengths[0] | |
text_padded = torch.LongTensor(len(batch), max_input_len) | |
text_padded.zero_() | |
for i in range(len(ids_sorted_decreasing)): | |
text = batch[ids_sorted_decreasing[i]][0] | |
text_padded[i, : text.size(0)] = text | |
# Right zero-pad mel-spec | |
num_mels = batch[0][1].size(0) | |
max_target_len = max([x[1].size(1) for x in batch]) | |
if max_target_len % self.n_frames_per_step != 0: | |
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step | |
assert max_target_len % self.n_frames_per_step == 0 | |
# include mel padded and gate padded | |
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len) | |
mel_padded.zero_() | |
gate_padded = torch.FloatTensor(len(batch), max_target_len) | |
gate_padded.zero_() | |
output_lengths = torch.LongTensor(len(batch)) | |
for i in range(len(ids_sorted_decreasing)): | |
mel = batch[ids_sorted_decreasing[i]][1] | |
mel_padded[i, :, : mel.size(1)] = mel | |
gate_padded[i, mel.size(1) - 1 :] = 1 | |
output_lengths[i] = mel.size(1) | |
return text_padded, input_lengths, mel_padded, gate_padded, output_lengths | |