# -*- coding: utf-8 -*- # Copyright 2019 Tomoki Hayashi # MIT License (https://opensource.org/licenses/MIT) """Dataset modules based on kaldi-style scp files.""" import logging from multiprocessing import Manager import kaldiio import numpy as np from torch.utils.data import Dataset from parallel_wavegan.utils import HDF5ScpLoader from parallel_wavegan.utils import NpyScpLoader def _get_feats_scp_loader(feats_scp): # read the first line of feats.scp file with open(feats_scp) as f: key, value = f.readlines()[0].replace("\n", "").split() # check scp type if ":" in value: value_1, value_2 = value.split(":") if value_1.endswith(".ark"): # kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index return kaldiio.load_scp(feats_scp) elif value_1.endswith(".h5"): # hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats return HDF5ScpLoader(feats_scp) else: raise ValueError("Not supported feats.scp type.") else: if value.endswith(".h5"): # hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5 return HDF5ScpLoader(feats_scp) elif value.endswith(".npy"): # npy case: utt_id_1 /path/to/utt_id_1.npy return NpyScpLoader(feats_scp) else: raise ValueError("Not supported feats.scp type.") class AudioMelSCPDataset(Dataset): """PyTorch compatible audio and mel dataset based on kaldi-stype scp files.""" def __init__( self, wav_scp, feats_scp, segments=None, audio_length_threshold=None, mel_length_threshold=None, return_utt_id=False, return_sampling_rate=False, allow_cache=False, ): """Initialize dataset. Args: wav_scp (str): Kaldi-style wav.scp file. feats_scp (str): Kaldi-style fests.scp file. segments (str): Kaldi-style segments file. audio_length_threshold (int): Threshold to remove short audio files. mel_length_threshold (int): Threshold to remove short feature files. return_utt_id (bool): Whether to return utterance id. return_sampling_rate (bool): Wheter to return sampling rate. allow_cache (bool): Whether to allow cache of the loaded files. """ # load scp as lazy dict audio_loader = kaldiio.load_scp(wav_scp, segments=segments) mel_loader = _get_feats_scp_loader(feats_scp) audio_keys = list(audio_loader.keys()) mel_keys = list(mel_loader.keys()) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] idxs = [ idx for idx in range(len(audio_keys)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_keys) != len(idxs): logging.warning( f"Some files are filtered by audio length threshold " f"({len(audio_keys)} -> {len(idxs)})." ) audio_keys = [audio_keys[idx] for idx in idxs] mel_keys = [mel_keys[idx] for idx in idxs] if mel_length_threshold is not None: mel_lengths = [mel.shape[0] for mel in mel_loader.values()] idxs = [ idx for idx in range(len(mel_keys)) if mel_lengths[idx] > mel_length_threshold ] if len(mel_keys) != len(idxs): logging.warning( f"Some files are filtered by mel length threshold " f"({len(mel_keys)} -> {len(idxs)})." ) audio_keys = [audio_keys[idx] for idx in idxs] mel_keys = [mel_keys[idx] for idx in idxs] # assert the number of files assert len(audio_keys) == len( mel_keys ), f"Number of audio and mel files are different ({len(audio_keys)} vs {len(mel_keys)})." self.audio_loader = audio_loader self.mel_loader = mel_loader self.utt_ids = audio_keys self.return_utt_id = return_utt_id self.return_sampling_rate = return_sampling_rate self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.utt_ids))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). ndarray: Feature (T', C). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] fs, audio = self.audio_loader[utt_id] mel = self.mel_loader[utt_id] # normalize audio signal to be [-1, 1] audio = audio.astype(np.float32) audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit if self.return_sampling_rate: audio = (audio, fs) if self.return_utt_id: items = utt_id, audio, mel else: items = audio, mel if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.utt_ids) class AudioSCPDataset(Dataset): """PyTorch compatible audio dataset based on kaldi-stype scp files.""" def __init__( self, wav_scp, segments=None, audio_length_threshold=None, return_utt_id=False, return_sampling_rate=False, allow_cache=False, ): """Initialize dataset. Args: wav_scp (str): Kaldi-style wav.scp file. segments (str): Kaldi-style segments file. audio_length_threshold (int): Threshold to remove short audio files. return_utt_id (bool): Whether to return utterance id. return_sampling_rate (bool): Wheter to return sampling rate. allow_cache (bool): Whether to allow cache of the loaded files. """ # load scp as lazy dict audio_loader = kaldiio.load_scp(wav_scp, segments=segments) audio_keys = list(audio_loader.keys()) # filter by threshold if audio_length_threshold is not None: audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] idxs = [ idx for idx in range(len(audio_keys)) if audio_lengths[idx] > audio_length_threshold ] if len(audio_keys) != len(idxs): logging.warning( f"Some files are filtered by audio length threshold " f"({len(audio_keys)} -> {len(idxs)})." ) audio_keys = [audio_keys[idx] for idx in idxs] self.audio_loader = audio_loader self.utt_ids = audio_keys self.return_utt_id = return_utt_id self.return_sampling_rate = return_sampling_rate self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.utt_ids))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] fs, audio = self.audio_loader[utt_id] # normalize audio signal to be [-1, 1] audio = audio.astype(np.float32) audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit if self.return_sampling_rate: audio = (audio, fs) if self.return_utt_id: items = utt_id, audio else: items = audio if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.utt_ids) class MelSCPDataset(Dataset): """PyTorch compatible mel dataset based on kaldi-stype scp files.""" def __init__( self, feats_scp, mel_length_threshold=None, return_utt_id=False, allow_cache=False, ): """Initialize dataset. Args: feats_scp (str): Kaldi-style fests.scp file. mel_length_threshold (int): Threshold to remove short feature files. return_utt_id (bool): Whether to return utterance id. allow_cache (bool): Whether to allow cache of the loaded files. """ # load scp as lazy dict mel_loader = _get_feats_scp_loader(feats_scp) mel_keys = list(mel_loader.keys()) # filter by threshold if mel_length_threshold is not None: mel_lengths = [mel.shape[0] for mel in mel_loader.values()] idxs = [ idx for idx in range(len(mel_keys)) if mel_lengths[idx] > mel_length_threshold ] if len(mel_keys) != len(idxs): logging.warning( f"Some files are filtered by mel length threshold " f"({len(mel_keys)} -> {len(idxs)})." ) mel_keys = [mel_keys[idx] for idx in idxs] self.mel_loader = mel_loader self.utt_ids = mel_keys self.return_utt_id = return_utt_id self.allow_cache = allow_cache if allow_cache: # NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 self.manager = Manager() self.caches = self.manager.list() self.caches += [() for _ in range(len(self.utt_ids))] def __getitem__(self, idx): """Get specified idx items. Args: idx (int): Index of the item. Returns: str: Utterance id (only in return_utt_id = True). ndarray: Feature (T', C). """ if self.allow_cache and len(self.caches[idx]) != 0: return self.caches[idx] utt_id = self.utt_ids[idx] mel = self.mel_loader[utt_id] if self.return_utt_id: items = utt_id, mel else: items = mel if self.allow_cache: self.caches[idx] = items return items def __len__(self): """Return dataset length. Returns: int: The length of dataset. """ return len(self.utt_ids)