File size: 3,146 Bytes
5f1c16f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import Wav2Vec2Processor
from torch_audiomentations import Compose, Gain
from audiomentations import (
Compose,
AddGaussianNoise,
AddGaussianSNR,
ClippingDistortion,
FrequencyMask,
Gain,
LoudnessNormalization,
Normalize,
PitchShift,
PolarityInversion,
Shift,
TimeMask,
TimeStretch,
)
class DataCollatorCTCWithPadding:
def __init__(
self,
processor: Wav2Vec2Processor,
padding: Union[bool, str] = True,
sample_rate: int = 16_000,
apply_gaussian_noise_with_p: float = 0,
apply_gain_with_p: float = 0,
apply_pitch_shift_with_p: float = 0,
apply_time_stretch_with_p: float = 0,
):
self.processor = processor
self.padding = padding
self.apply_gaussian_noise_with_p = apply_gaussian_noise_with_p
self.apply_gain_with_p = apply_gain_with_p
self.apply_pitch_shift_with_p = apply_pitch_shift_with_p
self.apply_time_stretch_with_p = apply_time_stretch_with_p
self.sample_rate = sample_rate
self.augmentator = None
if self.apply_gaussian_noise_with_p + self.apply_gain_with_p + self.apply_pitch_shift_with_p + self.apply_time_stretch_with_p > 0:
self.augmentator = Compose([
TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=False, p=self.apply_time_stretch_with_p),
PitchShift(min_semitones=-1, max_semitones=1, p=self.apply_pitch_shift_with_p),
Gain(min_gain_in_db=-1, max_gain_in_db=1, p=self.apply_gain_with_p),
AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=self.apply_gaussian_noise_with_p),
])
def _apply_augmentation(self, input_values: List[float]):
"""apply some audio augmentations in the given input_values"""
if self.augmentator is not None:
return self.augmentator(samples=np.array(input_values), sample_rate=self.sample_rate).tolist()
else:
return input_values
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# TODO maybe disable augmentation in inference mode?
input_features = [
{"input_values": self._apply_augmentation(feature["input_values"])} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding=self.padding,
return_tensors="pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding=self.padding,
return_tensors="pt",
)
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100
)
batch["labels"] = labels
return batch |