aoxo commited on
Commit
6ac0719
1 Parent(s): d6525d6

Upload trainer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. trainer.py +289 -0
trainer.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from datasets import ClassLabel, Dataset, DatasetDict, load_dataset
4
+ from datasets.features import Audio
5
+ import pandas as pd
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from IPython.display import display, HTML
9
+
10
+ # Function to load your custom dataset
11
+ def load_custom_dataset(data_dir):
12
+ data = {
13
+ "audio": [],
14
+ "text": []
15
+ }
16
+
17
+ wav_dir = os.path.join(data_dir, 'wav')
18
+ txt_dir = os.path.join(data_dir, 'transcription')
19
+
20
+ # Assuming filenames in 'wav' and 'txt' match
21
+ for wav_file in os.listdir(wav_dir):
22
+ if wav_file.endswith('.wav'):
23
+ txt_file = wav_file.replace('.wav', '.txt')
24
+ wav_path = os.path.join(wav_dir, wav_file)
25
+ txt_path = os.path.join(txt_dir, txt_file)
26
+
27
+ # Read the transcription text
28
+ with open(txt_path, 'r', encoding='utf-8') as f:
29
+ transcription = f.read().strip()
30
+
31
+ # Append to the dataset
32
+ data["audio"].append(wav_path)
33
+ data["text"].append(transcription)
34
+
35
+ # Create a pandas dataframe
36
+ df = pd.DataFrame(data)
37
+
38
+ # Convert to a Hugging Face dataset
39
+ dataset = Dataset.from_pandas(df)
40
+
41
+ # Define the audio feature (for .wav files)
42
+ dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) # Adjust the sampling rate if needed
43
+
44
+ return dataset
45
+
46
+ custom_train_dataset = load_custom_dataset("./")
47
+
48
+ # Combine them into a DatasetDict
49
+ dataset_dict = DatasetDict({
50
+ "train": custom_train_dataset,
51
+ })
52
+
53
+ # Select 975 random samples from train and add them to test
54
+ train_size = len(dataset_dict["train"])
55
+ sample_indices = random.sample(range(train_size), 975)
56
+
57
+ # Select the samples
58
+ test_samples = dataset_dict["train"].select(sample_indices)
59
+
60
+ # Filter out the selected samples from the train dataset
61
+ remaining_train_samples = dataset_dict["train"].filter(lambda example, idx: idx not in sample_indices, with_indices=True)
62
+
63
+ # Add the selected samples to the test dataset
64
+ dataset_dict["test"] = test_samples
65
+ dataset_dict["train"] = remaining_train_samples
66
+
67
+ print(dataset_dict)
68
+
69
+ def show_random_elements(dataset, num_examples=10):
70
+ assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
71
+ picks = []
72
+ for _ in range(num_examples):
73
+ pick = random.randint(0, len(dataset)-1)
74
+ while pick in picks:
75
+ pick = random.randint(0, len(dataset)-1)
76
+ picks.append(pick)
77
+
78
+ df = pd.DataFrame(dataset[picks])
79
+
80
+ show_random_elements(dataset_dict["train"])
81
+
82
+ import re
83
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
84
+
85
+ def remove_special_characters(batch):
86
+ batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower()
87
+ return batch
88
+
89
+ dataset_dict = dataset_dict.map(remove_special_characters)
90
+
91
+ show_random_elements(dataset_dict["train"])
92
+
93
+ def extract_all_chars(batch):
94
+ all_text = " ".join(batch["text"])
95
+ vocab = list(set(all_text))
96
+ return {"vocab": [vocab], "all_text": [all_text]}
97
+
98
+ vocabs = dataset_dict.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=dataset_dict.column_names["train"])
99
+
100
+ vocab_list = list(set(vocabs["train"]["vocab"][0]))
101
+
102
+ vocab_dict = {v: k for k, v in enumerate(vocab_list)}
103
+ print(vocab_dict)
104
+
105
+ vocab_dict["[UNK]"] = len(vocab_dict)
106
+ vocab_dict["[PAD]"] = len(vocab_dict)
107
+ print(len(vocab_dict))
108
+
109
+ import json
110
+ with open('vocab.json', 'w') as vocab_file:
111
+ json.dump(vocab_dict, vocab_file)
112
+
113
+ from transformers import Wav2Vec2CTCTokenizer
114
+
115
+ tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", vocab_size=len(vocab_dict))
116
+
117
+ from transformers import Wav2Vec2FeatureExtractor
118
+
119
+ feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
120
+
121
+ from transformers import Wav2Vec2Processor
122
+
123
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
124
+
125
+ rand_int = random.randint(0, len(dataset_dict["train"]))
126
+
127
+ print("Target text:", dataset_dict["train"][rand_int]["text"])
128
+ print("Input array shape:", np.asarray(dataset_dict["train"][rand_int]["audio"]["array"]).shape)
129
+ print("Sampling rate:", dataset_dict["train"][rand_int]["audio"]["sampling_rate"])
130
+
131
+ def prepare_dataset(batch):
132
+ audio = batch["audio"]
133
+
134
+ # batched output is "un-batched" to ensure mapping is correct
135
+ batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
136
+
137
+ with processor.as_target_processor():
138
+ batch["labels"] = processor(batch["text"]).input_ids
139
+ return batch
140
+
141
+ dataset_dict = dataset_dict.map(prepare_dataset, remove_columns=dataset_dict.column_names["train"], num_proc=None)
142
+
143
+ import torch
144
+
145
+ from dataclasses import dataclass, field
146
+ from typing import Any, Dict, List, Optional, Union
147
+
148
+ @dataclass
149
+ class DataCollatorCTCWithPadding:
150
+ """
151
+ Data collator that will dynamically pad the inputs received.
152
+ Args:
153
+ processor (:class:`~transformers.Wav2Vec2Processor`)
154
+ The processor used for proccessing the data.
155
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
156
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
157
+ among:
158
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
159
+ sequence if provided).
160
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
161
+ maximum acceptable input length for the model if that argument is not provided.
162
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
163
+ different lengths).
164
+ max_length (:obj:`int`, `optional`):
165
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
166
+ max_length_labels (:obj:`int`, `optional`):
167
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
168
+ pad_to_multiple_of (:obj:`int`, `optional`):
169
+ If set will pad the sequence to a multiple of the provided value.
170
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
171
+ 7.5 (Volta).
172
+ """
173
+
174
+ processor: Wav2Vec2Processor
175
+ padding: Union[bool, str] = True
176
+ max_length: Optional[int] = None
177
+ max_length_labels: Optional[int] = None
178
+ pad_to_multiple_of: Optional[int] = None
179
+ pad_to_multiple_of_labels: Optional[int] = None
180
+
181
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
182
+ # split inputs and labels since they have to be of different lengths and need
183
+ # different padding methods
184
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
185
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
186
+
187
+ batch = self.processor.pad(
188
+ input_features,
189
+ padding=self.padding,
190
+ max_length=self.max_length,
191
+ pad_to_multiple_of=self.pad_to_multiple_of,
192
+ return_tensors="pt",
193
+ )
194
+ with self.processor.as_target_processor():
195
+ labels_batch = self.processor.pad(
196
+ label_features,
197
+ padding=self.padding,
198
+ max_length=self.max_length_labels,
199
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
200
+ return_tensors="pt",
201
+ )
202
+
203
+ # replace padding with -100 to ignore loss correctly
204
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
205
+
206
+ batch["labels"] = labels
207
+
208
+ return batch
209
+
210
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
211
+
212
+ import evaluate
213
+
214
+ wer_metric = evaluate.load("wer")
215
+
216
+ def compute_metrics(pred):
217
+ pred_logits = pred.predictions
218
+ pred_ids = np.argmax(pred_logits, axis=-1)
219
+
220
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
221
+
222
+ pred_str = processor.batch_decode(pred_ids)
223
+ # we do not want to group tokens when computing the metrics
224
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
225
+
226
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
227
+
228
+ return {"wer": wer}
229
+
230
+ from transformers import Wav2Vec2ForCTC
231
+
232
+ model = Wav2Vec2ForCTC.from_pretrained(
233
+ "facebook/wav2vec2-large",
234
+ ctc_loss_reduction="mean",
235
+ pad_token_id=processor.tokenizer.pad_token_id,
236
+ vocab_size=len(vocab_dict),
237
+ )
238
+
239
+ model.freeze_feature_encoder()
240
+
241
+ model.gradient_checkpointing_enable()
242
+
243
+ from transformers import TrainingArguments
244
+
245
+ training_args = TrainingArguments(
246
+ output_dir='wav2vec2-large-mal',
247
+ group_by_length=True,
248
+ per_device_train_batch_size=36,
249
+ eval_strategy="steps",
250
+ num_train_epochs=30,
251
+ fp16=True,
252
+ gradient_checkpointing=True,
253
+ save_steps=500,
254
+ eval_steps=500,
255
+ logging_steps=500,
256
+ learning_rate=1e-4,
257
+ weight_decay=0.005,
258
+ warmup_steps=1000,
259
+ save_total_limit=2,
260
+ )
261
+
262
+ from transformers import Trainer
263
+
264
+ trainer = Trainer(
265
+ model=model,
266
+ data_collator=data_collator,
267
+ args=training_args,
268
+ compute_metrics=compute_metrics,
269
+ train_dataset=dataset_dict["train"],
270
+ eval_dataset=dataset_dict["test"],
271
+ processing_class=processor.feature_extractor,
272
+ )
273
+
274
+ trainer.train()
275
+
276
+ def map_to_result(batch):
277
+ with torch.no_grad():
278
+ input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
279
+ logits = model(input_values).logits
280
+
281
+ pred_ids = torch.argmax(logits, dim=-1)
282
+ batch["pred_str"] = processor.batch_decode(pred_ids)[0]
283
+ batch["text"] = processor.decode(batch["labels"], group_tokens=False)
284
+
285
+ return batch
286
+
287
+ results = dataset_dict["test"].map(map_to_result, remove_columns=dataset_dict["test"].column_names)
288
+
289
+ print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["text"])))