Spaces:
Configuration error
Configuration error
File size: 1,138 Bytes
b7f4dbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
from transformers import Wav2Vec2Processor
INPUT_FIELD = "input_values"
LABEL_FIELD = "labels"
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(
self, examples: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
input_features = [
{INPUT_FIELD: example[INPUT_FIELD]} for example in examples
] # example is basically row0, row1, etc...
labels = [example[LABEL_FIELD] for example in examples]
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
batch[LABEL_FIELD] = torch.tensor(labels)
return batch
|