SoundOfWater / sound_of_water /data /audio_transforms.py
bpiyush's picture
Upload sound_of_water/data/audio_transforms.py with huggingface_hub
701e2d4 verified
raw
history blame
5.16 kB
"""Audio transforms."""
# import torchaudio
# import torchvision
from torchvision.transforms import Compose, ToTensor
import torchaudio.transforms as T
# import imgaug.augmenters as iaa
import numpy as np
import torch
class AddNoise(object):
"""Add noise to the waveform."""
def __init__(self, noise_level=0.1):
self.noise_level = noise_level
def __call__(self, waveform):
noise = torch.randn_like(waveform)
return waveform + self.noise_level * noise
def __repr__(self):
return self.__class__.__name__ + f"(noise_level={self.noise_level})"
class ChangeVolume(object):
"""Change the volume of the waveform."""
def __init__(self, volume_factor=[0.6, 1.2]):
self.volume_factor = volume_factor
def __call__(self, waveform):
return waveform * np.random.uniform(*self.volume_factor)
def __repr__(self):
return self.__class__.__name__ + f"(volume_factor={self.volume_factor})"
def configure_transforms(cfg):
"""
Given a transform config (List[dict]), return a Compose object that
applies the transforms in order.
"""
transform = []
for a in cfg:
transform.append(eval(a["name"])(**a["args"]))
return Compose(transform)
class AudioClipsTransform:
def __init__(self, audio_transform):
"""Applies image transform to each frame of each video clip."""
self.audio_transform = audio_transform
def __call__(self, audio_clips):
"""
Args:
audio_clips (list): list of audio clips, each tensor [1, M]
where M is number of samples in each clip
"""
transformed_audio_clips = [self.audio_transform(x) for x in audio_clips]
# transformed_audio_clips = []
# for clip in audio_clips:
# transformed_clip = [self.audio_transform(x) for x in clip]
# transformed_audio_clips.append(transformed_clip)
return transformed_audio_clips
def __repr__(self):
return self.audio_transform.__repr__()
class NumpyToTensor:
def __call__(self, x):
return torch.from_numpy(x).float()
def __repr__(self):
return self.__class__.__name__ + "()"
# TODO: Might have to introduce normalisation
# to have a consistent pipeline.
class Wav2Vec2WaveformProcessor:
def __init__(self, model_name="facebook/wav2vec2-base-960h", sr=16000):
from transformers import Wav2Vec2Processor
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.sr = sr
def __call__(self, x):
x = self.processor(
x, sampling_rate=self.sr, return_tensors="pt",
).input_values
return x
def define_audio_transforms(cfg_transform, augment=False):
wave_transforms = cfg_transform["audio"]["wave"]
wave_transforms_new = []
# Only pick augmentations if augment=True
for t in wave_transforms:
if "augmentation" not in t:
wave_transforms_new.append(t)
else:
if augment and t["augmentation"]:
wave_transforms_new.append(t)
# print(wave_transforms_new)
wave_transform = configure_transforms(wave_transforms_new)
wave_transform = AudioClipsTransform(wave_transform)
# wave_transform = configure_transforms(
# cfg_transform["audio"]["wave"],
# )
# wave_transform = AudioClipsTransform(wave_transform)
# spec_transform = configure_transforms(
# cfg_transform["audio"]["spec"],
# )
# spec_transform = AudioClipsTransform(spec_transform)
audio_transform = dict(
wave=wave_transform,
# spec=spec_transform,
)
return audio_transform
if __name__ == "__main__":
# Testing it out
# Raw waveform transform
cfg = [
{
"name": "AddNoise",
"args": {"noise_level": 0.1},
},
{
"name": "ChangeVolume",
"args": {"volume_factor": [0.6, 1.2]},
},
]
transform = configure_transforms(cfg)
x = torch.randn([1, 16000])
z = transform(x)
print(x.shape, z.shape)
import matplotlib.pyplot as plt
fig, ax = plt.subplots(2, 1, figsize=(8, 4))
ax[0].plot(x[0].numpy())
ax[1].plot(z[0].numpy())
plt.savefig("waveform_transform.png")
# Wav2Vec2 transform
cfg = [
{
"name": "Wav2Vec2WaveformProcessor",
"args": {"model_name": "facebook/wav2vec2-base-960h", "sr": 16000},
},
]
transform = configure_transforms(cfg)
x = torch.randn([4, 16000])
z = transform(x)
print(x.shape, z.shape)
# Spectrogram transform
cfg = [
{
"name": "T.FrequencyMasking",
"args": {"freq_mask_param": 8},
},
{
"name": "T.TimeMasking",
"args": {"time_mask_param": 16},
},
]
transform = configure_transforms(cfg)
x = torch.randn([1, 64, 251])
z = transform(x)
print(x.shape, z.shape)
fig, ax = plt.subplots(2, 1, figsize=(8, 4))
ax[0].imshow(x[0].numpy())
ax[1].imshow(z[0].numpy())
plt.savefig("spectrogram_transform.png")