Spaces:
Running
Running
"""Audio transforms.""" | |
import torchaudio | |
import torchvision | |
from torchvision.transforms import Compose, ToTensor | |
import torchaudio.transforms as T | |
import imgaug.augmenters as iaa | |
import numpy as np | |
import torch | |
class AddNoise(object): | |
"""Add noise to the waveform.""" | |
def __init__(self, noise_level=0.1): | |
self.noise_level = noise_level | |
def __call__(self, waveform): | |
noise = torch.randn_like(waveform) | |
return waveform + self.noise_level * noise | |
def __repr__(self): | |
return self.__class__.__name__ + f"(noise_level={self.noise_level})" | |
class ChangeVolume(object): | |
"""Change the volume of the waveform.""" | |
def __init__(self, volume_factor=[0.6, 1.2]): | |
self.volume_factor = volume_factor | |
def __call__(self, waveform): | |
return waveform * np.random.uniform(*self.volume_factor) | |
def __repr__(self): | |
return self.__class__.__name__ + f"(volume_factor={self.volume_factor})" | |
def configure_transforms(cfg): | |
""" | |
Given a transform config (List[dict]), return a Compose object that | |
applies the transforms in order. | |
""" | |
transform = [] | |
for a in cfg: | |
transform.append(eval(a["name"])(**a["args"])) | |
return Compose(transform) | |
class AudioClipsTransform: | |
def __init__(self, audio_transform): | |
"""Applies image transform to each frame of each video clip.""" | |
self.audio_transform = audio_transform | |
def __call__(self, audio_clips): | |
""" | |
Args: | |
audio_clips (list): list of audio clips, each tensor [1, M] | |
where M is number of samples in each clip | |
""" | |
transformed_audio_clips = [self.audio_transform(x) for x in audio_clips] | |
# transformed_audio_clips = [] | |
# for clip in audio_clips: | |
# transformed_clip = [self.audio_transform(x) for x in clip] | |
# transformed_audio_clips.append(transformed_clip) | |
return transformed_audio_clips | |
def __repr__(self): | |
return self.audio_transform.__repr__() | |
class NumpyToTensor: | |
def __call__(self, x): | |
return torch.from_numpy(x).float() | |
def __repr__(self): | |
return self.__class__.__name__ + "()" | |
# TODO: Might have to introduce normalisation | |
# to have a consistent pipeline. | |
class Wav2Vec2WaveformProcessor: | |
def __init__(self, model_name="facebook/wav2vec2-base-960h", sr=16000): | |
from transformers import Wav2Vec2Processor | |
self.processor = Wav2Vec2Processor.from_pretrained(model_name) | |
self.sr = sr | |
def __call__(self, x): | |
x = self.processor( | |
x, sampling_rate=self.sr, return_tensors="pt", | |
).input_values | |
return x | |
def define_audio_transforms(cfg_transform, augment=False): | |
wave_transforms = cfg_transform["audio"]["wave"] | |
wave_transforms_new = [] | |
# Only pick augmentations if augment=True | |
for t in wave_transforms: | |
if "augmentation" not in t: | |
wave_transforms_new.append(t) | |
else: | |
if augment and t["augmentation"]: | |
wave_transforms_new.append(t) | |
# print(wave_transforms_new) | |
wave_transform = configure_transforms(wave_transforms_new) | |
wave_transform = AudioClipsTransform(wave_transform) | |
# wave_transform = configure_transforms( | |
# cfg_transform["audio"]["wave"], | |
# ) | |
# wave_transform = AudioClipsTransform(wave_transform) | |
# spec_transform = configure_transforms( | |
# cfg_transform["audio"]["spec"], | |
# ) | |
# spec_transform = AudioClipsTransform(spec_transform) | |
audio_transform = dict( | |
wave=wave_transform, | |
# spec=spec_transform, | |
) | |
return audio_transform | |
if __name__ == "__main__": | |
# Testing it out | |
# Raw waveform transform | |
cfg = [ | |
{ | |
"name": "AddNoise", | |
"args": {"noise_level": 0.1}, | |
}, | |
{ | |
"name": "ChangeVolume", | |
"args": {"volume_factor": [0.6, 1.2]}, | |
}, | |
] | |
transform = configure_transforms(cfg) | |
x = torch.randn([1, 16000]) | |
z = transform(x) | |
print(x.shape, z.shape) | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots(2, 1, figsize=(8, 4)) | |
ax[0].plot(x[0].numpy()) | |
ax[1].plot(z[0].numpy()) | |
plt.savefig("waveform_transform.png") | |
# Wav2Vec2 transform | |
cfg = [ | |
{ | |
"name": "Wav2Vec2WaveformProcessor", | |
"args": {"model_name": "facebook/wav2vec2-base-960h", "sr": 16000}, | |
}, | |
] | |
transform = configure_transforms(cfg) | |
x = torch.randn([4, 16000]) | |
z = transform(x) | |
print(x.shape, z.shape) | |
# Spectrogram transform | |
cfg = [ | |
{ | |
"name": "T.FrequencyMasking", | |
"args": {"freq_mask_param": 8}, | |
}, | |
{ | |
"name": "T.TimeMasking", | |
"args": {"time_mask_param": 16}, | |
}, | |
] | |
transform = configure_transforms(cfg) | |
x = torch.randn([1, 64, 251]) | |
z = transform(x) | |
print(x.shape, z.shape) | |
fig, ax = plt.subplots(2, 1, figsize=(8, 4)) | |
ax[0].imshow(x[0].numpy()) | |
ax[1].imshow(z[0].numpy()) | |
plt.savefig("spectrogram_transform.png") | |