jone's picture
init
75c6e9a
from typing import Dict, List, NoReturn, Optional
import h5py
import librosa
import numpy as np
import torch
from pytorch_lightning.core.datamodule import LightningDataModule
from bytesep.data.samplers import DistributedSamplerWrapper
from bytesep.utils import int16_to_float32
class DataModule(LightningDataModule):
def __init__(
self,
train_sampler: object,
train_dataset: object,
num_workers: int,
distributed: bool,
):
r"""Data module.
Args:
train_sampler: Sampler object
train_dataset: Dataset object
num_workers: int
distributed: bool
"""
super().__init__()
self._train_sampler = train_sampler
self.train_dataset = train_dataset
self.num_workers = num_workers
self.distributed = distributed
def setup(self, stage: Optional[str] = None) -> NoReturn:
r"""called on every device."""
# SegmentSampler is used for selecting segments for training.
# On multiple devices, each SegmentSampler samples a part of mini-batch
# data.
if self.distributed:
self.train_sampler = DistributedSamplerWrapper(self._train_sampler)
else:
self.train_sampler = self._train_sampler
def train_dataloader(self) -> torch.utils.data.DataLoader:
r"""Get train loader."""
train_loader = torch.utils.data.DataLoader(
dataset=self.train_dataset,
batch_sampler=self.train_sampler,
collate_fn=collate_fn,
num_workers=self.num_workers,
pin_memory=True,
)
return train_loader
class Dataset:
def __init__(self, augmentor: object, segment_samples: int):
r"""Used for getting data according to a meta.
Args:
augmentor: Augmentor class
segment_samples: int
"""
self.augmentor = augmentor
self.segment_samples = segment_samples
def __getitem__(self, meta: Dict) -> Dict:
r"""Return data according to a meta. E.g., an input meta looks like: {
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}.
}
Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation).
Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation).
Finally, mixture is created by summing vocals and accompaniment.
Args:
meta: dict, e.g., {
'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]],
'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}
}
Returns:
data_dict: dict, e.g., {
'vocals': (channels, segments_num),
'accompaniment': (channels, segments_num),
'mixture': (channels, segments_num),
}
"""
source_types = meta.keys()
data_dict = {}
for source_type in source_types:
# E.g., ['vocals', 'bass', ...]
waveforms = [] # Audio segments to be mix-audio augmented.
for m in meta[source_type]:
# E.g., {
# 'hdf5_path': '.../song_A.h5',
# 'key_in_hdf5': 'vocals',
# 'begin_sample': '13406400',
# 'end_sample': 13538700,
# }
hdf5_path = m['hdf5_path']
key_in_hdf5 = m['key_in_hdf5']
bgn_sample = m['begin_sample']
end_sample = m['end_sample']
with h5py.File(hdf5_path, 'r') as hf:
if source_type == 'audioset':
index_in_hdf5 = m['index_in_hdf5']
waveform = int16_to_float32(
hf['waveform'][index_in_hdf5][bgn_sample:end_sample]
)
waveform = waveform[None, :]
else:
waveform = int16_to_float32(
hf[key_in_hdf5][:, bgn_sample:end_sample]
)
if self.augmentor:
waveform = self.augmentor(waveform, source_type)
waveform = librosa.util.fix_length(
waveform, size=self.segment_samples, axis=1
)
# (channels_num, segments_num)
waveforms.append(waveform)
# E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)]
# mix-audio augmentation
data_dict[source_type] = np.sum(waveforms, axis=0)
# data_dict[source_type]: (channels_num, audio_samples)
# data_dict looks like: {
# 'voclas': (channels_num, audio_samples),
# 'accompaniment': (channels_num, audio_samples)
# }
# Mix segments from different sources.
mixture = np.sum(
[data_dict[source_type] for source_type in source_types], axis=0
)
data_dict['mixture'] = mixture
# shape: (channels_num, audio_samples)
return data_dict
def collate_fn(list_data_dict: List[Dict]) -> Dict:
r"""Collate mini-batch data to inputs and targets for training.
Args:
list_data_dict: e.g., [
{'vocals': (channels_num, segment_samples),
'accompaniment': (channels_num, segment_samples),
'mixture': (channels_num, segment_samples)
},
{'vocals': (channels_num, segment_samples),
'accompaniment': (channels_num, segment_samples),
'mixture': (channels_num, segment_samples)
},
...]
Returns:
data_dict: e.g. {
'vocals': (batch_size, channels_num, segment_samples),
'accompaniment': (batch_size, channels_num, segment_samples),
'mixture': (batch_size, channels_num, segment_samples)
}
"""
data_dict = {}
for key in list_data_dict[0].keys():
data_dict[key] = torch.Tensor(
np.array([data_dict[key] for data_dict in list_data_dict])
)
return data_dict