from dataclasses import dataclass from typing import Dict, List, Optional, Union import torch import transformers from transformers import Wav2Vec2Processor, Wav2Vec2FeatureExtractor @dataclass class DataCollatorCTCWithPadding: feature_extractor: Wav2Vec2FeatureExtractor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [feature["labels"] for feature in features] d_type = torch.long if isinstance(label_features[0], int) else torch.float batch = self.feature_extractor.pad( input_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, return_tensors="pt", ) batch["labels"] = torch.tensor(label_features, dtype=d_type) return batch