supun9's picture
Upload 5 files
b7f4dbe
raw
history blame
No virus
1.14 kB
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
from transformers import Wav2Vec2Processor
INPUT_FIELD = "input_values"
LABEL_FIELD = "labels"
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(
self, examples: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
input_features = [
{INPUT_FIELD: example[INPUT_FIELD]} for example in examples
] # example is basically row0, row1, etc...
labels = [example[LABEL_FIELD] for example in examples]
batch = self.processor.pad(
input_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
batch[LABEL_FIELD] = torch.tensor(labels)
return batch