kmfoda commited on
Commit
9ff32e7
โ€ข
1 Parent(s): 5a6ce3a

Update README.md with correct WER

Browse files
Files changed (2) hide show
  1. README.md +17 -3
  2. run_common_voice.py +518 -0
README.md CHANGED
@@ -47,8 +47,15 @@ test_dataset = load_dataset("common_voice", "ar", split="test[:2%]")
47
  processor = Wav2Vec2Processor.from_pretrained("kmfoda/wav2vec2-large-xlsr-arabic")
48
  model = Wav2Vec2ForCTC.from_pretrained("kmfoda/wav2vec2-large-xlsr-arabic")
49
 
 
 
 
 
 
 
50
  def prepare_example(example):
51
- example["speech"], _ = librosa.load(example["path"], sr=16000)
 
52
  return example
53
 
54
  test_dataset = test_dataset.map(prepare_example)
@@ -84,8 +91,15 @@ model.to("cuda")
84
 
85
  chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\โ€œ\ุŸ\_\ุ›\ู€\โ€”]'
86
 
 
 
 
 
 
 
87
  def prepare_example(example):
88
- example["speech"], _ = librosa.load(example["path"], sr=16000)
 
89
  return example
90
 
91
  test_dataset = test_dataset.map(prepare_example)
@@ -108,7 +122,7 @@ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"],
108
 
109
  ```
110
 
111
- **Test Result**: ??
112
 
113
 
114
  ## Training
47
  processor = Wav2Vec2Processor.from_pretrained("kmfoda/wav2vec2-large-xlsr-arabic")
48
  model = Wav2Vec2ForCTC.from_pretrained("kmfoda/wav2vec2-large-xlsr-arabic")
49
 
50
+ resamplers = { # all three sampling rates exist in test split
51
+ 48000: torchaudio.transforms.Resample(48000, 16000),
52
+ 44100: torchaudio.transforms.Resample(44100, 16000),
53
+ 32000: torchaudio.transforms.Resample(32000, 16000),
54
+ }
55
+
56
  def prepare_example(example):
57
+ speech, sampling_rate = torchaudio.load(example["path"])
58
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
59
  return example
60
 
61
  test_dataset = test_dataset.map(prepare_example)
91
 
92
  chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\โ€œ\ุŸ\_\ุ›\ู€\โ€”]'
93
 
94
+ resamplers = { # all three sampling rates exist in test split
95
+ 48000: torchaudio.transforms.Resample(48000, 16000),
96
+ 44100: torchaudio.transforms.Resample(44100, 16000),
97
+ 32000: torchaudio.transforms.Resample(32000, 16000),
98
+ }
99
+
100
  def prepare_example(example):
101
+ speech, sampling_rate = torchaudio.load(example["path"])
102
+ example["speech"] = resamplers[sampling_rate](speech).squeeze().numpy()
103
  return example
104
 
105
  test_dataset = test_dataset.map(prepare_example)
122
 
123
  ```
124
 
125
+ **Test Result**: 52.53
126
 
127
 
128
  ## Training
