othrif commited on
Commit
96dcf78
1 Parent(s): e3717e3

added my scripts

Browse files
Files changed (3) hide show
  1. evaluate.py +46 -0
  2. finetune.sh +37 -0
  3. run_common_voice.py +518 -0
evaluate.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from datasets import load_dataset, load_metric
4
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ import re
6
+ import tnkeeh as tn
7
+
8
+ test_dataset = load_dataset("common_voice", "ar", split="test")
9
+ wer = load_metric("wer")
10
+
11
+ processor = Wav2Vec2Processor.from_pretrained("othrif/wav2vec2-large-xlsr-arabic")
12
+ model = Wav2Vec2ForCTC.from_pretrained("othrif/wav2vec2-large-xlsr-arabic")
13
+ model.to("cuda")
14
+
15
+ #chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\'\�]'
16
+ chars_to_ignore_regex = '[\؛\—\_get\«\»\ـ\ـ\,\?\.\!\-\;\:\"\“\%\‘\”\�\#\،\☭,\؟\'ۚ\چ\ڨ\ﺃ\ھ\ﻻ\'ۖ]'
17
+ resampler = torchaudio.transforms.Resample(48_000, 16_000)
18
+
19
+
20
+ # Preprocessing the datasets.
21
+ # We need to read the aduio files as arrays
22
+ def speech_file_to_array_fn(batch):
23
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower()
24
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
25
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
26
+ return batch
27
+
28
+ test_dataset = test_dataset.map(speech_file_to_array_fn)
29
+ # For arabic diacritics
30
+ cleander = tn.Tnkeeh(remove_diacritics=True)
31
+ test_dataset = cleander.clean_hf_dataset(test_dataset, 'sentence')
32
+ # Preprocessing the datasets.
33
+ # We need to read the aduio files as arrays
34
+ def evaluate(batch):
35
+ inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
36
+
37
+ with torch.no_grad():
38
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
39
+
40
+ pred_ids = torch.argmax(logits, dim=-1)
41
+ batch["pred_strings"] = processor.batch_decode(pred_ids)
42
+ return batch
43
+
44
+ result = test_dataset.map(evaluate, batched=True, batch_size=32)
45
+
46
+ print("WER: {:2f}".format(100 * wer.compute(predictions=result["pred_strings"], references=result["sentence"])))
finetune.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ export model_path=$1
4
+ mkdir -p ${model_path}
5
+
6
+ python run_common_voice.py \
7
+ --dataloader_num_workers="8" \
8
+ --model_name_or_path="facebook/wav2vec2-large-xlsr-53" \
9
+ #--overwrite_output_dir \
10
+ --dataset_config_name="ar" \
11
+ --output_dir=${model_path} \
12
+ --num_train_epochs="50" \
13
+ --per_device_train_batch_size="16" \
14
+ --per_device_eval_batch_size="16" \
15
+ --evaluation_strategy="steps" \
16
+ --warmup_steps="500" \
17
+ --fp16 \
18
+ --freeze_feature_extractor \
19
+ --save_steps="400" \
20
+ --eval_steps="400" \
21
+ --logging_steps="400" \
22
+ --save_total_limit="1" \
23
+ --group_by_length \
24
+ --attention_dropout="0.094" \
25
+ --activation_dropout="0.055" \
26
+ --feat_proj_dropout="0.04" \
27
+ --hidden_dropout="0.047" \
28
+ --layerdrop="0.041" \
29
+ --mask_time_prob="0.082" \
30
+ --gradient_checkpointing \
31
+ --learning_rate="3e-4" \
32
+ --do_train --do_eval
33
+
34
+
35
+
36
+
37
+
run_common_voice.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ import logging
4
+ import os
5
+ import re
6
+ import sys
7
+ from dataclasses import dataclass, field
8
+ from typing import Any, Dict, List, Optional, Union
9
+
10
+ import datasets
11
+ import numpy as np
12
+ import torch
13
+ import torchaudio
14
+ from packaging import version
15
+ from torch import nn
16
+
17
+ import transformers
18
+ from transformers import (
19
+ HfArgumentParser,
20
+ Trainer,
21
+ TrainingArguments,
22
+ Wav2Vec2CTCTokenizer,
23
+ Wav2Vec2FeatureExtractor,
24
+ Wav2Vec2ForCTC,
25
+ Wav2Vec2Processor,
26
+ is_apex_available,
27
+ set_seed,
28
+ )
29
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
30
+
31
+
32
+ if is_apex_available():
33
+ from apex import amp
34
+
35
+ import tnkeeh as tn
36
+
37
+ if version.parse(torch.__version__) >= version.parse("1.6"):
38
+ _is_native_amp_available = True
39
+ from torch.cuda.amp import autocast
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ def list_field(default=None, metadata=None):
45
+ return field(default_factory=lambda: default, metadata=metadata)
46
+
47
+
48
+ @dataclass
49
+ class ModelArguments:
50
+ """
51
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
52
+ """
53
+
54
+ model_name_or_path: str = field(
55
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
56
+ )
57
+ cache_dir: Optional[str] = field(
58
+ default=None,
59
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
60
+ )
61
+ freeze_feature_extractor: Optional[bool] = field(
62
+ default=True, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
63
+ )
64
+ attention_dropout: Optional[float] = field(
65
+ default=0.1, metadata={"help": "The dropout ratio for the attention probabilities."}
66
+ )
67
+ activation_dropout: Optional[float] = field(
68
+ default=0.1, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
69
+ )
70
+ hidden_dropout: Optional[float] = field(
71
+ default=0.1,
72
+ metadata={
73
+ "help": "The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler."
74
+ },
75
+ )
76
+ feat_proj_dropout: Optional[float] = field(
77
+ default=0.1,
78
+ metadata={"help": "The dropout probabilitiy for all 1D convolutional layers in feature extractor."},
79
+ )
80
+ mask_time_prob: Optional[float] = field(
81
+ default=0.05,
82
+ metadata={
83
+ "help": "Propability of each feature vector along the time axis to be chosen as the start of the vector"
84
+ "span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
85
+ "vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
86
+ },
87
+ )
88
+ gradient_checkpointing: Optional[bool] = field(
89
+ default=True,
90
+ metadata={
91
+ "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
92
+ },
93
+ )
94
+ layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
95
+
96
+
97
+ @dataclass
98
+ class DataTrainingArguments:
99
+ """
100
+ Arguments pertaining to what data we are going to input our model for training and eval.
101
+
102
+ Using `HfArgumentParser` we can turn this class
103
+ into argparse arguments to be able to specify them on
104
+ the command line.
105
+ """
106
+
107
+ dataset_config_name: Optional[str] = field(
108
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
109
+ )
110
+ train_split_name: Optional[str] = field(
111
+ default="train+validation",
112
+ metadata={
113
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
114
+ },
115
+ )
116
+ overwrite_cache: bool = field(
117
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
118
+ )
119
+ preprocessing_num_workers: Optional[int] = field(
120
+ default=None,
121
+ metadata={"help": "The number of processes to use for the preprocessing."},
122
+ )
123
+ max_train_samples: Optional[int] = field(
124
+ default=None,
125
+ metadata={
126
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
127
+ "value if set."
128
+ },
129
+ )
130
+ max_val_samples: Optional[int] = field(
131
+ default=None,
132
+ metadata={
133
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
134
+ "value if set."
135
+ },
136
+ )
137
+ chars_to_ignore: List[str] = list_field(
138
+ default=[",", "?", ".", "!", "-", ";", ":", '""', "%", "'", '"', "�"],
139
+ metadata={"help": "A list of characters to remove from the transcripts."},
140
+ )
141
+
142
+
143
+ @dataclass
144
+ class DataCollatorCTCWithPadding:
145
+ """
146
+ Data collator that will dynamically pad the inputs received.
147
+ Args:
148
+ processor (:class:`~transformers.Wav2Vec2Processor`)
149
+ The processor used for proccessing the data.
150
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
151
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
152
+ among:
153
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
154
+ sequence if provided).
155
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
156
+ maximum acceptable input length for the model if that argument is not provided.
157
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
158
+ different lengths).
159
+ max_length (:obj:`int`, `optional`):
160
+ Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
161
+ max_length_labels (:obj:`int`, `optional`):
162
+ Maximum length of the ``labels`` returned list and optionally padding length (see above).
163
+ pad_to_multiple_of (:obj:`int`, `optional`):
164
+ If set will pad the sequence to a multiple of the provided value.
165
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
166
+ 7.5 (Volta).
167
+ """
168
+
169
+ processor: Wav2Vec2Processor
170
+ padding: Union[bool, str] = True
171
+ max_length: Optional[int] = None
172
+ max_length_labels: Optional[int] = None
173
+ pad_to_multiple_of: Optional[int] = None
174
+ pad_to_multiple_of_labels: Optional[int] = None
175
+
176
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
177
+ # split inputs and labels since they have to be of different lenghts and need
178
+ # different padding methods
179
+ input_features = [{"input_values": feature["input_values"]} for feature in features]
180
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
181
+
182
+ batch = self.processor.pad(
183
+ input_features,
184
+ padding=self.padding,
185
+ max_length=self.max_length,
186
+ pad_to_multiple_of=self.pad_to_multiple_of,
187
+ return_tensors="pt",
188
+ )
189
+ with self.processor.as_target_processor():
190
+ labels_batch = self.processor.pad(
191
+ label_features,
192
+ padding=self.padding,
193
+ max_length=self.max_length_labels,
194
+ pad_to_multiple_of=self.pad_to_multiple_of_labels,
195
+ return_tensors="pt",
196
+ )
197
+
198
+ # replace padding with -100 to ignore loss correctly
199
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
200
+
201
+ batch["labels"] = labels
202
+
203
+ return batch
204
+
205
+
206
+ class CTCTrainer(Trainer):
207
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
208
+ """
209
+ Perform a training step on a batch of inputs.
210
+
211
+ Subclass and override to inject custom behavior.
212
+
213
+ Args:
214
+ model (:obj:`nn.Module`):
215
+ The model to train.
216
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
217
+ The inputs and targets of the model.
218
+
219
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
220
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
221
+
222
+ Return:
223
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
224
+ """
225
+
226
+ model.train()
227
+ inputs = self._prepare_inputs(inputs)
228
+
229
+ if self.use_amp:
230
+ with autocast():
231
+ loss = self.compute_loss(model, inputs)
232
+ else:
233
+ loss = self.compute_loss(model, inputs)
234
+
235
+ if self.args.n_gpu > 1:
236
+ if model.module.config.ctc_loss_reduction == "mean":
237
+ loss = loss.mean()
238
+ elif model.module.config.ctc_loss_reduction == "sum":
239
+ loss = loss.sum() / (inputs["labels"] >= 0).sum()
240
+ else:
241
+ raise ValueError(f"{model.config.ctc_loss_reduction} is not valid. Choose one of ['mean', 'sum']")
242
+
243
+ if self.args.gradient_accumulation_steps > 1:
244
+ loss = loss / self.args.gradient_accumulation_steps
245
+
246
+ if self.use_amp:
247
+ self.scaler.scale(loss).backward()
248
+ elif self.use_apex:
249
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
250
+ scaled_loss.backward()
251
+ elif self.deepspeed:
252
+ self.deepspeed.backward(loss)
253
+ else:
254
+ loss.backward()
255
+
256
+ return loss.detach()
257
+
258
+
259
+ def main():
260
+ # See all possible arguments in src/transformers/training_args.py
261
+ # or by passing the --help flag to this script.
262
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
263
+
264
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
265
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
266
+ # If we pass only one argument to the script and it's the path to a json file,
267
+ # let's parse it to get our arguments.
268
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
269
+ else:
270
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
271
+
272
+ # Detecting last checkpoint.
273
+ last_checkpoint = None
274
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
275
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
276
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
277
+ raise ValueError(
278
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
279
+ "Use --overwrite_output_dir to overcome."
280
+ )
281
+ elif last_checkpoint is not None:
282
+ logger.info(
283
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
284
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
285
+ )
286
+
287
+ # Setup logging
288
+ logging.basicConfig(
289
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
290
+ datefmt="%m/%d/%Y %H:%M:%S",
291
+ handlers=[logging.StreamHandler(sys.stdout)],
292
+ )
293
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
294
+
295
+ # Log on each process the small summary:
296
+ logger.warning(
297
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
298
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
299
+ )
300
+ # Set the verbosity to info of the Transformers logger (on main process only):
301
+ if is_main_process(training_args.local_rank):
302
+ transformers.utils.logging.set_verbosity_info()
303
+ logger.info("Training/evaluation parameters %s", training_args)
304
+
305
+ # Set seed before initializing model.
306
+ set_seed(training_args.seed)
307
+
308
+ # Get the datasets:
309
+ train_dataset = datasets.load_dataset(
310
+ "common_voice", data_args.dataset_config_name, split=data_args.train_split_name
311
+ )
312
+ eval_dataset = datasets.load_dataset("common_voice", data_args.dataset_config_name, split="test")
313
+
314
+ # Create and save tokenizer
315
+ #chars_to_ignore_regex = f'[{"".join(data_args.chars_to_ignore)}]'
316
+ chars_to_ignore_regex = '[\؛\—\_get\«\»\ـ\ـ\,\?\.\!\-\;\:\"\“\%\‘\”\�\#\،\☭,\؟\'ۚ\چ\ڨ\ﺃ\ھ\ﻻ\'ۖ]'
317
+
318
+ def remove_special_characters(batch):
319
+ batch["text"] = re.sub(chars_to_ignore_regex, "", batch["sentence"]).lower() + " "
320
+ return batch
321
+
322
+ train_dataset = train_dataset.map(remove_special_characters, remove_columns=["sentence"])
323
+ eval_dataset = eval_dataset.map(remove_special_characters, remove_columns=["sentence"])
324
+
325
+ # For arabic diacritics
326
+ cleander = tn.Tnkeeh(remove_diacritics=True)
327
+ train_dataset = cleander.clean_hf_dataset(train_dataset, 'sentence')
328
+ eval_dataset = cleander.clean_hf_dataset(eval_dataset, 'sentence')
329
+
330
+ def extract_all_chars(batch):
331
+ all_text = " ".join(batch["text"])
332
+ vocab = list(set(all_text))
333
+ return {"vocab": [vocab], "all_text": [all_text]}
334
+
335
+ vocab_train = train_dataset.map(
336
+ extract_all_chars,
337
+ batched=True,
338
+ batch_size=-1,
339
+ keep_in_memory=True,
340
+ remove_columns=train_dataset.column_names,
341
+ )
342
+ vocab_test = train_dataset.map(
343
+ extract_all_chars,
344
+ batched=True,
345
+ batch_size=-1,
346
+ keep_in_memory=True,
347
+ remove_columns=eval_dataset.column_names,
348
+ )
349
+
350
+ vocab_list = list(set(vocab_train["vocab"][0]) | set(vocab_test["vocab"][0]))
351
+ vocab_dict = {v: k for k, v in enumerate(vocab_list)}
352
+ vocab_dict["|"] = vocab_dict[" "]
353
+ del vocab_dict[" "]
354
+ vocab_dict["[UNK]"] = len(vocab_dict)
355
+ vocab_dict["[PAD]"] = len(vocab_dict)
356
+
357
+ with open("vocab.json", "w") as vocab_file:
358
+ json.dump(vocab_dict, vocab_file)
359
+
360
+ # Load pretrained model and tokenizer
361
+ #
362
+ # Distributed training:
363
+ # The .from_pretrained methods guarantee that only one local process can concurrently
364
+ # download model & vocab.
365
+ tokenizer = Wav2Vec2CTCTokenizer(
366
+ "vocab.json",
367
+ unk_token="[UNK]",
368
+ pad_token="[PAD]",
369
+ word_delimiter_token="|",
370
+ )
371
+ feature_extractor = Wav2Vec2FeatureExtractor(
372
+ feature_size=1, sampling_rate=16_000, padding_value=0.0, do_normalize=True, return_attention_mask=True
373
+ )
374
+ processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
375
+ model = Wav2Vec2ForCTC.from_pretrained(
376
+ model_args.model_name_or_path,
377
+ cache_dir=model_args.cache_dir,
378
+ activation_dropout=model_args.activation_dropout,
379
+ attention_dropout=model_args.attention_dropout,
380
+ hidden_dropout=model_args.hidden_dropout,
381
+ feat_proj_dropout=model_args.feat_proj_dropout,
382
+ mask_time_prob=model_args.mask_time_prob,
383
+ gradient_checkpointing=model_args.gradient_checkpointing,
384
+ layerdrop=model_args.layerdrop,
385
+ ctc_loss_reduction="mean",
386
+ pad_token_id=processor.tokenizer.pad_token_id,
387
+ vocab_size=len(processor.tokenizer),
388
+ )
389
+
390
+ if data_args.max_train_samples is not None:
391
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
392
+
393
+ if data_args.max_val_samples is not None:
394
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
395
+
396
+ resampler = torchaudio.transforms.Resample(48_000, 16_000)
397
+
398
+ # Preprocessing the datasets.
399
+ # We need to read the aduio files as arrays and tokenize the targets.
400
+ def speech_file_to_array_fn(batch):
401
+ speech_array, sampling_rate = torchaudio.load(batch["path"])
402
+ batch["speech"] = resampler(speech_array).squeeze().numpy()
403
+ batch["sampling_rate"] = 16_000
404
+ batch["target_text"] = batch["text"]
405
+ return batch
406
+
407
+ train_dataset = train_dataset.map(
408
+ speech_file_to_array_fn,
409
+ remove_columns=train_dataset.column_names,
410
+ num_proc=data_args.preprocessing_num_workers,
411
+ )
412
+ eval_dataset = eval_dataset.map(
413
+ speech_file_to_array_fn,
414
+ remove_columns=eval_dataset.column_names,
415
+ num_proc=data_args.preprocessing_num_workers,
416
+ )
417
+
418
+ def prepare_dataset(batch):
419
+ # check that all files have the correct sampling rate
420
+ assert (
421
+ len(set(batch["sampling_rate"])) == 1
422
+ ), f"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}."
423
+ batch["input_values"] = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0]).input_values
424
+ # Setup the processor for targets
425
+ with processor.as_target_processor():
426
+ batch["labels"] = processor(batch["target_text"]).input_ids
427
+ return batch
428
+
429
+ train_dataset = train_dataset.map(
430
+ prepare_dataset,
431
+ remove_columns=train_dataset.column_names,
432
+ batch_size=training_args.per_device_train_batch_size,
433
+ batched=True,
434
+ num_proc=data_args.preprocessing_num_workers,
435
+ )
436
+ eval_dataset = eval_dataset.map(
437
+ prepare_dataset,
438
+ remove_columns=eval_dataset.column_names,
439
+ batch_size=training_args.per_device_train_batch_size,
440
+ batched=True,
441
+ num_proc=data_args.preprocessing_num_workers,
442
+ )
443
+
444
+ # Metric
445
+ wer_metric = datasets.load_metric("wer")
446
+
447
+ def compute_metrics(pred):
448
+ pred_logits = pred.predictions
449
+ pred_ids = np.argmax(pred_logits, axis=-1)
450
+
451
+ pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
452
+
453
+ pred_str = processor.batch_decode(pred_ids)
454
+ # we do not want to group tokens when computing the metrics
455
+ label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
456
+
457
+ wer = wer_metric.compute(predictions=pred_str, references=label_str)
458
+
459
+ return {"wer": wer}
460
+
461
+ if model_args.freeze_feature_extractor:
462
+ model.freeze_feature_extractor()
463
+
464
+ # Data collator
465
+ data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
466
+
467
+ # Initialize our Trainer
468
+ trainer = CTCTrainer(
469
+ model=model,
470
+ data_collator=data_collator,
471
+ args=training_args,
472
+ compute_metrics=compute_metrics,
473
+ train_dataset=train_dataset if training_args.do_train else None,
474
+ eval_dataset=eval_dataset if training_args.do_eval else None,
475
+ tokenizer=processor.feature_extractor,
476
+ )
477
+
478
+ # Training
479
+ if training_args.do_train:
480
+ if last_checkpoint is not None:
481
+ checkpoint = last_checkpoint
482
+ elif os.path.isdir(model_args.model_name_or_path):
483
+ checkpoint = model_args.model_name_or_path
484
+ else:
485
+ checkpoint = None
486
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
487
+ trainer.save_model()
488
+
489
+ # save the feature_extractor and the tokenizer
490
+ if is_main_process(training_args.local_rank):
491
+ processor.save_pretrained(training_args.output_dir)
492
+
493
+ metrics = train_result.metrics
494
+ max_train_samples = (
495
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
496
+ )
497
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
498
+
499
+ trainer.log_metrics("train", metrics)
500
+ trainer.save_metrics("train", metrics)
501
+ trainer.save_state()
502
+
503
+ # Evaluation
504
+ results = {}
505
+ if training_args.do_eval:
506
+ logger.info("*** Evaluate ***")
507
+ metrics = trainer.evaluate()
508
+ max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
509
+ metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
510
+
511
+ trainer.log_metrics("eval", metrics)
512
+ trainer.save_metrics("eval", metrics)
513
+
514
+ return results
515
+
516
+
517
+ if __name__ == "__main__":
518
+ main()