aspram / aspram /collator.py
lilitket's picture
Move to package
cab7f7b
raw history blame
No virus
3.15 kB
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Wav2Vec2Processor
from torch_audiomentations import Compose, Gain
from audiomentations import (
Compose,
AddGaussianNoise,
AddGaussianSNR,
ClippingDistortion,
FrequencyMask,
Gain,
LoudnessNormalization,
Normalize,
PitchShift,
PolarityInversion,
Shift,
TimeMask,
TimeStretch,
)
class DataCollatorCTCWithPadding:
def __init__(
self,
processor: Wav2Vec2Processor,
padding: Union[bool, str] = True,
sample_rate: int = 16_000,
apply_gaussian_noise_with_p: float = 0,
apply_gain_with_p: float = 0,
apply_pitch_shift_with_p: float = 0,
apply_time_stretch_with_p: float = 0,
):
self.processor = processor
self.padding = padding
self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p
self.apply_gain_with_p = apply_gain_with_p
self.apply_pitch_shift_with_p = apply_pitch_shift_with_p
self.apply_time_stretch_with_p = apply_time_stretch_with_p
self.sample_rate = sample_rate
self.augmentator = None
if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0:
self.augmentator = Compose([
TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p),
PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p),
Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p),
AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p),
])
def _apply_augmentation(self, input_values: List[float]):
"""apply some audio augmentations in the given input_values"""
if self.augmentator is not None:
return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist()
else:
return input_values
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# TODO maybe disable augmentation in inference mode?
input_features = [
{"input_values": self._apply_augmentation(feature["input_values"])} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
batch["labels"] = labels
return batch