Spaces:
Configuration error
Configuration error
from dataclasses import dataclass | |
from typing import Dict, List, Optional, Union | |
import torch | |
from transformers import Wav2Vec2Processor | |
INPUT_FIELD = "input_values" | |
LABEL_FIELD = "labels" | |
class DataCollatorCTCWithPadding: | |
processor: Wav2Vec2Processor | |
padding: Union[bool, str] = True | |
max_length: Optional[int] = None | |
max_length_labels: Optional[int] = None | |
pad_to_multiple_of: Optional[int] = None | |
pad_to_multiple_of_labels: Optional[int] = None | |
def __call__( | |
self, examples: List[Dict[str, Union[List[int], torch.Tensor]]] | |
) -> Dict[str, torch.Tensor]: | |
input_features = [ | |
{INPUT_FIELD: example[INPUT_FIELD]} for example in examples | |
] # example is basically row0, row1, etc... | |
labels = [example[LABEL_FIELD] for example in examples] | |
batch = self.processor.pad( | |
input_features, | |
padding=self.padding, | |
max_length=self.max_length, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors="pt", | |
) | |
batch[LABEL_FIELD] = torch.tensor(labels) | |
return batch | |