Zaid commited on
Commit
de5bb56
1 Parent(s): ff53392

Training in progress, step 10000

Browse files
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3e692eb6043acda8671ca15866b5800d1926740b4f023fd9fb15ee413f9d6e5c
3
  size 3055754841
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d030879de7c6cd0ae429b34490b7cf104969ce12b2ae4217f5a266aa22e7b01
3
  size 3055754841
run_speech_recognition_seq2seq_mixed_mgb2_wandb.py ADDED
@@ -0,0 +1,873 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence speech recognition
18
+ with 🤗 Datasets' streaming mode.
19
+ """
20
+ # You can also adapt this script for your own sequence to sequence speech
21
+ # recognition task. Pointers for this are left as comments.
22
+
23
+ import logging
24
+ import os
25
+ import sys
26
+ from dataclasses import dataclass, field
27
+ from typing import Any, Dict, List, Optional, Union
28
+
29
+ import datasets
30
+ import torch
31
+ from datasets import DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
32
+ from torch.utils.data import IterableDataset
33
+
34
+ import evaluate
35
+ import transformers
36
+ from transformers import (
37
+ AutoConfig,
38
+ AutoFeatureExtractor,
39
+ AutoModelForSpeechSeq2Seq,
40
+ AutoProcessor,
41
+ AutoTokenizer,
42
+ HfArgumentParser,
43
+ Seq2SeqTrainer,
44
+ Seq2SeqTrainingArguments,
45
+ TrainerCallback,
46
+ set_seed,
47
+ )
48
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
49
+ from transformers.trainer_pt_utils import IterableDatasetShard
50
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
51
+ from transformers.utils import check_min_version, send_example_telemetry
52
+ from transformers.utils.versions import require_version
53
+ import wandb
54
+
55
+ run = wandb.init(project="whisper_finetuning", job_type="fine-tuning", group="medium", resume="must", id="2k10w4qq" )
56
+
57
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
58
+ check_min_version("4.25.0.dev0")
59
+
60
+ require_version(
61
+ "datasets>=1.18.2",
62
+ "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt",
63
+ )
64
+
65
+ logger = logging.getLogger(__name__)
66
+
67
+
68
+ def load_samples_dataset(dataset, num_samples=10):
69
+ samples = []
70
+ for i, item in enumerate(dataset):
71
+ samples.append(item)
72
+ if i == (num_samples-1):
73
+ break
74
+ sample_dataset = Dataset.from_list(samples)
75
+ return sample_dataset
76
+
77
+ def compute_spectrograms(example):
78
+ waveform = example["audio"]["array"]
79
+ specs = feature_extractor(waveform, sampling_rate=16000, padding="do_not_pad").input_features[0]
80
+ return {"spectrogram": specs}
81
+
82
+
83
+ def record_to_html(sample_record):
84
+ audio_array = np.array(sample_record["audio"]["array"])
85
+ audio_sr = sample_record["audio"]["sampling_rate"]
86
+ audio_duration = sample_record["length"]
87
+ audio_spectrogram = np.array(sample_record["spectrogram"])
88
+
89
+ bounds = (0,0, audio_duration, audio_spectrogram.max())
90
+
91
+ waveform_int = np.int16(audio_array * 32767)
92
+
93
+
94
+
95
+ hv_audio = pn.pane.Audio(waveform_int, sample_rate=audio_sr, name='Audio', throttle=500)
96
+
97
+ slider = pn.widgets.FloatSlider(end=audio_duration, visible=False, step=0.001)
98
+ line_audio = hv.VLine(0).opts(color='black')
99
+ line_spec = hv.VLine(0).opts(color='red')
100
+
101
+
102
+ slider.jslink(hv_audio, value='time', bidirectional=True)
103
+ slider.jslink(line_audio, value='glyph.location')
104
+ slider.jslink(line_spec, value='glyph.location')
105
+
106
+ time = np.linspace(0, audio_duration, num=len(audio_array))
107
+ line_plot_hv = hv.Curve(
108
+ (time, audio_array), ["Time (s)", "amplitude"]).opts(
109
+ width=500, height=150, axiswise=True) * line_audio
110
+
111
+ hv_spec_gram = hv.Image(
112
+ audio_spectrogram, bounds=(bounds), kdims=["Time (s)", "Frequency (hz)"]).opts(
113
+ width=500, height=150, labelled=[], axiswise=True, color_levels=512)* line_spec
114
+
115
+
116
+ combined = pn.Row(hv_audio, hv_spec_gram, line_plot_hv, slider)
117
+ audio_html = StringIO()
118
+ combined.save(audio_html)
119
+ return audio_html
120
+
121
+
122
+ def dataset_to_records(dataset):
123
+ records = []
124
+ for item in dataset:
125
+ record = {}
126
+ record["audio_with_spec"] = wandb.Html(record_to_html(item))
127
+ record["sentence"] = item["sentence"]
128
+ record["length"] = item["length"]
129
+ records.append(record)
130
+ records = pd.DataFrame(records)
131
+ return records
132
+
133
+ def decode_predictions(trainer, predictions):
134
+ pred_ids = predictions.predictions
135
+ pred_str = trainer.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, )
136
+ return pred_str
137
+
138
+
139
+ def compute_measures(predictions, labels):
140
+ measures = [jiwer.compute_measures(ls, ps,) for ps, ls in zip(predictions, labels)]
141
+ measures_df = pd.DataFrame(measures)[["wer", "hits", "substitutions", "deletions", "insertions"]]
142
+ return measures_df
143
+
144
+ class WandbProgressResultsCallback(WandbCallback):
145
+ def __init__(self, trainer, sample_dataset):
146
+ super().__init__()
147
+ self.trainer = trainer
148
+ self.sample_dataset = sample_dataset
149
+ self.records_df = dataset_to_records(sample_dataset)
150
+
151
+ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
152
+ super().on_log(args, state, control, model, logs)
153
+ predictions = trainer.predict(self.sample_dataset)
154
+ predictions = decode_predictions(self.trainer, predictions)
155
+ measures_df = compute_measures(predictions, self.records_df["sentence"].tolist())
156
+ records_df = pd.concat([self.records_df, measures_df], axis=1)
157
+ records_df["prediction"] = predictions
158
+ records_df["step"] = state.global_step
159
+ records_table = self._wandb.Table(dataframe=records_df)
160
+ self._wandb.log({"sample_predictions": records_table})
161
+
162
+ def on_save(self, args, state, control, model=None, tokenizer=None, **kwargs):
163
+ if self._wandb is None:
164
+ return
165
+ if self._log_model and self._initialized and state.is_world_process_zero:
166
+ with tempfile.TemporaryDirectory() as temp_dir:
167
+ self.trainer.save_model(temp_dir)
168
+ metadata = (
169
+ {
170
+ k: v
171
+ for k, v in dict(self._wandb.summary).items()
172
+ if isinstance(v, numbers.Number) and not k.startswith("_")
173
+ }
174
+ if not args.load_best_model_at_end
175
+ else {
176
+ f"eval/{args.metric_for_best_model}": state.best_metric,
177
+ "train/total_floss": state.total_flos,
178
+ }
179
+ )
180
+ artifact = self._wandb.Artifact(
181
+ name=f"model-{self._wandb.run.id}",
182
+ type="model", metadata=metadata)
183
+ for f in Path(temp_dir).glob("*"):
184
+ if f.is_file():
185
+ with artifact.new_file(f.name, mode="wb") as fa:
186
+ fa.write(f.read_bytes())
187
+ self._wandb.run.log_artifact(artifact)
188
+
189
+ @dataclass
190
+ class ModelArguments:
191
+ """
192
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
193
+ """
194
+
195
+ model_name_or_path: str = field(
196
+ metadata={
197
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"
198
+ }
199
+ )
200
+ config_name: Optional[str] = field(
201
+ default=None,
202
+ metadata={
203
+ "help": "Pretrained config name or path if not the same as model_name"
204
+ },
205
+ )
206
+ tokenizer_name: Optional[str] = field(
207
+ default=None,
208
+ metadata={
209
+ "help": "Pretrained tokenizer name or path if not the same as model_name"
210
+ },
211
+ )
212
+ feature_extractor_name: Optional[str] = field(
213
+ default=None,
214
+ metadata={
215
+ "help": "feature extractor name or path if not the same as model_name"
216
+ },
217
+ )
218
+ cache_dir: Optional[str] = field(
219
+ default=None,
220
+ metadata={
221
+ "help": "Where to store the pretrained models downloaded from huggingface.co"
222
+ },
223
+ )
224
+ use_fast_tokenizer: bool = field(
225
+ default=True,
226
+ metadata={
227
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
228
+ },
229
+ )
230
+ model_revision: str = field(
231
+ default="main",
232
+ metadata={
233
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
234
+ },
235
+ )
236
+ use_auth_token: bool = field(
237
+ default=False,
238
+ metadata={
239
+ "help": (
240
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
241
+ "with private models)."
242
+ )
243
+ },
244
+ )
245
+ freeze_feature_encoder: bool = field(
246
+ default=True,
247
+ metadata={"help": "Whether to freeze the feature encoder layers of the model."},
248
+ )
249
+ freeze_encoder: bool = field(
250
+ default=False,
251
+ metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."},
252
+ )
253
+ forced_decoder_ids: List[List[int]] = field(
254
+ default=None,
255
+ metadata={
256
+ "help": (
257
+ "A list of pairs of integers which indicates a mapping from generation indices to token indices "
258
+ "that will be forced before sampling. For example, [[0, 123]] means the first generated token "
259
+ "will always be a token of index 123."
260
+ )
261
+ },
262
+ )
263
+ suppress_tokens: List[int] = field(
264
+ default=None,
265
+ metadata={"help": "A list of tokens that will be suppressed at generation."},
266
+ )
267
+ model_index_name: str = field(
268
+ default=None, metadata={"help": "Pretty name for the model card."}
269
+ )
270
+
271
+
272
+ @dataclass
273
+ class DataTrainingArguments:
274
+ """
275
+ Arguments pertaining to what data we are going to input our model for training and eval.
276
+ """
277
+
278
+ dataset_name: str = field(
279
+ default=None,
280
+ metadata={"help": "The name of the dataset to use (via the datasets library)."},
281
+ )
282
+ dataset_config_name: Optional[str] = field(
283
+ default=None,
284
+ metadata={
285
+ "help": "The configuration name of the dataset to use (via the datasets library)."
286
+ },
287
+ )
288
+ text_column: Optional[str] = field(
289
+ default=None,
290
+ metadata={
291
+ "help": "The name of the column in the datasets containing the full texts (for summarization)."
292
+ },
293
+ )
294
+ max_train_samples: Optional[int] = field(
295
+ default=None,
296
+ metadata={
297
+ "help": (
298
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
299
+ "value if set."
300
+ )
301
+ },
302
+ )
303
+ max_eval_samples: Optional[int] = field(
304
+ default=None,
305
+ metadata={
306
+ "help": (
307
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
308
+ "value if set."
309
+ )
310
+ },
311
+ )
312
+ audio_column_name: str = field(
313
+ default="audio",
314
+ metadata={
315
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
316
+ },
317
+ )
318
+ text_column_name: str = field(
319
+ default="text",
320
+ metadata={
321
+ "help": "The name of the dataset column containing the text data. Defaults to 'text'"
322
+ },
323
+ )
324
+ max_duration_in_seconds: float = field(
325
+ default=20.0,
326
+ metadata={
327
+ "help": (
328
+ "Truncate audio files that are longer than `max_duration_in_seconds` seconds to"
329
+ " 'max_duration_in_seconds`"
330
+ )
331
+ },
332
+ )
333
+ min_duration_in_seconds: float = field(
334
+ default=0.0,
335
+ metadata={
336
+ "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"
337
+ },
338
+ )
339
+ train_split_name: str = field(
340
+ default="train",
341
+ metadata={
342
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
343
+ },
344
+ )
345
+ eval_split_name: str = field(
346
+ default="test",
347
+ metadata={
348
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
349
+ },
350
+ )
351
+ do_lower_case: bool = field(
352
+ default=False,
353
+ metadata={"help": "Whether the target text should be lower cased."},
354
+ )
355
+ do_remove_punctuation: bool = field(
356
+ default=False,
357
+ metadata={"help": "Whether the target text should be striped of punctuation."},
358
+ )
359
+ do_normalize_eval: bool = field(
360
+ default=True,
361
+ metadata={
362
+ "help": "Whether to normalise the references and predictions in the eval WER calculation."
363
+ },
364
+ )
365
+ language: str = field(
366
+ default=None,
367
+ metadata={
368
+ "help": (
369
+ "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
370
+ "only. For English speech recognition, it should be set to `None`."
371
+ )
372
+ },
373
+ )
374
+ task: str = field(
375
+ default="transcribe",
376
+ metadata={
377
+ "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."
378
+ },
379
+ )
380
+ shuffle_buffer_size: Optional[int] = field(
381
+ default=500,
382
+ metadata={
383
+ "help": (
384
+ "The number of streamed examples to download before shuffling them. The large the buffer, "
385
+ "the closer it is to real offline shuffling."
386
+ )
387
+ },
388
+ )
389
+ streaming: bool = field(
390
+ default=True,
391
+ metadata={
392
+ "help": "Whether to use streaming mode to load and pre-process the data."
393
+ },
394
+ )
395
+
396
+
397
+ @dataclass
398
+ class DataCollatorSpeechSeq2SeqWithPadding:
399
+ """
400
+ Data collator that will dynamically pad the inputs received.
401
+ Args:
402
+ processor ([`WhisperProcessor`])
403
+ The processor used for processing the data.
404
+ decoder_start_token_id (`int`)
405
+ The begin-of-sentence of the decoder.
406
+ """
407
+
408
+ processor: Any
409
+ decoder_start_token_id: int
410
+
411
+ def __call__(
412
+ self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
413
+ ) -> Dict[str, torch.Tensor]:
414
+ # split inputs and labels since they have to be of different lengths and need
415
+ # different padding methods
416
+ model_input_name = self.processor.model_input_names[0]
417
+ input_features = [
418
+ {model_input_name: feature[model_input_name]} for feature in features
419
+ ]
420
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
421
+
422
+ batch = self.processor.feature_extractor.pad(
423
+ input_features, return_tensors="pt"
424
+ )
425
+
426
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
427
+
428
+ # replace padding with -100 to ignore loss correctly
429
+ labels = labels_batch["input_ids"].masked_fill(
430
+ labels_batch.attention_mask.ne(1), -100
431
+ )
432
+
433
+ # if bos token is appended in previous tokenization step,
434
+ # cut bos token here as it's append later anyways
435
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
436
+ labels = labels[:, 1:]
437
+
438
+ batch["labels"] = labels
439
+
440
+ return batch
441
+
442
+
443
+ def load_maybe_streaming_dataset(
444
+ dataset_name, dataset_config_name, split="train", streaming=True, **kwargs
445
+ ):
446
+ """
447
+ Utility function to load a dataset in streaming mode. For datasets with multiple splits,
448
+ each split is loaded individually and then splits combined by taking alternating examples from
449
+ each (interleaving).
450
+ """
451
+ if "+" in split:
452
+ # load multiple splits separated by the `+` symbol with streaming mode
453
+ dataset_splits = [
454
+ load_dataset(
455
+ dataset_name,
456
+ dataset_config_name,
457
+ split=split_name,
458
+ streaming=streaming,
459
+ **kwargs,
460
+ )
461
+ for split_name in split.split("+")
462
+ ]
463
+ # interleave multiple splits to form one dataset
464
+ interleaved_dataset = interleave_datasets(dataset_splits)
465
+ return interleaved_dataset
466
+ else:
467
+ # load a single split *with* streaming mode
468
+ dataset = load_dataset(
469
+ dataset_name,
470
+ dataset_config_name,
471
+ split=split,
472
+ streaming=streaming,
473
+ **kwargs,
474
+ )
475
+ return dataset
476
+
477
+
478
+ def main():
479
+ # 1. Parse input arguments
480
+ # See all possible arguments in src/transformers/training_args.py
481
+ # or by passing the --help flag to this script.
482
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
483
+ parser = HfArgumentParser(
484
+ (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)
485
+ )
486
+
487
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
488
+ # If we pass only one argument to the script and it's the path to a json file,
489
+ # let's parse it to get our arguments.
490
+ model_args, data_args, training_args = parser.parse_json_file(
491
+ json_file=os.path.abspath(sys.argv[1])
492
+ )
493
+ else:
494
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
495
+
496
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
497
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
498
+ send_example_telemetry(
499
+ "run_speech_recognition_seq2seq_streaming", model_args, data_args
500
+ )
501
+
502
+ # 2. Setup logging
503
+ logging.basicConfig(
504
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
505
+ datefmt="%m/%d/%Y %H:%M:%S",
506
+ handlers=[logging.StreamHandler(sys.stdout)],
507
+ )
508
+ log_level = training_args.get_process_log_level()
509
+ logger.setLevel(log_level)
510
+ datasets.utils.logging.set_verbosity(log_level)
511
+ transformers.utils.logging.set_verbosity(log_level)
512
+ transformers.utils.logging.enable_default_handler()
513
+ transformers.utils.logging.enable_explicit_format()
514
+
515
+ logger.setLevel(
516
+ logging.INFO if is_main_process(training_args.local_rank) else logging.WARN
517
+ )
518
+
519
+ # Log on each process the small summary:
520
+ logger.warning(
521
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
522
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
523
+ )
524
+ logger.info(f"Training/evaluation parameters {training_args}")
525
+
526
+ # Set the verbosity to info of the Transformers logger (on main process only):
527
+ if is_main_process(training_args.local_rank):
528
+ transformers.utils.logging.set_verbosity_info()
529
+ logger.info("Training/evaluation parameters %s", training_args)
530
+
531
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint
532
+ last_checkpoint = None
533
+ if (
534
+ os.path.isdir(training_args.output_dir)
535
+ and training_args.do_train
536
+ and not training_args.overwrite_output_dir
537
+ ):
538
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
539
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
540
+ raise ValueError(
541
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
542
+ "Use --overwrite_output_dir to overcome."
543
+ )
544
+ elif (
545
+ last_checkpoint is not None and training_args.resume_from_checkpoint is None
546
+ ):
547
+ logger.info(
548
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
549
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
550
+ )
551
+
552
+ # Set seed before initializing model.
553
+ set_seed(training_args.seed)
554
+
555
+ # 4. Load dataset
556
+ raw_datasets = IterableDatasetDict()
557
+
558
+ if training_args.do_train:
559
+ raw_datasets["train"] = load_maybe_streaming_dataset(
560
+ data_args.dataset_name,
561
+ data_args.dataset_config_name,
562
+ split=data_args.train_split_name,
563
+ streaming=True,
564
+ use_auth_token=True if model_args.use_auth_token else None,
565
+ )
566
+
567
+ if training_args.do_eval:
568
+ raw_datasets["eval"] = load_maybe_streaming_dataset(
569
+ "arbml/mgb3",
570
+ data_args.dataset_config_name,
571
+ split="train",
572
+ streaming=False,
573
+ use_auth_token=True if model_args.use_auth_token else None,
574
+ )
575
+
576
+ raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
577
+
578
+ if data_args.audio_column_name not in raw_datasets_features:
579
+ raise ValueError(
580
+ f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
581
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
582
+ f"{', '.join(raw_datasets_features)}."
583
+ )
584
+
585
+ if data_args.text_column_name not in raw_datasets_features:
586
+ raise ValueError(
587
+ f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
588
+ "Make sure to set `--text_column_name` to the correct text column - one of "
589
+ f"{', '.join(raw_datasets_features)}."
590
+ )
591
+
592
+ # 5. Load pretrained model, tokenizer, and feature extractor
593
+ #
594
+ # Distributed training:
595
+ # The .from_pretrained methods guarantee that only one local process can concurrently
596
+ config = AutoConfig.from_pretrained(
597
+ model_args.config_name
598
+ if model_args.config_name
599
+ else model_args.model_name_or_path,
600
+ cache_dir=model_args.cache_dir,
601
+ revision=model_args.model_revision,
602
+ use_auth_token=True if model_args.use_auth_token else None,
603
+ )
604
+
605
+ config.update(
606
+ {
607
+ "forced_decoder_ids": model_args.forced_decoder_ids,
608
+ "suppress_tokens": model_args.suppress_tokens,
609
+ }
610
+ )
611
+
612
+ if training_args.gradient_checkpointing:
613
+ config.update({"use_cache": False})
614
+
615
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
616
+ model_args.feature_extractor_name
617
+ if model_args.feature_extractor_name
618
+ else model_args.model_name_or_path,
619
+ cache_dir=model_args.cache_dir,
620
+ revision=model_args.model_revision,
621
+ use_auth_token=True if model_args.use_auth_token else None,
622
+ )
623
+ tokenizer = AutoTokenizer.from_pretrained(
624
+ model_args.tokenizer_name
625
+ if model_args.tokenizer_name
626
+ else model_args.model_name_or_path,
627
+ cache_dir=model_args.cache_dir,
628
+ use_fast=model_args.use_fast_tokenizer,
629
+ revision=model_args.model_revision,
630
+ use_auth_token=True if model_args.use_auth_token else None,
631
+ )
632
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
633
+ model_args.model_name_or_path,
634
+ config=config,
635
+ cache_dir=model_args.cache_dir,
636
+ revision=model_args.model_revision,
637
+ use_auth_token=True if model_args.use_auth_token else None,
638
+ )
639
+
640
+ if model.config.decoder_start_token_id is None:
641
+ raise ValueError(
642
+ "Make sure that `config.decoder_start_token_id` is correctly defined"
643
+ )
644
+
645
+ max_label_length = model.config.max_length
646
+
647
+ if model_args.freeze_feature_encoder:
648
+ model.freeze_feature_encoder()
649
+
650
+ if model_args.freeze_encoder:
651
+ model.freeze_encoder()
652
+ model.model.encoder.gradient_checkpointing = False
653
+
654
+ if data_args.language is not None:
655
+ # We only need to set the task id when the language is specified (i.e. in a multilingual setting)
656
+ tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
657
+
658
+ # 6. Resample speech dataset if necessary
659
+ dataset_sampling_rate = (
660
+ next(iter(raw_datasets.values()))
661
+ .features[data_args.audio_column_name]
662
+ .sampling_rate
663
+ )
664
+ if dataset_sampling_rate != feature_extractor.sampling_rate:
665
+ raw_datasets = raw_datasets.cast_column(
666
+ data_args.audio_column_name,
667
+ datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
668
+ )
669
+
670
+ # 7. Preprocessing the datasets.
671
+ # We need to read the audio files as arrays and tokenize the targets.
672
+ max_input_length = (
673
+ data_args.max_duration_in_seconds * feature_extractor.sampling_rate
674
+ )
675
+ min_input_length = (
676
+ data_args.min_duration_in_seconds * feature_extractor.sampling_rate
677
+ )
678
+ audio_column_name = data_args.audio_column_name
679
+ text_column_name = data_args.text_column_name
680
+ model_input_name = feature_extractor.model_input_names[0]
681
+ do_lower_case = data_args.do_lower_case
682
+ do_remove_punctuation = data_args.do_remove_punctuation
683
+ normalizer = BasicTextNormalizer() # 'official' text normalizer from OpenAI
684
+
685
+ if data_args.max_train_samples is not None:
686
+ raw_datasets["train"] = raw_datasets["train"].take(data_args.max_train_samples)
687
+
688
+ if data_args.max_eval_samples is not None:
689
+ raw_datasets["eval"] = raw_datasets["eval"].select(
690
+ range(data_args.max_eval_samples)
691
+ )
692
+
693
+ def prepare_dataset(batch):
694
+ # process audio
695
+ sample = batch[audio_column_name]
696
+ inputs = feature_extractor(
697
+ sample["array"], sampling_rate=sample["sampling_rate"]
698
+ )
699
+ # process audio length
700
+ batch[model_input_name] = inputs.get(model_input_name)[0]
701
+ batch["input_length"] = len(sample["array"])
702
+
703
+ # process targets
704
+ input_str = (
705
+ batch[text_column_name].lower()
706
+ if do_lower_case
707
+ else batch[text_column_name]
708
+ )
709
+ if do_remove_punctuation:
710
+ input_str = normalizer(input_str).strip()
711
+ batch["labels"] = tokenizer(input_str).input_ids
712
+ return batch
713
+
714
+ with training_args.main_process_first(desc="dataset map pre-processing"):
715
+ vectorized_datasets = raw_datasets.map(
716
+ prepare_dataset,
717
+ remove_columns=raw_datasets_features,
718
+ ).with_format("torch")
719
+
720
+ if training_args.do_train:
721
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
722
+ buffer_size=data_args.shuffle_buffer_size,
723
+ seed=training_args.seed,
724
+ )
725
+
726
+ # filter training data that is shorter than min_input_length or longer than
727
+ # max_input_length
728
+ def is_audio_in_length_range(length):
729
+ return min_input_length < length < max_input_length
730
+
731
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
732
+ is_audio_in_length_range,
733
+ input_columns=["input_length"],
734
+ )
735
+
736
+ def filter_labels(labels):
737
+ """Filter label sequences longer than max length"""
738
+ return len(labels) < max_label_length
739
+
740
+ vectorized_datasets = vectorized_datasets.filter(filter_labels, input_columns=["labels"])
741
+
742
+ # 8. Load Metric
743
+ metric = evaluate.load("wer")
744
+ do_normalize_eval = data_args.do_normalize_eval
745
+
746
+ def compute_metrics(pred):
747
+ pred_ids = pred.predictions
748
+
749
+ pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
750
+
751
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
752
+ # we do not want to group tokens when computing the metrics
753
+ label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True)
754
+
755
+ if do_normalize_eval:
756
+ pred_str = [normalizer(pred) for pred in pred_str]
757
+ label_str = [normalizer(label) for label in label_str]
758
+ # filtering step to only evaluate the samples that correspond to non-zero references:
759
+ pred_str = [
760
+ pred_str[i] for i in range(len(pred_str)) if len(label_str[i]) > 0
761
+ ]
762
+ label_str = [
763
+ label_str[i] for i in range(len(label_str)) if len(label_str[i]) > 0
764
+ ]
765
+
766
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
767
+
768
+ return {"wer": wer}
769
+
770
+ # 9. Create a single speech processor
771
+ if is_main_process(training_args.local_rank):
772
+ # save feature extractor, tokenizer and config
773
+ feature_extractor.save_pretrained(training_args.output_dir)
774
+ tokenizer.save_pretrained(training_args.output_dir)
775
+ config.save_pretrained(training_args.output_dir)
776
+
777
+ processor = AutoProcessor.from_pretrained(training_args.output_dir)
778
+
779
+ # 10. Define data collator
780
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
781
+ processor=processor,
782
+ decoder_start_token_id=model.config.decoder_start_token_id,
783
+ )
784
+
785
+ # 11. Configure Trainer
786
+ # Trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch
787
+ # Only required for streaming: Trainer automatically shuffles non-streaming datasets
788
+ class ShuffleCallback(TrainerCallback):
789
+ def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):
790
+ if isinstance(train_dataloader.dataset, IterableDatasetShard):
791
+ pass # set_epoch() is handled by the Trainer
792
+ elif isinstance(train_dataloader.dataset, IterableDataset):
793
+ train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)
794
+
795
+ progress_callback = WandbProgressResultsCallback(trainer, samples_dataset)
796
+
797
+ # Initialize Trainer
798
+ trainer = Seq2SeqTrainer(
799
+ model=model,
800
+ args=training_args,
801
+ train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
802
+ eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
803
+ tokenizer=feature_extractor,
804
+ data_collator=data_collator,
805
+ compute_metrics=compute_metrics
806
+ if training_args.predict_with_generate
807
+ else None,
808
+ callbacks=[ShuffleCallback()],
809
+ )
810
+
811
+ trainer.add_callback(progress_callback)
812
+
813
+ # 12. Training
814
+ if training_args.do_train:
815
+ checkpoint = None
816
+ if training_args.resume_from_checkpoint is not None:
817
+ checkpoint = training_args.resume_from_checkpoint
818
+ elif last_checkpoint is not None:
819
+ checkpoint = last_checkpoint
820
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
821
+ trainer.save_model() # Saves the feature extractor too for easy upload
822
+
823
+ metrics = train_result.metrics
824
+ if data_args.max_train_samples:
825
+ metrics["train_samples"] = data_args.max_train_samples
826
+ trainer.log_metrics("train", metrics)
827
+ trainer.save_metrics("train", metrics)
828
+ trainer.save_state()
829
+
830
+ # 13. Evaluation
831
+ results = {}
832
+ if training_args.do_eval:
833
+ logger.info("*** Evaluate ***")
834
+ metrics = trainer.evaluate(
835
+ metric_key_prefix="eval",
836
+ max_length=training_args.generation_max_length,
837
+ num_beams=training_args.generation_num_beams,
838
+ )
839
+ if data_args.max_eval_samples:
840
+ metrics["eval_samples"] = data_args.max_eval_samples
841
+
842
+ trainer.log_metrics("eval", metrics)
843
+ trainer.save_metrics("eval", metrics)
844
+
845
+ # 14. Write Training Stats
846
+ kwargs = {
847
+ "finetuned_from": model_args.model_name_or_path,
848
+ "tasks": "automatic-speech-recognition",
849
+ "tags": "whisper-event",
850
+ }
851
+ if data_args.dataset_name is not None:
852
+ kwargs["dataset_tags"] = data_args.dataset_name
853
+ if data_args.dataset_config_name is not None:
854
+ kwargs[
855
+ "dataset"
856
+ ] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
857
+ else:
858
+ kwargs["dataset"] = data_args.dataset_name
859
+ if "common_voice" in data_args.dataset_name:
860
+ kwargs["language"] = data_args.dataset_config_name[:2]
861
+ if model_args.model_index_name is not None:
862
+ kwargs["model_name"] = model_args.model_index_name
863
+
864
+ if training_args.push_to_hub:
865
+ trainer.push_to_hub(**kwargs)
866
+ else:
867
+ trainer.create_model_card(**kwargs)
868
+
869
+ return results
870
+
871
+
872
+ if __name__ == "__main__":
873
+ main()
runs/Dec14_09-02-25_129-146-107-47/events.out.tfevents.1671008564.129-146-107-47.118226.0 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4af926d7444467a8f05dc84dffabb614c9f26a2af2d5dc2f41623550b39d3815
3
- size 63637
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a323350203dc36ce667b31d80d8451952b7800742b109d826566c8ba077ac8b
3
+ size 70235