from typing import Dict, List, NoReturn, Optional import h5py import librosa import numpy as np import torch from pytorch_lightning.core.datamodule import LightningDataModule from bytesep.data.samplers import DistributedSamplerWrapper from bytesep.utils import int16_to_float32 class DataModule(LightningDataModule): def __init__( self, train_sampler: object, train_dataset: object, num_workers: int, distributed: bool, ): r"""Data module. Args: train_sampler: Sampler object train_dataset: Dataset object num_workers: int distributed: bool """ super().__init__() self._train_sampler = train_sampler self.train_dataset = train_dataset self.num_workers = num_workers self.distributed = distributed def setup(self, stage: Optional[str] = None) -> NoReturn: r"""called on every device.""" # SegmentSampler is used for selecting segments for training. # On multiple devices, each SegmentSampler samples a part of mini-batch # data. if self.distributed: self.train_sampler = DistributedSamplerWrapper(self._train_sampler) else: self.train_sampler = self._train_sampler def train_dataloader(self) -> torch.utils.data.DataLoader: r"""Get train loader.""" train_loader = torch.utils.data.DataLoader( dataset=self.train_dataset, batch_sampler=self.train_sampler, collate_fn=collate_fn, num_workers=self.num_workers, pin_memory=True, ) return train_loader class Dataset: def __init__(self, augmentor: object, segment_samples: int): r"""Used for getting data according to a meta. Args: augmentor: Augmentor class segment_samples: int """ self.augmentor = augmentor self.segment_samples = segment_samples def __getitem__(self, meta: Dict) -> Dict: r"""Return data according to a meta. E.g., an input meta looks like: { 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]}. } Then, vocals segments of song_A and song_B will be mixed (mix-audio augmentation). Accompaniment segments of song_C and song_B will be mixed (mix-audio augmentation). Finally, mixture is created by summing vocals and accompaniment. Args: meta: dict, e.g., { 'vocals': [['song_A.h5', 6332760, 6465060], ['song_B.h5', 198450, 330750]], 'accompaniment': [['song_C.h5', 24232920, 24365250], ['song_D.h5', 1569960, 1702260]]} } Returns: data_dict: dict, e.g., { 'vocals': (channels, segments_num), 'accompaniment': (channels, segments_num), 'mixture': (channels, segments_num), } """ source_types = meta.keys() data_dict = {} for source_type in source_types: # E.g., ['vocals', 'bass', ...] waveforms = [] # Audio segments to be mix-audio augmented. for m in meta[source_type]: # E.g., { # 'hdf5_path': '.../song_A.h5', # 'key_in_hdf5': 'vocals', # 'begin_sample': '13406400', # 'end_sample': 13538700, # } hdf5_path = m['hdf5_path'] key_in_hdf5 = m['key_in_hdf5'] bgn_sample = m['begin_sample'] end_sample = m['end_sample'] with h5py.File(hdf5_path, 'r') as hf: if source_type == 'audioset': index_in_hdf5 = m['index_in_hdf5'] waveform = int16_to_float32( hf['waveform'][index_in_hdf5][bgn_sample:end_sample] ) waveform = waveform[None, :] else: waveform = int16_to_float32( hf[key_in_hdf5][:, bgn_sample:end_sample] ) if self.augmentor: waveform = self.augmentor(waveform, source_type) waveform = librosa.util.fix_length( waveform, size=self.segment_samples, axis=1 ) # (channels_num, segments_num) waveforms.append(waveform) # E.g., waveforms: [(channels_num, audio_samples), (channels_num, audio_samples)] # mix-audio augmentation data_dict[source_type] = np.sum(waveforms, axis=0) # data_dict[source_type]: (channels_num, audio_samples) # data_dict looks like: { # 'voclas': (channels_num, audio_samples), # 'accompaniment': (channels_num, audio_samples) # } # Mix segments from different sources. mixture = np.sum( [data_dict[source_type] for source_type in source_types], axis=0 ) data_dict['mixture'] = mixture # shape: (channels_num, audio_samples) return data_dict def collate_fn(list_data_dict: List[Dict]) -> Dict: r"""Collate mini-batch data to inputs and targets for training. Args: list_data_dict: e.g., [ {'vocals': (channels_num, segment_samples), 'accompaniment': (channels_num, segment_samples), 'mixture': (channels_num, segment_samples) }, {'vocals': (channels_num, segment_samples), 'accompaniment': (channels_num, segment_samples), 'mixture': (channels_num, segment_samples) }, ...] Returns: data_dict: e.g. { 'vocals': (batch_size, channels_num, segment_samples), 'accompaniment': (batch_size, channels_num, segment_samples), 'mixture': (batch_size, channels_num, segment_samples) } """ data_dict = {} for key in list_data_dict[0].keys(): data_dict[key] = torch.Tensor( np.array([data_dict[key] for data_dict in list_data_dict]) ) return data_dict