vishred18's picture
Upload 364 files
d5ee97c
raw
history blame
No virus
10.1 kB
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tacotron Related Dataset modules."""
import itertools
import logging
import os
import random
import numpy as np
import tensorflow as tf
from tensorflow_tts.datasets.abstract_dataset import AbstractDataset
from tensorflow_tts.utils import find_files
class CharactorMelDataset(AbstractDataset):
"""Tensorflow Charactor Mel dataset."""
def __init__(
self,
dataset,
root_dir,
charactor_query="*-ids.npy",
mel_query="*-norm-feats.npy",
align_query="",
charactor_load_fn=np.load,
mel_load_fn=np.load,
mel_length_threshold=0,
reduction_factor=1,
mel_pad_value=0.0,
char_pad_value=0,
ga_pad_value=-1.0,
g=0.2,
use_fixed_shapes=False,
):
"""Initialize dataset.
Args:
root_dir (str): Root directory including dumped files.
charactor_query (str): Query to find charactor files in root_dir.
mel_query (str): Query to find feature files in root_dir.
charactor_load_fn (func): Function to load charactor file.
align_query (str): Query to find FAL files in root_dir. If empty, we use stock guided attention loss
mel_load_fn (func): Function to load feature file.
mel_length_threshold (int): Threshold to remove short feature files.
reduction_factor (int): Reduction factor on Tacotron-2 paper.
mel_pad_value (float): Padding value for mel-spectrogram.
char_pad_value (int): Padding value for charactor.
ga_pad_value (float): Padding value for guided attention.
g (float): G value for guided attention.
use_fixed_shapes (bool): Use fixed shape for mel targets or not.
max_char_length (int): maximum charactor length if use_fixed_shapes=True.
max_mel_length (int): maximum mel length if use_fixed_shapes=True
"""
# find all of charactor and mel files.
charactor_files = sorted(find_files(root_dir, charactor_query))
mel_files = sorted(find_files(root_dir, mel_query))
mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files]
char_lengths = [charactor_load_fn(f).shape[0] for f in charactor_files]
# assert the number of files
assert len(mel_files) != 0, f"Not found any mels files in ${root_dir}."
assert (
len(mel_files) == len(charactor_files) == len(mel_lengths)
), f"Number of charactor, mel and duration files are different \
({len(mel_files)} vs {len(charactor_files)} vs {len(mel_lengths)})."
self.align_files = []
if len(align_query) > 1:
align_files = sorted(find_files(root_dir, align_query))
assert len(align_files) == len(
mel_files
), f"Number of align files ({len(align_files)}) and mel files ({len(mel_files)}) are different"
logging.info("Using FAL loss")
self.align_files = align_files
else:
logging.info("Using guided attention loss")
if ".npy" in charactor_query:
suffix = charactor_query[1:]
utt_ids = [os.path.basename(f).replace(suffix, "") for f in charactor_files]
# set global params
self.utt_ids = utt_ids
self.mel_files = mel_files
self.charactor_files = charactor_files
self.mel_load_fn = mel_load_fn
self.charactor_load_fn = charactor_load_fn
self.mel_lengths = mel_lengths
self.char_lengths = char_lengths
self.reduction_factor = reduction_factor
self.mel_length_threshold = mel_length_threshold
self.mel_pad_value = mel_pad_value
self.char_pad_value = char_pad_value
self.ga_pad_value = ga_pad_value
self.g = g
self.use_fixed_shapes = use_fixed_shapes
self.max_char_length = np.max(char_lengths)
if np.max(mel_lengths) % self.reduction_factor == 0:
self.max_mel_length = np.max(mel_lengths)
else:
self.max_mel_length = (
np.max(mel_lengths)
+ self.reduction_factor
- np.max(mel_lengths) % self.reduction_factor
)
def get_args(self):
return [self.utt_ids]
def generator(self, utt_ids):
for i, utt_id in enumerate(utt_ids):
mel_file = self.mel_files[i]
charactor_file = self.charactor_files[i]
align_file = self.align_files[i] if len(self.align_files) > 1 else ""
items = {
"utt_ids": utt_id,
"mel_files": mel_file,
"charactor_files": charactor_file,
"align_files": align_file,
}
yield items
@tf.function
def _load_data(self, items):
mel = tf.numpy_function(np.load, [items["mel_files"]], tf.float32)
charactor = tf.numpy_function(np.load, [items["charactor_files"]], tf.int32)
g_att = (
tf.numpy_function(np.load, [items["align_files"]], tf.float32)
if len(self.align_files) > 1
else None
)
mel_length = len(mel)
char_length = len(charactor)
# padding mel to make its length is multiple of reduction factor.
real_mel_length = mel_length
remainder = mel_length % self.reduction_factor
if remainder != 0:
new_mel_length = mel_length + self.reduction_factor - remainder
mel = tf.pad(
mel,
[[0, new_mel_length - mel_length], [0, 0]],
constant_values=self.mel_pad_value,
)
mel_length = new_mel_length
items = {
"utt_ids": items["utt_ids"],
"input_ids": charactor,
"input_lengths": char_length,
"speaker_ids": 0,
"mel_gts": mel,
"mel_lengths": mel_length,
"real_mel_lengths": real_mel_length,
"g_attentions": g_att,
}
return items
def _guided_attention(self, items):
"""Guided attention. Refer to page 3 on the paper (https://arxiv.org/abs/1710.08969)."""
items = items.copy()
mel_len = items["mel_lengths"] // self.reduction_factor
char_len = items["input_lengths"]
xv, yv = tf.meshgrid(tf.range(char_len), tf.range(mel_len), indexing="ij")
f32_matrix = tf.cast(yv / mel_len - xv / char_len, tf.float32)
items["g_attentions"] = 1.0 - tf.math.exp(
-(f32_matrix ** 2) / (2 * self.g ** 2)
)
return items
def create(
self,
allow_cache=False,
batch_size=1,
is_shuffle=False,
map_fn=None,
reshuffle_each_iteration=True,
drop_remainder=True,
):
"""Create tf.dataset function."""
output_types = self.get_output_dtypes()
datasets = tf.data.Dataset.from_generator(
self.generator, output_types=output_types, args=(self.get_args())
)
# load data
datasets = datasets.map(
lambda items: self._load_data(items), tf.data.experimental.AUTOTUNE
)
# calculate guided attention
if len(self.align_files) < 1:
datasets = datasets.map(
lambda items: self._guided_attention(items),
tf.data.experimental.AUTOTUNE,
)
datasets = datasets.filter(
lambda x: x["mel_lengths"] > self.mel_length_threshold
)
if allow_cache:
datasets = datasets.cache()
if is_shuffle:
datasets = datasets.shuffle(
self.get_len_dataset(),
reshuffle_each_iteration=reshuffle_each_iteration,
)
# define padding value.
padding_values = {
"utt_ids": " ",
"input_ids": self.char_pad_value,
"input_lengths": 0,
"speaker_ids": 0,
"mel_gts": self.mel_pad_value,
"mel_lengths": 0,
"real_mel_lengths": 0,
"g_attentions": self.ga_pad_value,
}
# define padded shapes.
padded_shapes = {
"utt_ids": [],
"input_ids": [None]
if self.use_fixed_shapes is False
else [self.max_char_length],
"input_lengths": [],
"speaker_ids": [],
"mel_gts": [None, 80]
if self.use_fixed_shapes is False
else [self.max_mel_length, 80],
"mel_lengths": [],
"real_mel_lengths": [],
"g_attentions": [None, None]
if self.use_fixed_shapes is False
else [self.max_char_length, self.max_mel_length // self.reduction_factor],
}
datasets = datasets.padded_batch(
batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=drop_remainder,
)
datasets = datasets.prefetch(tf.data.experimental.AUTOTUNE)
return datasets
def get_output_dtypes(self):
output_types = {
"utt_ids": tf.string,
"mel_files": tf.string,
"charactor_files": tf.string,
"align_files": tf.string,
}
return output_types
def get_len_dataset(self):
return len(self.utt_ids)
def __name__(self):
return "CharactorMelDataset"