Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Copyright 2019 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Dataset modules based on kaldi-style scp files.""" | |
import logging | |
from multiprocessing import Manager | |
import kaldiio | |
import numpy as np | |
from torch.utils.data import Dataset | |
from parallel_wavegan.utils import HDF5ScpLoader | |
from parallel_wavegan.utils import NpyScpLoader | |
def _get_feats_scp_loader(feats_scp): | |
# read the first line of feats.scp file | |
with open(feats_scp) as f: | |
key, value = f.readlines()[0].replace("\n", "").split() | |
# check scp type | |
if ":" in value: | |
value_1, value_2 = value.split(":") | |
if value_1.endswith(".ark"): | |
# kaldi-ark case: utt_id_1 /path/to/utt_id_1.ark:index | |
return kaldiio.load_scp(feats_scp) | |
elif value_1.endswith(".h5"): | |
# hdf5 case with path in hdf5: utt_id_1 /path/to/utt_id_1.h5:feats | |
return HDF5ScpLoader(feats_scp) | |
else: | |
raise ValueError("Not supported feats.scp type.") | |
else: | |
if value.endswith(".h5"): | |
# hdf5 case without path in hdf5: utt_id_1 /path/to/utt_id_1.h5 | |
return HDF5ScpLoader(feats_scp) | |
elif value.endswith(".npy"): | |
# npy case: utt_id_1 /path/to/utt_id_1.npy | |
return NpyScpLoader(feats_scp) | |
else: | |
raise ValueError("Not supported feats.scp type.") | |
class AudioMelSCPDataset(Dataset): | |
"""PyTorch compatible audio and mel dataset based on kaldi-stype scp files.""" | |
def __init__( | |
self, | |
wav_scp, | |
feats_scp, | |
segments=None, | |
audio_length_threshold=None, | |
mel_length_threshold=None, | |
return_utt_id=False, | |
return_sampling_rate=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
wav_scp (str): Kaldi-style wav.scp file. | |
feats_scp (str): Kaldi-style fests.scp file. | |
segments (str): Kaldi-style segments file. | |
audio_length_threshold (int): Threshold to remove short audio files. | |
mel_length_threshold (int): Threshold to remove short feature files. | |
return_utt_id (bool): Whether to return utterance id. | |
return_sampling_rate (bool): Wheter to return sampling rate. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# load scp as lazy dict | |
audio_loader = kaldiio.load_scp(wav_scp, segments=segments) | |
mel_loader = _get_feats_scp_loader(feats_scp) | |
audio_keys = list(audio_loader.keys()) | |
mel_keys = list(mel_loader.keys()) | |
# filter by threshold | |
if audio_length_threshold is not None: | |
audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] | |
idxs = [ | |
idx | |
for idx in range(len(audio_keys)) | |
if audio_lengths[idx] > audio_length_threshold | |
] | |
if len(audio_keys) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by audio length threshold " | |
f"({len(audio_keys)} -> {len(idxs)})." | |
) | |
audio_keys = [audio_keys[idx] for idx in idxs] | |
mel_keys = [mel_keys[idx] for idx in idxs] | |
if mel_length_threshold is not None: | |
mel_lengths = [mel.shape[0] for mel in mel_loader.values()] | |
idxs = [ | |
idx | |
for idx in range(len(mel_keys)) | |
if mel_lengths[idx] > mel_length_threshold | |
] | |
if len(mel_keys) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by mel length threshold " | |
f"({len(mel_keys)} -> {len(idxs)})." | |
) | |
audio_keys = [audio_keys[idx] for idx in idxs] | |
mel_keys = [mel_keys[idx] for idx in idxs] | |
# assert the number of files | |
assert len(audio_keys) == len( | |
mel_keys | |
), f"Number of audio and mel files are different ({len(audio_keys)} vs {len(mel_keys)})." | |
self.audio_loader = audio_loader | |
self.mel_loader = mel_loader | |
self.utt_ids = audio_keys | |
self.return_utt_id = return_utt_id | |
self.return_sampling_rate = return_sampling_rate | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(self.utt_ids))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). | |
ndarray: Feature (T', C). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
fs, audio = self.audio_loader[utt_id] | |
mel = self.mel_loader[utt_id] | |
# normalize audio signal to be [-1, 1] | |
audio = audio.astype(np.float32) | |
audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit | |
if self.return_sampling_rate: | |
audio = (audio, fs) | |
if self.return_utt_id: | |
items = utt_id, audio, mel | |
else: | |
items = audio, mel | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.utt_ids) | |
class AudioSCPDataset(Dataset): | |
"""PyTorch compatible audio dataset based on kaldi-stype scp files.""" | |
def __init__( | |
self, | |
wav_scp, | |
segments=None, | |
audio_length_threshold=None, | |
return_utt_id=False, | |
return_sampling_rate=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
wav_scp (str): Kaldi-style wav.scp file. | |
segments (str): Kaldi-style segments file. | |
audio_length_threshold (int): Threshold to remove short audio files. | |
return_utt_id (bool): Whether to return utterance id. | |
return_sampling_rate (bool): Wheter to return sampling rate. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# load scp as lazy dict | |
audio_loader = kaldiio.load_scp(wav_scp, segments=segments) | |
audio_keys = list(audio_loader.keys()) | |
# filter by threshold | |
if audio_length_threshold is not None: | |
audio_lengths = [audio.shape[0] for _, audio in audio_loader.values()] | |
idxs = [ | |
idx | |
for idx in range(len(audio_keys)) | |
if audio_lengths[idx] > audio_length_threshold | |
] | |
if len(audio_keys) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by audio length threshold " | |
f"({len(audio_keys)} -> {len(idxs)})." | |
) | |
audio_keys = [audio_keys[idx] for idx in idxs] | |
self.audio_loader = audio_loader | |
self.utt_ids = audio_keys | |
self.return_utt_id = return_utt_id | |
self.return_sampling_rate = return_sampling_rate | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(self.utt_ids))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray or tuple: Audio signal (T,) or (w/ sampling rate if return_sampling_rate = True). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
fs, audio = self.audio_loader[utt_id] | |
# normalize audio signal to be [-1, 1] | |
audio = audio.astype(np.float32) | |
audio /= 1 << (16 - 1) # assume that wav is PCM 16 bit | |
if self.return_sampling_rate: | |
audio = (audio, fs) | |
if self.return_utt_id: | |
items = utt_id, audio | |
else: | |
items = audio | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.utt_ids) | |
class MelSCPDataset(Dataset): | |
"""PyTorch compatible mel dataset based on kaldi-stype scp files.""" | |
def __init__( | |
self, | |
feats_scp, | |
mel_length_threshold=None, | |
return_utt_id=False, | |
allow_cache=False, | |
): | |
"""Initialize dataset. | |
Args: | |
feats_scp (str): Kaldi-style fests.scp file. | |
mel_length_threshold (int): Threshold to remove short feature files. | |
return_utt_id (bool): Whether to return utterance id. | |
allow_cache (bool): Whether to allow cache of the loaded files. | |
""" | |
# load scp as lazy dict | |
mel_loader = _get_feats_scp_loader(feats_scp) | |
mel_keys = list(mel_loader.keys()) | |
# filter by threshold | |
if mel_length_threshold is not None: | |
mel_lengths = [mel.shape[0] for mel in mel_loader.values()] | |
idxs = [ | |
idx | |
for idx in range(len(mel_keys)) | |
if mel_lengths[idx] > mel_length_threshold | |
] | |
if len(mel_keys) != len(idxs): | |
logging.warning( | |
f"Some files are filtered by mel length threshold " | |
f"({len(mel_keys)} -> {len(idxs)})." | |
) | |
mel_keys = [mel_keys[idx] for idx in idxs] | |
self.mel_loader = mel_loader | |
self.utt_ids = mel_keys | |
self.return_utt_id = return_utt_id | |
self.allow_cache = allow_cache | |
if allow_cache: | |
# NOTE(kan-bayashi): Manager is need to share memory in dataloader with num_workers > 0 | |
self.manager = Manager() | |
self.caches = self.manager.list() | |
self.caches += [() for _ in range(len(self.utt_ids))] | |
def __getitem__(self, idx): | |
"""Get specified idx items. | |
Args: | |
idx (int): Index of the item. | |
Returns: | |
str: Utterance id (only in return_utt_id = True). | |
ndarray: Feature (T', C). | |
""" | |
if self.allow_cache and len(self.caches[idx]) != 0: | |
return self.caches[idx] | |
utt_id = self.utt_ids[idx] | |
mel = self.mel_loader[utt_id] | |
if self.return_utt_id: | |
items = utt_id, mel | |
else: | |
items = mel | |
if self.allow_cache: | |
self.caches[idx] = items | |
return items | |
def __len__(self): | |
"""Return dataset length. | |
Returns: | |
int: The length of dataset. | |
""" | |
return len(self.utt_ids) | |