import torch import torch.nn as nn from torch_audiomentations import Compose, Gain, PolarityInversion, AddColoredNoise, PitchShift, PeakNormalization, PitchShift # TODO add where I copied the code from class AUG(nn.Module): def __init__(self, prob=0.3): super().__init__() self.aug = Compose( transforms=[ AddColoredNoise(p=prob), PitchShift(sample_rate=16000, min_transpose_semitones=-1, max_transpose_semitones=1, p=prob), PeakNormalization(p=0.1), Gain(min_gain_in_db=-6, max_gain_in_db=6, p=prob), ]) def forward(self, x): return self.aug(x, sample_rate=16000)