Mayank022 commited on
Commit
fcaae5c
·
verified ·
1 Parent(s): f5b3371

Update data.py

Browse files
Files changed (1) hide show
  1. data.py +87 -0
data.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+
5
+ import transformers
6
+ import datasets
7
+ from typing import List, Dict, Any, Optional
8
+ import dataclasses
9
+ from config import ModelConfig, TrainConfig
10
+
11
+ class AudioTextDataset(Dataset):
12
+ def __init__(self, train_config: TrainConfig, processor: transformers.AutoProcessor, model_config: ModelConfig, tokenizer: transformers.PreTrainedTokenizer):
13
+ self.sampling_rate = 16000
14
+ print(f"Loading dataset: {train_config.dataset_name} ({train_config.dataset_subset}) split={train_config.dataset_split}")
15
+ self.dataset = datasets.load_dataset(
16
+ train_config.dataset_name,
17
+ train_config.dataset_subset,
18
+ split=train_config.dataset_split,
19
+ verification_mode="no_checks", # avoid NonMatchingSplitsSizesError when Hub metadata differs from cached
20
+ )
21
+ # Audio(sampling_rate=...) decodes and resamples via TorchCodec; requires system FFmpeg (apt install ffmpeg)
22
+ self.dataset = self.dataset.cast_column("audio", datasets.Audio(sampling_rate=self.sampling_rate))
23
+
24
+ self.processor = processor
25
+ self.tokenizer = tokenizer
26
+ self.model_config = model_config
27
+
28
+ def __len__(self):
29
+ return len(self.dataset)
30
+
31
+ def __getitem__(self, idx):
32
+ item = self.dataset[idx]
33
+ # HF Audio returns {'audio': {'array': ..., 'sampling_rate': ...}, 'sentence': ...}
34
+ audio_array = item["audio"]["array"]
35
+ sampling_rate = item["audio"]["sampling_rate"]
36
+ text = item.get("sentence", item.get("text", ""))
37
+ continuation = item.get("continuation", item.get("continuation_text", ""))
38
+
39
+ audio = torch.from_numpy(audio_array).float()
40
+ if audio.ndim == 1:
41
+ audio = audio.unsqueeze(0) # (1, T)
42
+ elif audio.shape[0] > 1:
43
+ audio = audio.mean(dim=0, keepdim=True) # mono
44
+
45
+ audio_inputs = self.processor(audio.squeeze().numpy(), sampling_rate=sampling_rate or self.sampling_rate, return_tensors="pt")
46
+ audio_values = audio_inputs.input_features.squeeze(0)
47
+
48
+ text_inputs = self.tokenizer(text, return_tensors="pt", padding=False, truncation=True)
49
+ input_ids = text_inputs.input_ids.squeeze(0)
50
+ labels = input_ids.clone()
51
+
52
+ return {
53
+ "audio_values": audio_values,
54
+ "input_ids": input_ids,
55
+ "labels": labels,
56
+ "continuation": continuation,
57
+ }
58
+
59
+ @dataclasses.dataclass
60
+ class DataCollator:
61
+ processor: transformers.AutoProcessor
62
+ tokenizer: transformers.PreTrainedTokenizer
63
+
64
+ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
65
+ audio_values = [f["audio_values"] for f in features]
66
+ input_ids = [f["input_ids"] for f in features]
67
+ labels = [f["labels"] for f in features]
68
+ continuations = [f.get("continuation", "") for f in features]
69
+
70
+ if audio_values[0].shape[-1] == 3000:
71
+ audio_batch = torch.stack(audio_values)
72
+ else:
73
+ audio_values_T = [a.T for a in audio_values]
74
+ audio_batch_T = torch.nn.utils.rnn.pad_sequence(audio_values_T, batch_first=True)
75
+ audio_batch = audio_batch_T.transpose(1, 2)
76
+
77
+
78
+ input_ids_batch = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
79
+ labels_batch = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
80
+
81
+ return {
82
+ "audio_values": audio_batch,
83
+ "input_ids": input_ids_batch,
84
+ "labels": labels_batch,
85
+ "attention_mask": (input_ids_batch != self.tokenizer.pad_token_id).long(),
86
+ "continuation": continuations,
87
+ }