from pathlib import Path from typing import List, Optional, Tuple, Type, Union import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose import modeling.transforms as transform_module from modeling.transforms import ( LabelsFromTxt, OneHotEncode, ParentMultilabel, Preprocess, Transform, ) from modeling.utils import CLASSES, get_wav_files, init_obj, init_transforms class IRMASDataset(Dataset): """Dataset class for IRMAS dataset. :param audio_dir: Directory containing the audio files :type audio_dir: Union[str, Path] :param preprocess: Preprocessing method to apply to the audio files :type preprocess: Type[Preprocess] :param signal_augments: Signal augmentation method to apply to the audio files, defaults to None :type signal_augments: Optional[Union[Type[Compose], Type[Transform]]], optional :param transforms: Transform method to apply to the audio files, defaults to None :type transforms: Optional[Union[Type[Compose], Type[Transform]]], optional :param spec_augments: Spectrogram augmentation method to apply to the audio files, defaults to None :type spec_augments: Optional[Union[Type[Compose], Type[Transform]]], optional :param subset: Subset of the data to load (train, valid, or test), defaults to "train" :type subset: str, optional :raises AssertionError: Raises an assertion error if subset is not train, valid or test :raises OSError: Raises an OS error if test_songs.txt is not found in the data folder :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label :rtype: Tuple[Tensor, Tensor] """ def __init__( self, audio_dir: Union[str, Path], preprocess: Type[Preprocess], signal_augments: Optional[Union[Type[Compose], Type[Transform]]] = None, transforms: Optional[Union[Type[Compose], Type[Transform]]] = None, spec_augments: Optional[Union[Type[Compose], Type[Transform]]] = None, subset: str = "train", ): self.files = get_wav_files(audio_dir) assert subset in ["train", "valid", "test"], "Subset can only be train, valid or test" self.subset = subset if self.subset != "train": try: test_songs = np.genfromtxt("../data/test_songs.txt", dtype=str, ndmin=1, delimiter="\n") except OSError as e: print("Error: {e}") print("test_songs.txt not found in data/. Please generate a split before training") raise e if self.subset == "valid": self.files = [file for file in self.files if Path(file).stem not in test_songs] if self.subset == "test": self.files = [file for file in self.files if Path(file).stem in test_songs] self.preprocess = preprocess self.transforms = transforms self.signal_augments = signal_augments self.spec_augments = spec_augments def __len__(self): """Return the length of the dataset. :return: The length of the dataset :rtype: int """ return len(self.files) def __getitem__(self, index): """Get an item from the dataset. :param index: The index of the item to get :type index: int :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label :rtype: Tuple[Tensor, Tensor] """ sample_path = self.files[index] signal = self.preprocess(sample_path) if self.subset == "train": target_transforms = Compose([ParentMultilabel(sep="-"), OneHotEncode(CLASSES)]) else: target_transforms = Compose([LabelsFromTxt(), OneHotEncode(CLASSES)]) label = target_transforms(sample_path) if self.signal_augments is not None and self.subset == "train": signal = self.signal_augments(signal) if self.transforms is not None: signal = self.transforms(signal) if self.spec_augments is not None and self.subset == "train": signal = self.spec_augments(signal) return signal, label.float() def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]): """ Function to collate a batch of audio signals and their corresponding labels. :param data: A list of tuples containing the audio signals and their corresponding labels. :type data: List[Tuple[torch.Tensor, torch.Tensor]] :return: A tuple containing the batch of audio signals and their corresponding labels. :rtype: Tuple[torch.Tensor, torch.Tensor] """ features, labels = zip(*data) features = [item.squeeze().T for item in features] # Pads items to same length if they're not features = pad_sequence(features, batch_first=True) labels = torch.stack(labels) return features, labels def get_loader(config: dict, subset: str): """ Function to create a PyTorch DataLoader for a given subset of the IRMAS dataset. :param config: A configuration object. :type config: Any :param subset: The subset of the dataset to use. Can be "train" or "valid". :type subset: str :return: A PyTorch DataLoader for the specified subset of the dataset. :rtype: torch.utils.data.DataLoader """ dst = IRMASDataset( config.train_dir if subset == "train" else config.valid_dir, preprocess=init_obj(config.preprocess, transform_module), transforms=init_obj(config.transforms, transform_module), signal_augments=init_transforms(config.signal_augments, transform_module), spec_augments=init_transforms(config.spec_augments, transform_module), subset=subset, ) return DataLoader( dst, batch_size=config.batch_size, shuffle=True if subset == "train" else False, pin_memory=True if torch.cuda.is_available() else False, num_workers=torch.get_num_threads() - 1, collate_fn=collate_fn, )