run_common_voice.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import datasets
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ from packaging import version
15
+ from torch import nn
16
+
17
+ import transformers
18
+ from transformers import (
19
+ HfArgumentParser,
20
+ Trainer,
21
+ TrainingArguments,
22
+ Wav2Vec2CTCTokenizer,
23
+ Wav2Vec2FeatureExtractor,
24
+ Wav2Vec2ForCTC,
25
+ Wav2Vec2Processor,
26
+ is_apex_available,
27
+ set_seed,
28
+ )
29
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
30
+
31
+
32
+ if is_apex_available():
33
+ from apex import amp
34
+
35
+
36
+ if version.parse(torch.__version__) >= version.parse("1.6"):
37
+ _is_native_amp_available = True
38
+ from torch.cuda.amp import autocast
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ def list_field(default=None, metadata=None):
44
+ return field(default_factory=lambda: default, metadata=metadata)
45
+
46
+ import wandb
47
+ wandb.login()
48
+ os.environ['WANDB_PROJECT'] = "ar-base-30e-hyperv3"
49
+ os.environ['WANDB_LOG_MODEL'] = "true"
50
+
51
+ @dataclass
52
+ class ModelArguments:
53
+ """
54
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
55
+ """
56
+
57
+ model_name_or_path: str = field(
58
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
59
+ )
60
+ cache_dir: Optional[str] = field(
61
+ default=None,
62
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
63
+ )
64
+ freeze_feature_extractor: Optional[bool] = field(
65
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
66
+ )
67
+ attention_dropout: Optional[float] = field(
68
+ default=0.1, metadata={"help": "The dropout ratio for the attention probabilities."}
69
+ )
70
+ activation_dropout: Optional[float] = field(
71
+ default=0.1, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
72
+ )
73
+ hidden_dropout: Optional[float] = field(
74
+ default=0.1,
75
+ metadata={
76
+ "help": "The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler."
77
+ },
78
+ )
79
+ feat_proj_dropout: Optional[float] = field(
80
+ default=0.1,
81
+ metadata={"help": "The dropout probabilitiy for all 1D convolutional layers in feature extractor."},
82
+ )
83
+ mask_time_prob: Optional[float] = field(
84
+ default=0.05,
85
+ metadata={
86
+ "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector"
87
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
88
+ "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
89
+ },
90
+ )
91
+ gradient_checkpointing: Optional[bool] = field(
92
+ default=True,
93
+ metadata={
94
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
95
+ },
96
+ )
97
+ layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
98
+
99
+
100
+ @dataclass
101
+ class DataTrainingArguments:
102
+ """
103
+ Arguments pertaining to what data we are going to input our model for training and eval.
104
+
105
+ Using `HfArgumentParser` we can turn this class
106
+ into argparse arguments to be able to specify them on
107
+ the command line.
108
+ """
109
+
110
+ dataset_config_name: Optional[str] = field(
111
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
112
+ )
113
+ train_split_name: Optional[str] = field(
114
+ default="train+validation",
115
+ metadata={
116
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
117
+ },
118
+ )
119
+ overwrite_cache: bool = field(
120
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
121
+ )
122
+ preprocessing_num_workers: Optional[int] = field(
123
+ default=None,
124
+ metadata={"help": "The number of processes to use for the preprocessing."},
125
+ )
126
+ max_train_samples: Optional[int] = field(
127
+ default=None,
128
+ metadata={
129
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
130
+ "value if set."
131
+ },
132
+ )
133
+ max_val_samples: Optional[int] = field(
134
+ default=None,
135
+ metadata={
136
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
137
+ "value if set."
138
+ },
139
+ )
140
+ chars_to_ignore: List[str] = list_field(
141
+ default=[",", "?", ".", "!", "-", ";", ":", '""', "%", "'", '"', "๏ฟฝ"],
142
+ metadata={"help": "A list of characters to remove from the transcripts."},
143
+ )
144
+
145
+
146
+ @dataclass
147
+ class DataCollatorCTCWithPadding:
148
+ """
149
+ Data collator that will dynamically pad the inputs received.
150
+ Args:
151
+ processor (:class:`~transformers.Wav2Vec2Processor`)
152
+ The processor used for proccessing the data.
153
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
154
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
155
+ among:
156
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
157
+ sequence if provided).
158
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
159
+ maximum acceptable input length for the model if that argument is not provided.
160
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
161
+ different lengths).
162
+ max_length (:obj:`int`, `optional`):
163
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
164
+ max_length_labels (:obj:`int`, `optional`):
165
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
166
+ pad_to_multiple_of (:obj:`int`, `optional`):
167
+ If set will pad the sequence to a multiple of the provided value.
168
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
169
+ 7.5 (Volta).
170
+ """
171
+
172
+ processor: Wav2Vec2Processor
173
+ padding: Union[bool, str] = True
174
+ max_length: Optional[int] = None
175
+ max_length_labels: Optional[int] = None
176
+ pad_to_multiple_of: Optional[int] = None
177
+ pad_to_multiple_of_labels: Optional[int] = None
178
+
179
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
180
+ # split inputs and labels since they have to be of different lenghts and need
181
+ # different padding methods
182
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
183
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
184
+
185
+ batch = self.processor.pad(
186
+ input_features,
187
+ padding=self.padding,
188
+ max_length=self.max_length,
189
+ pad_to_multiple_of=self.pad_to_multiple_of,
190
+ return_tensors="pt",
191
+ )
192
+ with self.processor.as_target_processor():
193
+ labels_batch = self.processor.pad(
194
+ label_features,
195
+ padding=self.padding,
196
+ max_length=self.max_length_labels,
197
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
198
+ return_tensors="pt",
199
+ )
200
+
201
+ # replace padding with -100 to ignore loss correctly
202
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
203
+
204
+ batch["labels"] = labels
205
+
206
+ return batch
207
+
208
+
209
+ class CTCTrainer(Trainer):
210
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
211
+ """
212
+ Perform a training step on a batch of inputs.
213
+
214
+ Subclass and override to inject custom behavior.
215
+
216
+ Args:
217
+ model (:obj:`nn.Module`):
218
+ The model to train.
219
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
220
+ The inputs and targets of the model.
221
+
222
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
223
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
224
+
225
+ Return:
226
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
227
+ """
228
+
229
+ model.train()
230
+ inputs = self._prepare_inputs(inputs)
231
+
232
+ if self.use_amp:
233
+ with autocast():
234
+ loss = self.compute_loss(model, inputs)
235
+ else:
236
+ loss = self.compute_loss(model, inputs)
237
+
238
+ if self.args.n_gpu > 1:
239
+ if model.module.config.ctc_loss_reduction == "mean":
240
+ loss = loss.mean()
241
+ elif model.module.config.ctc_loss_reduction == "sum":
242
+ loss = loss.sum() / (inputs["labels"] >= 0).sum()
243
+ else:
244
+ raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
245
+
246
+ if self.args.gradient_accumulation_steps > 1:
247
+ loss = loss / self.args.gradient_accumulation_steps
248
+
249
+ if self.use_amp:
250
+ self.scaler.scale(loss).backward()
251
+ elif self.use_apex:
252
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
253
+ scaled_loss.backward()
254
+ elif self.deepspeed:
255
+ self.deepspeed.backward(loss)
256
+ else:
257
+ loss.backward()
258
+
259
+ return loss.detach()
260
+
261
+
262
+ def main():
263
+ # See all possible arguments in src/transformers/training_args.py
264
+ # or by passing the --help flag to this script.
265
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
266
+
267
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
268
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
269
+ # If we pass only one argument to the script and it's the path to a json file,
270
+ # let's parse it to get our arguments.
271
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
272
+ else:
273
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
274
+
275
+ # Detecting last checkpoint.
276
+ last_checkpoint = None
277
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
278
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
279
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
280
+ raise ValueError(
281
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
282
+ "Use --overwrite_output_dir to overcome."
283
+ )
284
+ elif last_checkpoint is not None:
285
+ logger.info(
286
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
287
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
288
+ )
289
+
290
+ # Setup logging
291
+ logging.basicConfig(
292
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
293
+ datefmt="%m/%d/%Y %H:%M:%S",
294
+ handlers=[logging.StreamHandler(sys.stdout)],
295
+ )
296
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
297
+
298
+ # Log on each process the small summary:
299
+ logger.warning(
300
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
301
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
302
+ )
303
+ # Set the verbosity to info of the Transformers logger (on main process only):
304
+ if is_main_process(training_args.local_rank):
305
+ transformers.utils.logging.set_verbosity_info()
306
+ logger.info("Training/evaluation parameters %s", training_args)
307
+
308
+ # Set seed before initializing model.
309
+ set_seed(training_args.seed)
310
+
311
+ # Get the datasets:
312
+ train_dataset = datasets.load_dataset(
313
+ "common_voice", data_args.dataset_config_name, split=data_args.train_split_name, cache_dir=model_args.cache_dir
314
+ )
315
+ eval_dataset = datasets.load_dataset("common_voice", data_args.dataset_config_name, split="test", cache_dir=model_args.cache_dir)
316
+
317
+ # Create and save tokenizer
318
+ # chars_to_ignore_regex = f'[{"".join(data_args.chars_to_ignore)}]'
319
+
320
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\โ€œ\ุŸ\_\ุ›\ู€\โ€”]'
321
+
322
+ def remove_special_characters(batch):
323
+ batch["text"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower() + " "
324
+ return batch
325
+
326
+ train_dataset = train_dataset.map(remove_special_characters, remove_columns=["sentence"])
327
+ eval_dataset = eval_dataset.map(remove_special_characters, remove_columns=["sentence"])
328
+
329
+ def extract_all_chars(batch):
330
+ all_text = " ".join(batch["text"])
331
+ vocab = list(set(all_text))
332
+ return {"vocab": [vocab], "all_text": [all_text]}
333
+
334
+ vocab_train = train_dataset.map(
335
+ extract_all_chars,
336
+ batched=True,
337
+ batch_size=-1,
338
+ keep_in_memory=True,
339
+ remove_columns=train_dataset.column_names,
340
+ )
341
+ vocab_test = train_dataset.map(
342
+ extract_all_chars,
343
+ batched=True,
344
+ batch_size=-1,
345
+ keep_in_memory=True,
346
+ remove_columns=eval_dataset.column_names,
347
+ )
348
+
349
+ vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
350
+ vocab_dict = {v: k for k, v in enumerate(vocab_list)}
351
+ vocab_dict["|"] = vocab_dict[" "]
352
+ del vocab_dict[" "]
353
+ vocab_dict["[UNK]"] = len(vocab_dict)
354
+ vocab_dict["[PAD]"] = len(vocab_dict)
355
+
356
+ with open("vocab.json", "w") as vocab_file:
357
+ json.dump(vocab_dict, vocab_file)
358
+
359
+ # Load pretrained model and tokenizer
360
+ #
361
+ # Distributed training:
362
+ # The .from_pretrained methods guarantee that only one local process can concurrently
363
+ # download model & vocab.
364
+ tokenizer = Wav2Vec2CTCTokenizer(
365
+ "vocab.json",
366
+ unk_token="[UNK]",
367
+ pad_token="[PAD]",
368
+ word_delimiter_token="|",
369
+ )
370
+ feature_extractor = Wav2Vec2FeatureExtractor(
371
+ feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True
372
+ )
373
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
374
+ model = Wav2Vec2ForCTC.from_pretrained(
375
+ model_args.model_name_or_path,
376
+ cache_dir=model_args.cache_dir,
377
+ activation_dropout=model_args.activation_dropout,
378
+ attention_dropout=model_args.attention_dropout,
379
+ hidden_dropout=model_args.hidden_dropout,
380
+ feat_proj_dropout=model_args.feat_proj_dropout,
381
+ mask_time_prob=model_args.mask_time_prob,
382
+ gradient_checkpointing=model_args.gradient_checkpointing,
383
+ layerdrop=model_args.layerdrop,
384
+ ctc_loss_reduction="mean",
385
+ pad_token_id=processor.tokenizer.pad_token_id,
386
+ vocab_size=len(processor.tokenizer),
387
+ )
388
+
389
+ if data_args.max_train_samples is not None:
390
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
391
+
392
+ if data_args.max_val_samples is not None:
393
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
394
+
395
+ resampler = torchaudio.transforms.Resample(32_000, 16_000)
396
+
397
+ # Preprocessing the datasets.
398
+ # We need to read the aduio files as arrays and tokenize the targets.
399
+ def speech_file_to_array_fn(batch):
400
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
401
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
402
+ batch["sampling_rate"] = 16_000
403
+ batch["target_text"] = batch["text"]
404
+ return batch
405
+
406
+ train_dataset = train_dataset.map(
407
+ speech_file_to_array_fn,
408
+ remove_columns=train_dataset.column_names,
409
+ num_proc=data_args.preprocessing_num_workers,
410
+ )
411
+ eval_dataset = eval_dataset.map(
412
+ speech_file_to_array_fn,
413
+ remove_columns=eval_dataset.column_names,
414
+ num_proc=data_args.preprocessing_num_workers,
415
+ )
416
+
417
+ def prepare_dataset(batch):
418
+ # check that all files have the correct sampling rate
419
+ assert (
420
+ len(set(batch["sampling_rate"])) == 1
421
+ ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
422
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
423
+ # Setup the processor for targets
424
+ with processor.as_target_processor():
425
+ batch["labels"] = processor(batch["target_text"]).input_ids
426
+ return batch
427
+
428
+ train_dataset = train_dataset.map(
429
+ prepare_dataset,
430
+ remove_columns=train_dataset.column_names,
431
+ batch_size=training_args.per_device_train_batch_size,
432
+ batched=True,
433
+ num_proc=data_args.preprocessing_num_workers,
434
+ )
435
+ eval_dataset = eval_dataset.map(
436
+ prepare_dataset,
437
+ remove_columns=eval_dataset.column_names,
438
+ batch_size=training_args.per_device_train_batch_size,
439
+ batched=True,
440
+ num_proc=data_args.preprocessing_num_workers,
441
+ )
442
+
443
+ # Metric
444
+ wer_metric = datasets.load_metric("wer")
445
+
446
+ def compute_metrics(pred):
447
+ pred_logits = pred.predictions
448
+ pred_ids = np.argmax(pred_logits, axis=-1)
449
+
450
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
451
+
452
+ pred_str = processor.batch_decode(pred_ids)
453
+ # we do not want to group tokens when computing the metrics
454
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
455
+
456
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
457
+
458
+ return {"wer": wer}
459
+
460
+ if model_args.freeze_feature_extractor:
461
+ model.freeze_feature_extractor()
462
+
463
+ # Data collator
464
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
465
+
466
+ # Initialize our Trainer
467
+ trainer = CTCTrainer(
468
+ model=model,
469
+ data_collator=data_collator,
470
+ args=training_args,
471
+ compute_metrics=compute_metrics,
472
+ train_dataset=train_dataset if training_args.do_train else None,
473
+ eval_dataset=eval_dataset if training_args.do_eval else None,
474
+ tokenizer=processor.feature_extractor,
475
+ )
476
+
477
+ # Training
478
+ if training_args.do_train:
479
+ if last_checkpoint is not None:
480
+ checkpoint = last_checkpoint
481
+ elif os.path.isdir(model_args.model_name_or_path):
482
+ checkpoint = model_args.model_name_or_path
483
+ else:
484
+ checkpoint = None
485
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
486
+ trainer.save_model()
487
+
488
+ # save the feature_extractor and the tokenizer
489
+ if is_main_process(training_args.local_rank):
490
+ processor.save_pretrained(training_args.output_dir)
491
+
492
+ metrics = train_result.metrics
493
+ max_train_samples = (
494
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
495
+ )
496
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
497
+
498
+ trainer.log_metrics("train", metrics)
499
+ trainer.save_metrics("train", metrics)
500
+ trainer.save_state()
501
+
502
+ # Evaluation
503
+ results = {}
504
+ if training_args.do_eval:
505
+ logger.info("*** Evaluate ***")
506
+ metrics = trainer.evaluate()
507
+ max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
508
+ metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
509
+
510
+ trainer.log_metrics("eval", metrics)
511
+ trainer.save_metrics("eval", metrics)
512
+
513
+ return results
514
+
515
+
516
+ if __name__ == "__main__":
517
+ main()
518
+