File size: 6,122 Bytes
fdc1efd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose

import modeling.transforms as transform_module
from modeling.transforms import (
    LabelsFromTxt,
    OneHotEncode,
    ParentMultilabel,
    Preprocess,
    Transform,
)
from modeling.utils import CLASSES, get_wav_files, init_obj, init_transforms


class IRMASDataset(Dataset):
    """Dataset class for IRMAS dataset.

    :param audio_dir: Directory containing the audio files
    :type audio_dir: Union[str, Path]
    :param preprocess: Preprocessing method to apply to the audio files
    :type preprocess: Type[Preprocess]
    :param signal_augments: Signal augmentation method to apply to the audio files, defaults to None
    :type signal_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
    :param transforms: Transform method to apply to the audio files, defaults to None
    :type transforms: Optional[Union[Type[Compose], Type[Transform]]], optional
    :param spec_augments: Spectrogram augmentation method to apply to the audio files, defaults to None
    :type spec_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
    :param subset: Subset of the data to load (train, valid, or test), defaults to "train"
    :type subset: str, optional
    :raises AssertionError: Raises an assertion error if subset is not train, valid or test
    :raises OSError: Raises an OS error if test_songs.txt is not found in the data folder
    :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
    :rtype: Tuple[Tensor, Tensor]
    """

    def __init__(
        self,
        audio_dir: Union[str, Path],
        preprocess: Type[Preprocess],
        signal_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
        transforms: Optional[Union[Type[Compose], Type[Transform]]] = None,
        spec_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
        subset: str = "train",
    ):
        self.files = get_wav_files(audio_dir)
        assert subset in ["train", "valid", "test"], "Subset can only be train, valid or test"
        self.subset = subset

        if self.subset != "train":
            try:
                test_songs = np.genfromtxt("../data/test_songs.txt", dtype=str, ndmin=1, delimiter="\n")
            except OSError as e:
                print("Error: {e}")
                print("test_songs.txt not found in data/. Please generate a split before training")
                raise e

        if self.subset == "valid":
            self.files = [file for file in self.files if Path(file).stem not in test_songs]
        if self.subset == "test":
            self.files = [file for file in self.files if Path(file).stem in test_songs]

        self.preprocess = preprocess
        self.transforms = transforms
        self.signal_augments = signal_augments
        self.spec_augments = spec_augments

    def __len__(self):
        """Return the length of the dataset.

        :return: The length of the dataset
        :rtype: int
        """

        return len(self.files)

    def __getitem__(self, index):
        """Get an item from the dataset.

        :param index: The index of the item to get
        :type index: int
        :return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
        :rtype: Tuple[Tensor, Tensor]
        """

        sample_path = self.files[index]
        signal = self.preprocess(sample_path)

        if self.subset == "train":
            target_transforms = Compose([ParentMultilabel(sep="-"), OneHotEncode(CLASSES)])
        else:
            target_transforms = Compose([LabelsFromTxt(), OneHotEncode(CLASSES)])

        label = target_transforms(sample_path)

        if self.signal_augments is not None and self.subset == "train":
            signal = self.signal_augments(signal)

        if self.transforms is not None:
            signal = self.transforms(signal)

        if self.spec_augments is not None and self.subset == "train":
            signal = self.spec_augments(signal)

        return signal, label.float()


def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]):
    """
    Function to collate a batch of audio signals and their corresponding labels.

    :param data: A list of tuples containing the audio signals and their corresponding labels.
    :type data: List[Tuple[torch.Tensor, torch.Tensor]]

    :return: A tuple containing the batch of audio signals and their corresponding labels.
    :rtype: Tuple[torch.Tensor, torch.Tensor]
    """

    features, labels = zip(*data)
    features = [item.squeeze().T for item in features]
    # Pads items to same length if they're not
    features = pad_sequence(features, batch_first=True)
    labels = torch.stack(labels)

    return features, labels


def get_loader(config: dict, subset: str):
    """
    Function to create a PyTorch DataLoader for a given subset of the IRMAS dataset.

    :param config: A configuration object.
    :type config: Any
    :param subset: The subset of the dataset to use. Can be "train" or "valid".
    :type subset: str

    :return: A PyTorch DataLoader for the specified subset of the dataset.
    :rtype: torch.utils.data.DataLoader
    """

    dst = IRMASDataset(
        config.train_dir if subset == "train" else config.valid_dir,
        preprocess=init_obj(config.preprocess, transform_module),
        transforms=init_obj(config.transforms, transform_module),
        signal_augments=init_transforms(config.signal_augments, transform_module),
        spec_augments=init_transforms(config.spec_augments, transform_module),
        subset=subset,
    )

    return DataLoader(
        dst,
        batch_size=config.batch_size,
        shuffle=True if subset == "train" else False,
        pin_memory=True if torch.cuda.is_available() else False,
        num_workers=torch.get_num_threads() - 1,
        collate_fn=collate_fn,
    )