|
from typing import Any, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from transformers import Wav2Vec2Processor |
|
|
|
from torch_audiomentations import Compose, Gain |
|
from audiomentations import ( |
|
Compose, |
|
AddGaussianNoise, |
|
AddGaussianSNR, |
|
ClippingDistortion, |
|
FrequencyMask, |
|
Gain, |
|
LoudnessNormalization, |
|
Normalize, |
|
PitchShift, |
|
PolarityInversion, |
|
Shift, |
|
TimeMask, |
|
TimeStretch, |
|
) |
|
|
|
|
|
class DataCollatorCTCWithPadding: |
|
|
|
def __init__( |
|
self, |
|
processor: Wav2Vec2Processor, |
|
padding: Union[bool, str] = True, |
|
sample_rate: int = 16_000, |
|
apply_gaussian_noise_with_p: float = 0, |
|
apply_gain_with_p: float = 0, |
|
apply_pitch_shift_with_p: float = 0, |
|
apply_time_stretch_with_p: float = 0, |
|
): |
|
self.processor = processor |
|
self.padding = padding |
|
self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p |
|
self.apply_gain_with_p = apply_gain_with_p |
|
self.apply_pitch_shift_with_p = apply_pitch_shift_with_p |
|
self.apply_time_stretch_with_p = apply_time_stretch_with_p |
|
self.sample_rate = sample_rate |
|
|
|
self.augmentator = None |
|
if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0: |
|
self.augmentator = Compose([ |
|
TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p), |
|
PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p), |
|
Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p), |
|
AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p), |
|
]) |
|
|
|
def _apply_augmentation(self, input_values: List[float]): |
|
"""apply some audio augmentations in the given input_values""" |
|
if self.augmentator is not None: |
|
return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist() |
|
else: |
|
return input_values |
|
|
|
def __call__( |
|
self, features: List[Dict[str, Union[List[int], torch.Tensor]]] |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
input_features = [ |
|
{"input_values": self._apply_augmentation(feature["input_values"])} for feature in features |
|
] |
|
label_features = [{"input_ids": feature["labels"]} for feature in features] |
|
|
|
batch = self.processor.pad( |
|
input_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
with self.processor.as_target_processor(): |
|
labels_batch = self.processor.pad( |
|
label_features, |
|
padding=self.padding, |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
labels = labels_batch["input_ids"].masked_fill( |
|
labels_batch.attention_mask.ne(1), -100 |
|
) |
|
|
|
batch["labels"] = labels |
|
|
|
return batch |