rsonavane commited on
Commit
265600e
1 Parent(s): 2e75110

Upload student-teacher-distillation-streaming.py

Browse files
student-teacher-distillation-streaming.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Run teacher student distillation for Whisper model
2
+ from transformers import AutoTokenizer, AutoProcessor, AutoConfig, AutoFeatureExtractor, AutoTokenizer, AutoProcessor, \
3
+ AutoModelForSpeechSeq2Seq, set_seed, get_linear_schedule_with_warmup
4
+ from datasets import load_dataset, DatasetDict, interleave_datasets, IterableDatasetDict
5
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
6
+ import transformers
7
+ import argparse
8
+ import datasets
9
+ import evaluate
10
+ import string
11
+ from accelerate import Accelerator
12
+ from dataclasses import dataclass
13
+ from typing import Any, Dict, List, Optional, Union
14
+ import torch
15
+ import os
16
+ from tqdm.auto import tqdm
17
+ import numpy as np
18
+
19
+
20
+ def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs):
21
+ if "+" in split:
22
+ # load multiple splits separated by the `+` symbol *with* streaming mode
23
+ dataset_splits = [load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs) for split_name in split.split("+")]
24
+ # interleave multiple splits to form one dataset
25
+ interleaved_dataset = interleave_datasets(dataset_splits)
26
+ return interleaved_dataset
27
+ else:
28
+ # load a single split *with* streaming mode
29
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
30
+ return dataset
31
+
32
+
33
+ def train(args, accelerator: Accelerator):
34
+ # load dataset in streaming mode
35
+ raw_datasets = IterableDatasetDict()
36
+ raw_datasets["train"] = load_streaming_dataset(args.train_dataset_name, args.train_dataset_config_name,
37
+ split=args.train_split_name,
38
+ cache_dir=args.data_cache_dir)
39
+ raw_datasets["validation"] = load_streaming_dataset(args.validation_dataset_name,
40
+ args.validation_dataset_config_name,
41
+ split=args.validation_split_name,
42
+ cache_dir=args.data_cache_dir)
43
+
44
+ # raw_datasets = raw_datasets.remove_columns(["file", "speaker_id", "chapter_id", "id"])
45
+ # raw_datasets = raw_datasets.rename_columns({'audio': 'audio', 'text': 'text'})
46
+
47
+ assert args.audio_column in raw_datasets["train"].column_names
48
+ assert args.text_column in raw_datasets["train"].column_names
49
+
50
+ with accelerator.main_process_first():
51
+ if args.max_train_samples is not None:
52
+ raw_datasets["train"] = raw_datasets["train"].take(args.max_train_samples)
53
+ if args.max_val_samples is not None:
54
+ raw_datasets["validation"] = raw_datasets["validation"].take(args.max_val_samples)
55
+
56
+ student_config = AutoConfig.from_pretrained(args.student_model_name_or_path, cache_dir=args.student_cache_dir)
57
+ teacher_config = AutoConfig.from_pretrained(args.teacher_model_name_or_path, cache_dir=args.teacher_cache_dir)
58
+
59
+ # assuming student and teacher uses same feature extractor, tokenizer and processor
60
+ feature_extractor = AutoFeatureExtractor.from_pretrained(args.teacher_model_name_or_path, cache_dir=args.teacher_cache_dir)
61
+ tokenizer = AutoTokenizer.from_pretrained(args.teacher_model_name_or_path, cache_dir=args.teacher_cache_dir)
62
+ processor = AutoProcessor.from_pretrained(args.teacher_model_name_or_path, cache_dir=args.teacher_cache_dir)
63
+
64
+ # make sure decoder_start_token_id is defined for both
65
+ assert teacher_config.decoder_start_token_id is not None
66
+ assert student_config.decoder_start_token_id is not None
67
+
68
+ # We need to set the language and task ids for previously multilingual checkpoints, default is English and transcribe
69
+ # Set to None if the model is not multilingual
70
+ student_config.forced_decoder_ids = None
71
+ # tokenizer.get_decoder_prompt_ids(language=args.language, task=args.task, no_timestamps=True)
72
+ teacher_config.forced_decoder_ids = None
73
+ # tokenizer.get_decoder_prompt_ids(language=args.language, task=args.task, no_timestamps=True)
74
+ student_config.suppress_tokens = []
75
+ teacher_config.suppress_tokens = []
76
+
77
+ student_model = AutoModelForSpeechSeq2Seq.from_pretrained(args.student_model_name_or_path, config=student_config)
78
+ teacher_model = AutoModelForSpeechSeq2Seq.from_pretrained(args.teacher_model_name_or_path, config=teacher_config,
79
+ cache_dir=args.teacher_cache_dir)
80
+
81
+ accelerator.print(
82
+ f"Loaded the model on device: student: {student_model.device}, teacher:{teacher_model.device}, accelerator:{accelerator.device}")
83
+
84
+ # freeze teacher model
85
+ for p in teacher_model.parameters():
86
+ p.requires_grad = False
87
+
88
+ if args.freeze_encoder:
89
+ accelerator.print("Freezing encoder")
90
+ student_model.freeze_encoder()
91
+ student_model.model.encoder.gradient_checkpointing = False
92
+
93
+ # Resample speech dataset: so we just need to set the correct target sampling rate
94
+ with accelerator.main_process_first():
95
+ # raw_datasets = raw_datasets.cast_column(args.audio_column,
96
+ # datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate))
97
+ raw_datasets = raw_datasets.cast_column(args.audio_column, datasets.Audio(sampling_rate=16000))
98
+
99
+ # Preprocessing the raw_datasets, need to read the audio files as arrays and tokenize the targets.
100
+ # might need to change the normalizer depending on language and task
101
+ normalizer = BasicTextNormalizer()
102
+
103
+ def prepare_dataset(batch):
104
+ # process audio
105
+ sample = batch[args.audio_column]
106
+ # compute log-Mel input features from input audio array
107
+ batch["input_features"] = \
108
+ processor.feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]).input_features[0]
109
+ # process audio length
110
+ batch["input_length"] = len(sample["array"]) / sample["sampling_rate"]
111
+ # process
112
+ transcription = batch[args.text_column]
113
+ if not args.keep_case:
114
+ transcription = transcription.lower()
115
+ if not args.keep_punctuation:
116
+ transcription = normalizer(transcription).strip()
117
+ batch["labels"] = processor.tokenizer(transcription).input_ids
118
+ return batch
119
+
120
+ with accelerator.main_process_first():
121
+ vectorized_datasets = raw_datasets.map(prepare_dataset,
122
+ remove_columns=raw_datasets["train"].column_names)
123
+
124
+ # filter training data with inputs longer than max_input_length
125
+ def is_audio_in_length_range(length):
126
+ # return min_input_length <= length <= max_input_length
127
+ return args.min_duration_in_seconds <= length <= args.max_duration_in_seconds
128
+
129
+ with accelerator.main_process_first():
130
+ vectorized_datasets = vectorized_datasets.filter(is_audio_in_length_range,
131
+ input_columns=["input_length"])
132
+
133
+ @dataclass
134
+ class DataCollatorForSeq2SeqWithPadding:
135
+ processor: Any
136
+
137
+ def __call__(self, features: List[Union[Dict[str, torch.Tensor], Dict[str, Any]]]) -> Dict[str, torch.Tensor]:
138
+ # split inputs and labels since they have to be of different lengths and need different padding methods
139
+ # first treat the audio inputs by simply returning torch tensors
140
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
141
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
142
+
143
+ # then pad the labels
144
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
145
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
146
+ # replace -100 in labels by tokenizer.pad_token_id to ignore padding in loss
147
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
148
+
149
+ # if bos token is appended in previous step, remove it here as it's appended again in the forward pass
150
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
151
+ labels = labels[:, 1:]
152
+
153
+ batch["labels"] = labels
154
+ return batch
155
+
156
+ data_collator = DataCollatorForSeq2SeqWithPadding(processor=processor)
157
+
158
+ # now define data loaders
159
+ train_dataloader = torch.utils.data.DataLoader(vectorized_datasets["train"], shuffle=False, collate_fn=data_collator,
160
+ batch_size=args.per_device_train_batch_size)
161
+ eval_dataloader = torch.utils.data.DataLoader(vectorized_datasets["validation"], shuffle=False,
162
+ collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
163
+
164
+ # define optimizer
165
+ optimizer = torch.optim.AdamW(list(student_model.parameters()), lr=args.learning_rate)
166
+
167
+ # scheduler
168
+ lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
169
+ num_training_steps=args.train_steps)
170
+
171
+ # accelerator setup for distributed training, this handles all the device mapping, gradient accumulation, fp16 training etc.
172
+ # add eval_dataloader to accelerator.prepare for distributed evaluation
173
+ student_model, teacher_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
174
+ student_model, teacher_model, optimizer, train_dataloader, lr_scheduler)
175
+ accelerator.print(
176
+ f"Distributed the model on device: student: {student_model.device}, teacher:{teacher_model.device}, accelerator:{accelerator.device}")
177
+
178
+ global_step = 0 # global step for logging
179
+ total_loss = 0 # total loss before each eval
180
+ total_kl_loss = 0 # total kl loss before each eval
181
+ total_ce_loss = 0 # total ce loss before each eval
182
+
183
+ if args.resume_from_checkpoint is not None:
184
+ accelerator.print(f"Loading checkpoint: {args.resume_from_checkpoint}")
185
+ accelerator.load_state(args.resume_from_checkpoint)
186
+ steps_completed = int(args.resume_from_checkpoint.split("-")[-1])
187
+ global_step += steps_completed
188
+ train_dataloader = accelerator.skip_first_batches(train_dataloader, steps_completed)
189
+
190
+ # load metric
191
+ wer_metric = evaluate.load("wer")
192
+ cer_metric = evaluate.load("cer")
193
+ all_punctuations = list(string.punctuation.replace("'", ""))
194
+
195
+ def compute_metrics(preds, labels):
196
+ # replace padded labels by padding token
197
+ for idx in range(len(labels)):
198
+ labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
199
+
200
+ pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
201
+ label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
202
+
203
+ pred_str = [_.strip() for _ in pred_str]
204
+ label_str = [_.strip() for _ in label_str]
205
+
206
+ # space punctuation for orthographic WER
207
+ spaced_pred_str = [pred_str[i].replace(punctuation, "") for punctuation in all_punctuations for i in
208
+ range(len(pred_str))]
209
+ spaced_label_str = [label_str[i].replace(punctuation, "") for punctuation in all_punctuations for i in
210
+ range(len(label_str))]
211
+
212
+ # compute WER
213
+ wer_ortho = 100 * wer_metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
214
+ cer_ortho = 100 * cer_metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
215
+ accelerator.print(
216
+ f"\nspaced_pred_str: {[_ for i, _ in enumerate(spaced_pred_str) if i < 3]}, \nspaced_label_str: {[_ for i, _ in enumerate(spaced_label_str) if i < 3]}")
217
+ # normalize everything and re-compute the WER
218
+ norm_pred_str = [normalizer(pred) for pred in pred_str]
219
+ norm_label_str = [normalizer(label) for label in label_str]
220
+ # filtering step to only evaluate the samples that correspond to non-zero normlized references:
221
+ norm_pred_str = [norm_pred_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
222
+ norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
223
+ accelerator.print(
224
+ f"\nnorm_pred_str: {[_ for i, _ in enumerate(norm_pred_str) if i < 3]}, \nnorm_label_str: {[_ for i, _ in enumerate(norm_label_str) if i < 3]}")
225
+
226
+ wer = 100 * wer_metric.compute(predictions=norm_pred_str, references=norm_label_str)
227
+ cer = 100 * cer_metric.compute(predictions=norm_pred_str, references=norm_label_str)
228
+
229
+ return {"wer": wer, "wer_ortho": wer_ortho, "cer": cer, "cer_ortho": cer_ortho}, pred_str, label_str
230
+
231
+ # save feature extractor, tokenizer, config and generation config
232
+ with accelerator.main_process_first():
233
+ output_dir = args.output_dir
234
+ feature_extractor.save_pretrained(output_dir)
235
+ tokenizer.save_pretrained(output_dir)
236
+ student_config.save_pretrained(output_dir)
237
+ teacher_config.save_pretrained(output_dir)
238
+
239
+ progress_bar = tqdm(range(global_step, args.train_steps), disable=not accelerator.is_main_process)
240
+ # define training step
241
+ while global_step < args.train_steps:
242
+ student_model.train()
243
+ for batch in train_dataloader:
244
+ with accelerator.accumulate(student_model):
245
+ # forward pass
246
+ outputs = student_model(**batch)
247
+ ce_loss = outputs.loss
248
+ logits = outputs.logits
249
+ with torch.no_grad():
250
+ teacher_logits = teacher_model(**batch).logits
251
+ # compute kl loss
252
+ kl_loss = torch.nn.functional.kl_div(torch.nn.functional.log_softmax(logits / args.temperature, dim=-1),
253
+ torch.nn.functional.softmax(teacher_logits / args.temperature,
254
+ dim=-1),
255
+ reduction="batchmean") * (args.temperature ** 2)
256
+ # compute total loss
257
+ loss = args.alpha_ce * ce_loss + args.alpha_distil * kl_loss
258
+
259
+ total_kl_loss += kl_loss.detach().item()
260
+ total_ce_loss += ce_loss.detach().item()
261
+ total_loss += loss.detach().item()
262
+ accelerator.backward(loss)
263
+ optimizer.step()
264
+ lr_scheduler.step()
265
+ optimizer.zero_grad()
266
+
267
+ global_step += 1
268
+ progress_bar.update(1)
269
+
270
+ # log metrics
271
+ eval_metrics = {}
272
+ eval_preds = []
273
+ eval_labels = []
274
+
275
+ if global_step % args.eval_steps == 0:
276
+ student_model.eval()
277
+ valid_loss = 0
278
+ # validation_progress_bar = tqdm(range(0, len(eval_dataloader)), disable=not accelerator.is_main_process)
279
+ for batch in eval_dataloader:
280
+ with torch.no_grad():
281
+ batch.to(accelerator.device)
282
+ references = batch.labels
283
+ if not args.predict_without_generate:
284
+ # accelerator.print("\nPredicting with generate")
285
+ # for modules wrapped in DataParallel or DistributedDataParallel, we need to use .module to access the underlying module
286
+ if accelerator.num_processes > 1:
287
+ # accelerator.print("Distributed eval")
288
+ predictions = student_model.module.generate(batch.input_features)
289
+ else:
290
+ predictions = student_model.generate(batch.input_features)
291
+ else:
292
+ # accelerator.print("\nPredicting without generate")
293
+ outputs = student_model(**batch)
294
+ valid_loss += outputs.loss.detach().item()
295
+ pred_logits = outputs.logits
296
+ predictions = pred_logits.argmax(-1)
297
+
298
+ # accelerator.print("Before gather")
299
+ # accelerator.print(f"len of predictions: {len(predictions)}, len of references: {len(references)}")
300
+ # accelerator.print(f"All types for gather has to be tensor: \ntype of predictions: {type(predictions)}, type of references: {type(references)}")
301
+ predictions, references = accelerator.gather_for_metrics((predictions, references))
302
+ # accelerator.print("After gather")
303
+ # accelerator.print(f"len of predictions: {len(predictions)}, len of references: {len(references)}")
304
+
305
+ ###########################
306
+ # convert any token after after first tokenizer.eos_token_id to eos_token_id
307
+ for idx, pred in enumerate(predictions):
308
+ first_eos_token_idx = (pred == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
309
+ if len(first_eos_token_idx) > 0:
310
+ predictions[idx, first_eos_token_idx[0] + 1:] = tokenizer.eos_token_id
311
+ ###########################
312
+
313
+ eval_preds.extend(predictions)
314
+ eval_labels.extend(references)
315
+ # validation_progress_bar.update(1)
316
+
317
+ accelerator.print(f"\npredictions: {eval_preds[:3]}, \nreferences: {eval_preds[:3]}")
318
+ accelerator.print(f"\nlen(eval_preds): {len(eval_preds)}, \nlen(eval_labels): {len(eval_labels)}")
319
+ eval_metrics, eval_preds, eval_labels = compute_metrics(eval_preds, eval_labels)
320
+ train_loss = total_loss / (
321
+ args.eval_steps * args.per_device_train_batch_size * accelerator.num_processes)
322
+ train_kl_loss = total_kl_loss / (
323
+ args.eval_steps * args.per_device_train_batch_size * accelerator.num_processes)
324
+ train_ce_loss = total_ce_loss / (
325
+ args.eval_steps * args.per_device_train_batch_size * accelerator.num_processes)
326
+
327
+ accelerator.print(
328
+ f"Step: {global_step}, Train Loss: {train_loss}, Train KL Loss: {train_kl_loss}, Train CE Loss: {train_ce_loss}, \
329
+ Eval WER: {eval_metrics['wer']}, Eval WER Ortho: {eval_metrics['wer_ortho']}, Eval CER: {eval_metrics['cer']}, \
330
+ Eval CER Ortho: {eval_metrics['cer_ortho']}")
331
+ accelerator.log(
332
+ {"cer": eval_metrics["cer"], "cer_ortho": eval_metrics["cer_ortho"], "wer": eval_metrics["wer"],
333
+ "wer_ortho": eval_metrics["wer_ortho"],
334
+ "train_loss": train_loss,
335
+ "train_kl_loss": train_kl_loss,
336
+ "train_ce_loss": train_ce_loss,
337
+ # "eval_loss": valid_loss / (len(eval_dataloader))
338
+ })
339
+
340
+ output_dir = os.path.join(args.output_dir, f"checkpoint-{global_step}")
341
+ # accelerator.save(student_model.state_dict(), output_dir)
342
+ accelerator.save_state(output_dir)
343
+ accelerator.wait_for_everyone()
344
+ unwrapped_model = accelerator.unwrap_model(student_model)
345
+ unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save,
346
+ is_main_process=accelerator.is_main_process)
347
+ total_loss = 0
348
+ total_kl_loss = 0
349
+ total_ce_loss = 0
350
+ student_model.train()
351
+
352
+
353
+ def main():
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("--teacher_model_name_or_path", type=str, default="openai/whisper-large-v2")
356
+ parser.add_argument("--student_model_name_or_path", type=str, default="distil-whisper/large-v2-8")
357
+ parser.add_argument("--output_dir", type=str, default="output")
358
+ parser.add_argument("--per_device_train_batch_size", type=int, default=16)
359
+ parser.add_argument("--per_device_eval_batch_size", type=int, default=16)
360
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
361
+ parser.add_argument("--freeze_encoder", action="store_true")
362
+ parser.add_argument("--temperature", type=float, default=2.0)
363
+ parser.add_argument("--alpha_ce", type=float, default=0.5)
364
+ parser.add_argument("--alpha_distil", type=float, default=0.5)
365
+ parser.add_argument("--language", type=str, default="en")
366
+ parser.add_argument("--task", type=str, default="transcribe")
367
+ parser.add_argument("--train_steps", type=int, default=100000)
368
+ parser.add_argument("--eval_steps", type=int, default=100)
369
+ parser.add_argument("--warmup_steps", type=int, default=2000)
370
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4) # increase by 2x for every 2x decrease in batch size 64
371
+ parser.add_argument("--seed", type=int, default=42)
372
+ parser.add_argument("--data_cache_dir", type=str, default="data/cache")
373
+ parser.add_argument("--teacher_cache_dir", type=str, default="model/cache")
374
+ parser.add_argument("--student_cache_dir", type=str, default="model/cache")
375
+ parser.add_argument("--mixed_precision", type=str, default="fp16")
376
+ parser.add_argument("--max_train_samples", type=int, default=None)
377
+ parser.add_argument("--max_val_samples", type=int, default=None)
378
+ parser.add_argument("--max_test_samples", type=int, default=None)
379
+ parser.add_argument("--audio_column", type=str, default="audio")
380
+ parser.add_argument("--text_column", type=str, default="text")
381
+ parser.add_argument("--max_duration_in_seconds", type=float, default=30)
382
+ parser.add_argument("--min_duration_in_seconds", type=float, default=1)
383
+ parser.add_argument("--keep_case", action="store_true")
384
+ parser.add_argument("--keep_punctuation", action="store_true")
385
+ parser.add_argument("--resume_from_checkpoint", type=str, default=None)
386
+ parser.add_argument("--num_workers", type=int, default=16)
387
+ parser.add_argument("--predict_without_generate", action="store_true")
388
+ parser.add_argument("--train_dataset_name", type=str, default="librispeech_asr")
389
+ parser.add_argument("--train_dataset_config_name", type=str, default="all")
390
+ parser.add_argument("--train_split_name", type=str, default="train.clean.100+train.clean.360+train.other.500")
391
+ parser.add_argument("--validation_dataset_name", type=str, default="librispeech_asr")
392
+ parser.add_argument("--validation_dataset_config_name", type=str, default="all")
393
+ parser.add_argument("--validation_split_name", type=str, default="validation.clean")
394
+
395
+ args = parser.parse_args()
396
+
397
+ # set seed
398
+ set_seed(args.seed)
399
+
400
+ if args.teacher_model_name_or_path is None or args.student_model_name_or_path is None:
401
+ raise ValueError("teacher_model_name_or_path and student_model_name_or_path cannot be None")
402
+
403
+ accelerator = Accelerator(mixed_precision=args.mixed_precision, gradient_accumulation_steps=1,
404
+ log_with="tensorboard", logging_dir=args.output_dir)
405
+
406
+ # have only one message per logs of transformers or datasets, so logging verbosity INFO only for the main process
407
+ if accelerator.is_main_process:
408
+ datasets.utils.logging.set_verbosity_info()
409
+ transformers.utils.logging.set_verbosity_info()
410
+ else:
411
+ datasets.utils.logging.set_verbosity_error()
412
+ transformers.utils.logging.set_verbosity_error()
413
+
414
+ # establish trackers for logging
415
+ track_config = {"lr": args.learning_rate,
416
+ "train_batch_size": args.per_device_train_batch_size,
417
+ "eval_batch_size": args.per_device_eval_batch_size,
418
+ "seed": args.seed,
419
+ "train_steps": args.train_steps}
420
+ accelerator.init_trackers('runs', track_config)
421
+ train(args, accelerator)
422
+ accelerator.end_training()
423
+
424
+
425
+ if __name__ == "__main__":
426
+ main()