jone's picture
init
75c6e9a
from typing import Dict
import librosa
import numpy as np
from bytesep.utils import db_to_magnitude, get_pitch_shift_factor, magnitude_to_db
class Augmentor:
def __init__(self, augmentations: Dict, random_seed=1234):
r"""Augmentor for data augmentation of a waveform.
Args:
augmentations: Dict, e.g, {
'mixaudio': {'vocals': 2, 'accompaniment': 2}
'pitch_shift': {'vocals': 4, 'accompaniment': 4},
...,
}
random_seed: int
"""
self.augmentations = augmentations
self.random_state = np.random.RandomState(random_seed)
def __call__(self, waveform: np.array, source_type: str) -> np.array:
r"""Augment a waveform.
Args:
waveform: (channels_num, audio_samples)
source_type: str
Returns:
new_waveform: (channels_num, new_audio_samples)
"""
if 'pitch_shift' in self.augmentations.keys():
waveform = self.pitch_shift(waveform, source_type)
if 'magnitude_scale' in self.augmentations.keys():
waveform = self.magnitude_scale(waveform, source_type)
if 'swap_channel' in self.augmentations.keys():
waveform = self.swap_channel(waveform, source_type)
if 'flip_axis' in self.augmentations.keys():
waveform = self.flip_axis(waveform, source_type)
return waveform
def pitch_shift(self, waveform: np.array, source_type: str) -> np.array:
r"""Shift the pitch of a waveform. We use resampling for fast pitch
shifting, so the speed will also be chaneged. The length of the returned
waveform will be changed.
Args:
waveform: (channels_num, audio_samples)
source_type: str
Returns:
new_waveform: (channels_num, new_audio_samples)
"""
# maximum pitch shift in semitones
max_pitch_shift = self.augmentations['pitch_shift'][source_type]
if max_pitch_shift == 0: # No pitch shift augmentations.
return waveform
# random pitch shift
rand_pitch = self.random_state.uniform(
low=-max_pitch_shift, high=max_pitch_shift
)
# We use librosa.resample instead of librosa.effects.pitch_shift
# because it is 10x times faster.
pitch_shift_factor = get_pitch_shift_factor(rand_pitch)
dummy_sample_rate = 10000 # Dummy constant.
channels_num = waveform.shape[0]
if channels_num == 1:
waveform = np.squeeze(waveform)
new_waveform = librosa.resample(
y=waveform,
orig_sr=dummy_sample_rate,
target_sr=dummy_sample_rate / pitch_shift_factor,
res_type='linear',
axis=-1,
)
if channels_num == 1:
new_waveform = new_waveform[None, :]
return new_waveform
def magnitude_scale(self, waveform: np.array, source_type: str) -> np.array:
r"""Scale the magnitude of a waveform.
Args:
waveform: (channels_num, audio_samples)
source_type: str
Returns:
new_waveform: (channels_num, audio_samples)
"""
lower_db = self.augmentations['magnitude_scale'][source_type]['lower_db']
higher_db = self.augmentations['magnitude_scale'][source_type]['higher_db']
if lower_db == 0 and higher_db == 0: # No magnitude scale augmentation.
return waveform
# The magnitude (in dB) of the sample with the maximum value.
waveform_db = magnitude_to_db(np.max(np.abs(waveform)))
new_waveform_db = self.random_state.uniform(
waveform_db + lower_db, min(waveform_db + higher_db, 0)
)
relative_db = new_waveform_db - waveform_db
relative_scale = db_to_magnitude(relative_db)
new_waveform = waveform * relative_scale
return new_waveform
def swap_channel(self, waveform: np.array, source_type: str) -> np.array:
r"""Randomly swap channels.
Args:
waveform: (channels_num, audio_samples)
source_type: str
Returns:
new_waveform: (channels_num, audio_samples)
"""
ndim = waveform.shape[0]
if ndim == 1:
return waveform
else:
random_axes = self.random_state.permutation(ndim)
return waveform[random_axes, :]
def flip_axis(self, waveform: np.array, source_type: str) -> np.array:
r"""Randomly flip the waveform along x-axis.
Args:
waveform: (channels_num, audio_samples)
source_type: str
Returns:
new_waveform: (channels_num, audio_samples)
"""
ndim = waveform.shape[0]
random_values = self.random_state.choice([-1, 1], size=ndim)
return waveform * random_values[:, None]