# -*- coding: utf-8 -*- # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """Dataset modules.""" import logging import os from multiprocessing import Manager import numpy as np from torch.utils.data import Dataset from parallel_wavegan.utils import find_files from parallel_wavegan.utils import read_hdf5 class AudioMelDataset(Dataset): """PyTorch compatible audio and mel dataset.""" def __init__( self, root_dir, audio_query="*.h5", mel_query="*.h5", audio_load_fn=lambda x: read_hdf5(x, "wave"), mel_load_fn=lambda x: read_hdf5(x, "feats"), audio_length_threshold=None, mel_length_threshold=None, return_utt_id=False, allow_cache=False, ): """Initialize dataset. Args: root_dir (str): Root directory including dumped files. audio_query (str): Query to find audio files in root_dir. mel_query (str): Query to find feature files in root_dir. audio_load_fn (func): Function to load audio file. mel_load_fn (func): Function to load feature file. audio_length_threshold (int): Threshold to remove short audio files. mel_length_threshold (int): Threshold to remove short feature files. return_utt_id (bool): Whether to return the utterance id with arrays. allow_cache (bool): Whether to allow cache of the loaded files. """ # find all of audio and mel files audio_files = sorted(find_files(root_dir, audio_query)) mel_files = sorted(find_files(root_dir, mel_query)) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] idxs = [ idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_files) != len(idxs): logging.warning( f"Some files are filtered by audio length threshold " f"({len(audio_files)} -> {len(idxs)})." ) audio_files = [audio_files[idx] for idx in idxs] mel_files = [mel_files[idx] for idx in idxs] if mel_length_threshold is not None: mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files] idxs = [ idx for idx in range(len(mel_files)) if mel_lengths[idx] > mel_length_threshold ] if len(mel_files) != len(idxs): logging.warning( f"Some files are filtered by mel length threshold " f"({len(mel_files)} -> {len(idxs)})." ) audio_files = [audio_files[idx] for idx in idxs] mel_files = [mel_files[idx] for idx in idxs] # assert the number of files assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}." assert len(audio_files) == len( mel_files ), f"Number of audio and mel files are different ({len(audio_files)} vs {len(mel_files)})." self.audio_files = audio_files self.audio_load_fn = audio_load_fn self.mel_load_fn = mel_load_fn self.mel_files = mel_files if ".npy" in audio_query: self.utt_ids = [ os.path.basename(f).replace("-wave.npy", "") for f in audio_files ] else: self.utt_ids = [ os.path.splitext(os.path.basename(f))[0] for f in audio_files ] self.return_utt_id = return_utt_id self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(audio_files))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray: Audio signal (T,). ndarray: Feature (T', C). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] audio = self.audio_load_fn(self.audio_files[idx]) mel = self.mel_load_fn(self.mel_files[idx]) if self.return_utt_id: items = utt_id, audio, mel else: items = audio, mel if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.audio_files) class AudioDataset(Dataset): """PyTorch compatible audio dataset.""" def __init__( self, root_dir, audio_query="*-wave.npy", audio_length_threshold=None, audio_load_fn=np.load, return_utt_id=False, allow_cache=False, ): """Initialize dataset. Args: root_dir (str): Root directory including dumped files. audio_query (str): Query to find audio files in root_dir. audio_load_fn (func): Function to load audio file. audio_length_threshold (int): Threshold to remove short audio files. return_utt_id (bool): Whether to return the utterance id with arrays. allow_cache (bool): Whether to allow cache of the loaded files. """ # find all of audio and mel files audio_files = sorted(find_files(root_dir, audio_query)) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio_load_fn(f).shape[0] for f in audio_files] idxs = [ idx for idx in range(len(audio_files)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_files) != len(idxs): logging.waning( f"some files are filtered by audio length threshold " f"({len(audio_files)} -> {len(idxs)})." ) audio_files = [audio_files[idx] for idx in idxs] # assert the number of files assert len(audio_files) != 0, f"Not found any audio files in ${root_dir}." self.audio_files = audio_files self.audio_load_fn = audio_load_fn self.return_utt_id = return_utt_id if ".npy" in audio_query: self.utt_ids = [ os.path.basename(f).replace("-wave.npy", "") for f in audio_files ] else: self.utt_ids = [ os.path.splitext(os.path.basename(f))[0] for f in audio_files ] self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(audio_files))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray: Audio (T,). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] audio = self.audio_load_fn(self.audio_files[idx]) if self.return_utt_id: items = utt_id, audio else: items = audio if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.audio_files) class MelDataset(Dataset): """PyTorch compatible mel dataset.""" def __init__( self, root_dir, mel_query="*-feats.npy", mel_length_threshold=None, mel_load_fn=np.load, return_utt_id=False, allow_cache=False, ): """Initialize dataset. Args: root_dir (str): Root directory including dumped files. mel_query (str): Query to find feature files in root_dir. mel_load_fn (func): Function to load feature file. mel_length_threshold (int): Threshold to remove short feature files. return_utt_id (bool): Whether to return the utterance id with arrays. allow_cache (bool): Whether to allow cache of the loaded files. """ # find all of the mel files mel_files = sorted(find_files(root_dir, mel_query)) # filter by threshold if mel_length_threshold is not None: mel_lengths = [mel_load_fn(f).shape[0] for f in mel_files] idxs = [ idx for idx in range(len(mel_files)) if mel_lengths[idx] > mel_length_threshold ] if len(mel_files) != len(idxs): logging.warning( f"Some files are filtered by mel length threshold " f"({len(mel_files)} -> {len(idxs)})." ) mel_files = [mel_files[idx] for idx in idxs] # assert the number of files assert len(mel_files) != 0, f"Not found any mel files in ${root_dir}." self.mel_files = mel_files self.mel_load_fn = mel_load_fn self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] if ".npy" in mel_query: self.utt_ids = [ os.path.basename(f).replace("-feats.npy", "") for f in mel_files ] else: self.utt_ids = [os.path.splitext(os.path.basename(f))[0] for f in mel_files] self.return_utt_id = return_utt_id self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(mel_files))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray: Feature (T', C). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] mel = self.mel_load_fn(self.mel_files[idx]) if self.return_utt_id: items = utt_id, mel else: items = mel if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.mel_files)