|
import os |
|
from abc import ABC, abstractmethod |
|
from functools import partial |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchaudio |
|
from torchaudio.transforms import FrequencyMasking, TimeMasking |
|
from torchvision.transforms import Compose |
|
from transformers import ASTFeatureExtractor |
|
|
|
|
|
class Transform(ABC): |
|
"""Abstract base class for audio transformations.""" |
|
|
|
@abstractmethod |
|
def __call__(self): |
|
""" |
|
Abstract method to apply the transformation. |
|
|
|
:raises NotImplementedError: If the subclass does not implement this method. |
|
|
|
""" |
|
pass |
|
|
|
|
|
class Preprocess(ABC): |
|
"""Abstract base class for preprocessing data. |
|
|
|
This class defines the interface for preprocessing data. Subclasses must implement the call method. |
|
|
|
""" |
|
|
|
@abstractmethod |
|
def __call__(self): |
|
"""Process the data. |
|
|
|
This method must be implemented by subclasses. |
|
|
|
:raises NotImplementedError: Subclasses must implement this method. |
|
|
|
""" |
|
pass |
|
|
|
|
|
class OneHotEncode(Transform): |
|
"""Transform labels to one-hot encoded tensor. |
|
|
|
This class is a transform that takes a list of labels and returns a one-hot encoded tensor. |
|
The labels are converted to a tensor with one-hot encoding using the specified classes. |
|
|
|
:param c: A list of classes to be used for one-hot encoding. |
|
:type c: list |
|
:return: A one-hot encoded tensor. |
|
:rtype: torch.Tensor |
|
|
|
""" |
|
|
|
def __init__(self, c: list): |
|
self.c = c |
|
|
|
def __call__(self, labels): |
|
""" |
|
Transform labels to one-hot encoded tensor. |
|
|
|
:param labels: A list of labels to be encoded. |
|
:type labels: list |
|
:return: A one-hot encoded tensor. |
|
:rtype: torch.Tensor |
|
|
|
""" |
|
|
|
target = torch.zeros(len(self.c), dtype=torch.float) |
|
for label in labels: |
|
idx = self.c.index(label) |
|
target[idx] = 1 |
|
return target |
|
|
|
|
|
class ParentMultilabel(Transform): |
|
""" |
|
A transform that extracts a list of labels from the parent directory name of a file path. |
|
|
|
:param sep: The separator used to split the parent directory name into labels. Defaults to " ". |
|
:type sep: str |
|
""" |
|
|
|
def __init__(self, sep=" "): |
|
self.sep = sep |
|
|
|
def __call__(self, path): |
|
""" |
|
Extract a list of labels from the parent directory name of a file path. |
|
|
|
:param path: The file path from which to extract labels. |
|
:type path: str |
|
:return: A list of labels extracted from the parent directory name of the input file path. |
|
:rtype: List[str] |
|
""" |
|
|
|
label = path.split(os.path.sep)[-2].split(self.sep) |
|
return label |
|
|
|
|
|
class LabelsFromTxt(Transform): |
|
""" |
|
Extract multilabel parent directory from file path. |
|
|
|
This class is a transform that extracts a multilabel parent directory from a file path. |
|
The directory names are split by a specified separator. |
|
|
|
:param sep: The separator used to split the directory names. Defaults to " ". |
|
:type sep: str |
|
|
|
""" |
|
|
|
def __init__(self, delimiter=None): |
|
self.delimiter = delimiter |
|
|
|
def __call__(self, path): |
|
""" |
|
Extract multilabel parent directory from file path. |
|
|
|
:param path: The path of the file to extract the multilabel directory from. |
|
:type path: str |
|
:return: A list of directory names representing the multilabel parent directory. |
|
:rtype: list |
|
|
|
""" |
|
|
|
path = path.replace("wav", "txt") |
|
label = np.loadtxt(path, dtype=str, ndmin=1, delimiter=self.delimiter) |
|
return label |
|
|
|
|
|
class PreprocessPipeline(Preprocess): |
|
"""A preprocessing pipeline for audio data. |
|
|
|
This class is a preprocessing pipeline for audio data. |
|
The pipeline includes resampling to a target sampling rate, mixing down stereo to mono, |
|
and loading audio from a file. |
|
|
|
:param target_sr: The target sampling rate to resample to. |
|
:type target_sr: int |
|
""" |
|
|
|
def __init__(self, target_sr): |
|
self.target_sr = target_sr |
|
|
|
def __call__(self, path): |
|
""" |
|
Preprocess audio data using a pipeline. |
|
|
|
:param path: The path to the audio file to load. |
|
:type path: str |
|
:return: A NumPy array of preprocessed audio data. |
|
:rtype: numpy.ndarray |
|
|
|
""" |
|
|
|
signal, sr = torchaudio.load(path) |
|
signal = self._resample(signal, sr) |
|
signal = self._mix_down(signal) |
|
return signal.numpy() |
|
|
|
def _mix_down(self, signal): |
|
""" |
|
Mix down stereo to mono. |
|
|
|
:param signal: The audio signal to mix down. |
|
:type signal: torch.Tensor |
|
:return: The mixed down audio signal. |
|
:rtype: torch.Tensor |
|
|
|
""" |
|
|
|
if signal.shape[0] > 1: |
|
signal = torch.mean(signal, dim=0, keepdim=True) |
|
return signal |
|
|
|
def _resample(self, signal, input_sr): |
|
""" |
|
Resample audio signal to a target sampling rate. |
|
|
|
:param signal: The audio signal to resample. |
|
:type signal: torch.Tensor |
|
:param input_sr: The current sampling rate of the audio signal. |
|
:type input_sr: int |
|
:return: The resampled audio signal. |
|
:rtype: torch.Tensor |
|
|
|
""" |
|
|
|
if input_sr != self.target_sr: |
|
resampler = torchaudio.transforms.Resample(input_sr, self.target_sr) |
|
signal = resampler(signal) |
|
return signal |
|
|
|
|
|
class SpecToImage(Transform): |
|
def __init__(self, mean=None, std=None, eps=1e-6): |
|
self.mean = mean |
|
self.std = std |
|
self.eps = eps |
|
|
|
def __call__(self, spec): |
|
spec = torch.stack([spec, spec, spec], dim=-1) |
|
|
|
mean = torch.mean(spec) if self.mean is None else self.mean |
|
std = torch.std(spec) if self.std is None else self.std |
|
spec_norm = (spec - mean) / std |
|
|
|
spec_min, spec_max = torch.min(spec_norm), torch.max(spec_norm) |
|
spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min) |
|
|
|
return spec_scaled.type(torch.uint8) |
|
|
|
|
|
class MinMaxScale(Transform): |
|
def __call__(self, spec): |
|
spec_min, spec_max = torch.min(spec), torch.max(spec) |
|
|
|
return (spec - spec_min) / (spec_max - spec_min) |
|
|
|
|
|
class Normalize(Transform): |
|
def __init__(self, mean, std): |
|
self.mean = mean |
|
self.std = std |
|
|
|
def __call__(self, spec): |
|
return (spec - self.mean) / self.std |
|
|
|
|
|
class FeatureExtractor(Transform): |
|
"""Extract features from audio signal using an AST feature extractor. |
|
|
|
This class is a transform that extracts features from an audio signal using an AST feature extractor. |
|
The features are returned as a PyTorch tensor. |
|
|
|
:param sr: The sampling rate of the audio signal. |
|
:type sr: int |
|
""" |
|
|
|
def __init__(self, sr): |
|
self.transform = partial(ASTFeatureExtractor(), sampling_rate=sr, return_tensors="pt") |
|
|
|
def __call__(self, signal): |
|
""" |
|
Extract features from audio signal using an AST feature extractor. |
|
|
|
:param signal: The audio signal to extract features from. |
|
:type signal: numpy.ndarray |
|
:return: A tensor of extracted audio features. |
|
:rtype: torch.Tensor |
|
|
|
""" |
|
|
|
return self.transform(signal.squeeze()).input_values.mT |
|
|
|
|
|
class Preemphasis(Transform): |
|
"""perform preemphasis on the input signal. |
|
:param signal: The signal to filter. |
|
:param coeff: The preemphasis coefficient. 0 is none, default 0.97. |
|
:returns: the filtered signal. |
|
""" |
|
|
|
def __init__(self, coeff: float = 0.97): |
|
self.coeff = coeff |
|
|
|
def __call__(self, signal): |
|
return torch.cat([signal[:, :1], signal[:, 1:] - self.coeff * signal[:, :-1]], dim=1) |
|
|
|
|
|
class Spectrogram(Transform): |
|
def __init__(self, sample_rate, n_mels, hop_length, n_fft): |
|
self.transform = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft, f_min=20, center=False |
|
) |
|
|
|
def __call__(self, signal): |
|
return self.transform(signal) |
|
|
|
|
|
class LogTransform(Transform): |
|
def __call__(self, signal): |
|
return torch.log(signal + 1e-8) |
|
|
|
|
|
class PadCutToLength(Transform): |
|
def __init__(self, max_length): |
|
self.max_length = max_length |
|
|
|
def __call__(self, spec): |
|
seq_len = spec.shape[-1] |
|
|
|
if seq_len > self.max_length: |
|
return spec[..., : self.max_length] |
|
if seq_len < self.max_length: |
|
diff = self.max_length - seq_len |
|
return F.pad(spec, (0, diff), mode="constant", value=0) |
|
|
|
|
|
class CustomFeatureExtractor(Transform): |
|
def __init__(self, sample_rate, n_mels, hop_length, n_fft, max_length, mean, std): |
|
self.extract = Compose( |
|
[ |
|
Preemphasis(), |
|
Spectrogram(sample_rate=sample_rate, n_mels=n_mels, hop_length=hop_length, n_fft=n_fft), |
|
LogTransform(), |
|
PadCutToLength(max_length=max_length), |
|
Normalize(mean=mean, std=std), |
|
] |
|
) |
|
|
|
def __call__(self, x): |
|
return self.extract(x) |
|
|
|
|
|
class RepeatAudio(Transform): |
|
"""A transform to repeat audio data. |
|
|
|
This class is a transform that repeats audio data a random number of times up to a maximum specified value. |
|
|
|
:param max_repeats: The maximum number of times to repeat the audio data. |
|
:type max_repeats: int |
|
""" |
|
|
|
def __init__(self, max_repeats: int = 2): |
|
self.max_repeats = max_repeats |
|
|
|
def __call__(self, signal): |
|
""" |
|
Repeat audio data a random number of times up to a maximum specified value. |
|
|
|
:param signal: The audio data to repeat. |
|
:type signal: numpy.ndarray |
|
:return: The repeated audio data. |
|
:rtype: numpy.ndarray |
|
|
|
""" |
|
|
|
num_repeats = torch.randint(1, self.max_repeats, (1,)).item() |
|
return np.tile(signal, reps=num_repeats) |
|
|
|
|
|
class MaskFrequency(Transform): |
|
"""A transform to mask frequency of a spectrogram. |
|
|
|
This class is a transform that masks out a random number of consecutive frequencies from a spectrogram. |
|
|
|
:param max_mask_length: The maximum number of consecutive frequencies to mask out from the spectrogram. |
|
:type max_mask_length: int |
|
""" |
|
|
|
def __init__(self, max_mask_length: int = 0): |
|
self.aug = FrequencyMasking(max_mask_length) |
|
|
|
def __call__(self, spec): |
|
""" |
|
Mask out a random number of consecutive frequencies from a spectrogram. |
|
|
|
:param spec: The input spectrogram. |
|
:type spec: numpy.ndarray |
|
:return: The spectrogram with masked frequencies. |
|
:rtype: numpy.ndarray |
|
|
|
""" |
|
|
|
return self.aug(spec) |
|
|
|
|
|
class MaskTime(Transform): |
|
"""A transform to mask time of a spectrogram. |
|
|
|
This class is a transform that masks out a random number of consecutive time steps from a spectrogram. |
|
|
|
:param max_mask_length: The maximum number of consecutive time steps to mask out from the spectrogram. |
|
:type max_mask_length: int |
|
""" |
|
|
|
def __init__(self, max_mask_length: int = 0): |
|
self.aug = TimeMasking(max_mask_length) |
|
|
|
def __call__(self, spec): |
|
""" |
|
Mask out a random number of consecutive time steps from a spectrogram. |
|
|
|
:param spec: The input spectrogram. |
|
:type spec: numpy.ndarray |
|
:return: The spectrogram with masked time steps. |
|
:rtype: numpy.ndarray |
|
|
|
""" |
|
|
|
return self.aug(spec) |
|
|