supawichwac commited on
Commit
17a9ba0
1 Parent(s): ca6a3e2

Saving train state of step 25

Browse files
.ipynb_checkpoints/run_distillation-checkpoint.py CHANGED
@@ -750,11 +750,14 @@ def main():
750
  else:
751
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
752
 
 
 
753
  # 2. Initialize the accelerator
754
  # We will let the accelerator handle device placement for us in this example
755
  # We simply have to specify the training precision and any trackers being used
756
  # We'll use the same dtype arguments as our JAX/Flax training script and convert
757
  # it to accelerate format
 
758
  if training_args.dtype == "float16":
759
  mixed_precision = "fp16"
760
  teacher_dtype = torch.float16
@@ -1007,686 +1010,688 @@ def main():
1007
  )
1008
  else:
1009
  is_multilingual = False
 
 
1010
 
1011
  # 8. Create a single speech processor - make sure all processes wait until data is saved
1012
- if accelerator.is_main_process:
1013
- feature_extractor.save_pretrained(training_args.output_dir)
1014
- tokenizer.save_pretrained(training_args.output_dir)
1015
- # save the config and generation config as well
1016
- config.save_pretrained(training_args.output_dir)
1017
- student_model.generation_config.save_pretrained(training_args.output_dir)
1018
-
1019
- accelerator.wait_for_everyone()
1020
- processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1021
-
1022
- # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1023
- # so we just need to set the correct target sampling rate.
1024
- sampling_rate = feature_extractor.sampling_rate
1025
- raw_datasets = raw_datasets.cast_column(
1026
- data_args.audio_column_name,
1027
- datasets.features.Audio(sampling_rate=sampling_rate),
1028
- )
1029
-
1030
- # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1031
- # 10.1: Define the pre-processing constants
1032
- max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1033
- min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1034
- max_label_length = (
1035
- data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1036
- )
1037
-
1038
- timestamp_probability = data_args.timestamp_probability
1039
- condition_on_prev_probability = data_args.condition_on_prev_probability
1040
- return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1041
-
1042
- timestamp_ids = tokenizer.timestamp_ids()
1043
- timestamp_begin = tokenizer.all_special_ids[-1]
1044
- timestamp_position = 3 if is_multilingual else 1
1045
-
1046
- decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1047
- decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1048
- prompt_cutoff_length = max_label_length // 2
1049
-
1050
- num_workers = data_args.preprocessing_num_workers
1051
- dataloader_num_workers = training_args.dataloader_num_workers
1052
- prefetch_factor = training_args.dataloader_prefetch_factor
1053
-
1054
- metric = evaluate.load("wer")
1055
- normalizer = (
1056
- BasicTextNormalizer()
1057
- if data_args.language is not None
1058
- else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1059
- )
1060
- wer_threshold = data_args.wer_threshold
1061
- use_pseudo_labels = data_args.use_pseudo_labels
1062
- train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1063
-
1064
- # 10.2: filter based on maximum number of training/evaluation samples
1065
- if training_args.do_train and data_args.max_train_samples is not None:
1066
- raw_datasets["train"] = (
1067
- raw_datasets["train"].take(data_args.max_train_samples)
1068
- if data_args.streaming
1069
- else raw_datasets["train"].select(range(data_args.max_train_samples))
1070
- )
1071
-
1072
- if training_args.do_eval and data_args.max_eval_samples is not None:
1073
- for eval_split in all_eval_splits:
1074
- raw_datasets[eval_split] = (
1075
- raw_datasets[eval_split].take(data_args.max_eval_samples)
1076
- if data_args.streaming
1077
- else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1078
- )
1079
-
1080
- # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1081
- def is_wer_in_range(ground_truth, whisper_transcript):
1082
- norm_ground_truth = normalizer(ground_truth)
1083
- if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1084
- # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1085
- return False
1086
- elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
1087
- norm_whisper_transcript = normalizer(whisper_transcript)
1088
- wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1089
- return wer < wer_threshold
1090
- else:
1091
- # filter automatically since we can't know the WER
1092
- return False
1093
-
1094
- filter_by_wer_threshold = partial(
1095
- raw_datasets["train"].filter,
1096
- function=is_wer_in_range,
1097
- input_columns=["text", "whisper_transcript"],
1098
- )
1099
-
1100
- if wer_threshold is not None and use_pseudo_labels:
1101
- with accelerator.main_process_first():
1102
- raw_datasets["train"] = (
1103
- filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1104
- if not data_args.streaming
1105
- else filter_by_wer_threshold()
1106
- )
1107
-
1108
- # 10.4: pre-process training/evaluation datasets
1109
- def prepare_train_dataset(batch):
1110
- """
1111
- Pre-process the raw dataset in a three stage process:
1112
- 1. Convert the audio arrays to log-mel spectrogram inputs
1113
- 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1114
- 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1115
- """
1116
- # process audio input
1117
- audio = [sample["array"] for sample in batch["audio"]]
1118
- inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1119
- batch["input_features"] = inputs.input_features
1120
- batch["input_length"] = [len(sample) for sample in audio]
1121
-
1122
- # process text targets - for training these are the Whisper-generated pseudo-labels
1123
- input_str_batched = batch[train_text_column_name]
1124
- condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1125
-
1126
- all_token_ids = []
1127
- all_token_ids_unprompted = []
1128
- for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1129
- token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1130
-
1131
- # check whether we have timestamps in the PLs and filter if required
1132
- has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1133
- if has_timestamps:
1134
- # sample from binomial distribution to get probability of training on timestamps
1135
- predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1136
- if not predict_timestamps:
1137
- # filter timestamps and insert the <|notimestamps|> task token
1138
- token_ids = [token for token in token_ids if token < timestamp_begin]
1139
- token_ids.insert(timestamp_position, timestamp_begin)
1140
-
1141
- all_token_ids_unprompted.append(token_ids)
1142
- # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1143
- condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1144
- if not condition_on_prev:
1145
- prev_ids = None
1146
- elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1147
- # prompt ids are the penultimate token ids in the batch
1148
- prev_ids = all_token_ids_unprompted[-2]
1149
-
1150
- if prev_ids is not None:
1151
- if has_timestamps and not predict_timestamps:
1152
- # filter timestamp ids from prompt when not predicting timestamps
1153
- prev_ids = [token for token in prev_ids if token < timestamp_begin]
1154
-
1155
- # check that the length of the prompt does not exceed more than half the max label length (224)
1156
- if len(prev_ids) > prompt_cutoff_length:
1157
- prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1158
- prev_ids = [decoder_prev_token_id] + prev_ids
1159
-
1160
- # and that the total length of the labels does not exceed the max label length (448)
1161
- if len(prev_ids + token_ids) > max_label_length:
1162
- trim_length = len(prev_ids + token_ids) - max_label_length + 1
1163
- prev_ids = prev_ids[trim_length:]
1164
- prev_ids = [decoder_prev_token_id] + prev_ids
1165
-
1166
- token_ids = prev_ids + token_ids
1167
-
1168
- all_token_ids.append(token_ids)
1169
-
1170
- batch["labels"] = all_token_ids
1171
- return batch
1172
-
1173
- def prepare_eval_dataset(batch):
1174
- # process audio input
1175
- sample = batch["audio"]
1176
- inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1177
- batch["input_features"] = inputs.input_features[0]
1178
- batch["input_length"] = len(sample["array"])
1179
-
1180
- # process targets - for evaluation these are the ground-truth transcriptions
1181
- input_str = batch["text"]
1182
- batch["labels"] = tokenizer(input_str).input_ids
1183
- return batch
1184
-
1185
- vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1186
- if training_args.do_train:
1187
- # with streaming mode we can only have 1 worker, whereas with non-streaming
1188
- # we can use `num_workers` (which is much faster)
1189
- # We gate the pre-processing function accordingly
1190
- map_fn_train = partial(
1191
- raw_datasets["train"].map,
1192
- function=prepare_train_dataset,
1193
- remove_columns=raw_datasets_train_features,
1194
- batched=True,
1195
- batch_size=data_args.preprocessing_batch_size,
1196
- )
1197
- with accelerator.main_process_first():
1198
- vectorized_datasets["train"] = (
1199
- map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1200
- if not data_args.streaming
1201
- else map_fn_train()
1202
- )
1203
- if training_args.do_eval:
1204
- for eval_split in all_eval_splits:
1205
- raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1206
- map_fn_eval = partial(
1207
- raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1208
- )
1209
- with accelerator.main_process_first():
1210
- vectorized_datasets[eval_split] = (
1211
- map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1212
- if not data_args.streaming
1213
- else map_fn_eval()
1214
- )
1215
-
1216
- # 10.5: Filter training data with inputs longer than `max_input_length`
1217
- def is_audio_in_length_range(length):
1218
- return min_input_length < length < max_input_length
1219
-
1220
- filter_by_audio_fn = partial(
1221
- vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1222
- )
1223
- with accelerator.main_process_first():
1224
- vectorized_datasets = (
1225
- filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1226
- if not data_args.streaming
1227
- else filter_by_audio_fn()
1228
- )
1229
-
1230
- # 10.6: Filter training data with labels longer than `max_label_length`
1231
- def is_labels_in_length_range(labels):
1232
- return 0 < len(labels) <= max_label_length
1233
-
1234
- filter_by_labels_fn = partial(
1235
- vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1236
- )
1237
- with accelerator.main_process_first():
1238
- vectorized_datasets = (
1239
- filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1240
- if not data_args.streaming
1241
- else filter_by_labels_fn()
1242
- )
1243
-
1244
- # Pre-processing complete!
1245
- # For large datasets it is advised to run the preprocessing on a
1246
- # single machine first with `--preprocessing_only` since there will mostly likely
1247
- # be a timeout when running the script in distributed mode.
1248
- # In a second step, `--preprocessing_only` can then be set to `False` to load the
1249
- # cached dataset
1250
- if data_args.preprocessing_only:
1251
- if data_args.streaming:
1252
- raise ValueError(
1253
- "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1254
- "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1255
- "on the fly with streaming mode."
1256
- )
1257
- cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1258
- logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1259
- return
1260
-
1261
- # 11. Define Evaluation Metrics
1262
- def compute_metrics(preds, labels):
1263
- # replace padded labels by the padding token
1264
- for idx in range(len(labels)):
1265
- labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1266
-
1267
- pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1268
- # we do not want to group tokens when computing the metrics
1269
- label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1270
- wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1271
-
1272
- # normalize everything and re-compute the WER
1273
- norm_pred_str = [normalizer(pred) for pred in pred_str]
1274
- norm_label_str = [normalizer(label) for label in label_str]
1275
- # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1276
- pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1277
- label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1278
- # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1279
- norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1280
- norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1281
-
1282
- wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1283
- return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1284
-
1285
- # 12. Define Training Schedule
1286
- # Store some constants
1287
- per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1288
- train_batch_size = per_device_train_batch_size * accelerator.num_processes
1289
- gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1290
- per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1291
-
1292
- if not data_args.streaming and training_args.max_steps < 0:
1293
- num_epochs = int(training_args.num_train_epochs)
1294
- steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1295
- total_train_steps = steps_per_epoch * num_epochs
1296
- elif training_args.max_steps > 0:
1297
- logger.info("max_steps is given, it will override any value given in num_train_epochs")
1298
- total_train_steps = int(training_args.max_steps)
1299
- if not data_args.streaming:
1300
- steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1301
- num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1302
- else:
1303
- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1304
- num_epochs = sys.maxsize
1305
- steps_per_epoch = total_train_steps
1306
- else:
1307
- raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1308
-
1309
- if training_args.eval_steps is None:
1310
- logger.info(
1311
- f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1312
- )
1313
- eval_steps = steps_per_epoch
1314
- else:
1315
- eval_steps = training_args.eval_steps
1316
-
1317
- # 13. Define optimizer, LR scheduler, collator
1318
- decay_parameters = get_parameter_names(
1319
- student_model,
1320
- [nn.LayerNorm],
1321
- forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1322
- )
1323
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
1324
- optimizer_grouped_parameters = [
1325
- {
1326
- "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1327
- "weight_decay": training_args.weight_decay,
1328
- },
1329
- {
1330
- "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1331
- "weight_decay": 0.0,
1332
- },
1333
- ]
1334
- optimizer = torch.optim.AdamW(
1335
- params=optimizer_grouped_parameters,
1336
- lr=training_args.learning_rate,
1337
- betas=(training_args.adam_beta1, training_args.adam_beta2),
1338
- eps=training_args.adam_epsilon,
1339
- )
1340
-
1341
- # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1342
- lr_scheduler = get_scheduler(
1343
- name=training_args.lr_scheduler_type,
1344
- optimizer=optimizer,
1345
- num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1346
- num_training_steps=total_train_steps * accelerator.num_processes,
1347
- )
1348
-
1349
- data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1350
- processor=processor,
1351
- decoder_start_token_id=decoder_start_token_id,
1352
- decoder_prev_token_id=decoder_prev_token_id,
1353
- input_padding="longest",
1354
- target_padding="max_length",
1355
- max_target_length=max_label_length,
1356
- )
1357
-
1358
- # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1359
- # so that we can still access the configs
1360
- num_beams = (
1361
- training_args.generation_num_beams
1362
- if training_args.generation_num_beams is not None
1363
- else getattr(student_model.generation_config, "num_beams", 1)
1364
- )
1365
-
1366
- gen_kwargs = {
1367
- "max_length": max_label_length,
1368
- "num_beams": num_beams,
1369
- "return_timestamps": return_timestamps,
1370
- }
1371
- if is_multilingual:
1372
- # forcing the language and task tokens helps multilingual models in their generations
1373
- gen_kwargs.update(
1374
- {
1375
- "language": data_args.language,
1376
- "task": data_args.task,
1377
- }
1378
- )
1379
-
1380
- # 15. Prepare everything with accelerate
1381
- student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1382
- student_model, teacher_model, optimizer, lr_scheduler
1383
- )
1384
-
1385
- def kl_divergence(target_distribution, log_predicted_distribution, labels):
1386
- kl_loss = nn.KLDivLoss(reduction="none")
1387
- divergence = kl_loss(log_predicted_distribution, target_distribution)
1388
- # ignore padded tokens from divergence, i.e. where labels are not set to -100
1389
- padding_mask = labels >= 0
1390
- padding_mask = padding_mask.unsqueeze(-1)
1391
- divergence = divergence * padding_mask
1392
- # take the average over the mini-batch
1393
- divergence = divergence.sum() / padding_mask.sum()
1394
- return divergence
1395
-
1396
- # Define gradient update step fn
1397
- def train_step(
1398
- batch,
1399
- temperature=2.0,
1400
- ):
1401
- student_model.train()
1402
- teacher_model.eval()
1403
-
1404
- student_outputs = student_model(**batch)
1405
- with torch.no_grad():
1406
- if share_hidden_states:
1407
- # if the student and teacher share the same frozen encoder then we don't have to recompute the
1408
- # encoder hidden-states for the teacher model, we can just re-use from the student
1409
- encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1410
- teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1411
- else:
1412
- # do the full forward pass for the teacher model (encoder + decoder)
1413
- teacher_outputs = teacher_model(**batch)
1414
-
1415
- # CE (data) loss
1416
- ce_loss = student_outputs.loss
1417
- # rescale distribution by temperature to ensure gradients scale correctly
1418
- teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1419
- # log softmax of student predictions for numerical stability
1420
- student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1421
- # KL-divergence loss (scaled by temperature)
1422
- kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1423
-
1424
- # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1425
- loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1426
- metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1427
- return loss, metrics
1428
-
1429
- # Define eval fn
1430
- def eval_step(batch):
1431
- student_model.eval()
1432
- teacher_model.eval()
1433
-
1434
- with torch.no_grad():
1435
- student_outputs = student_model(**batch)
1436
- if share_hidden_states:
1437
- encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1438
- teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1439
- else:
1440
- teacher_outputs = teacher_model(**batch)
1441
-
1442
- # CE (data) loss
1443
- ce_loss = student_outputs.loss
1444
-
1445
- # log softmax / softmax for numerical stability
1446
- student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1447
- teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1448
- # temperature is always 1 for eval
1449
- kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1450
-
1451
- # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1452
- loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1453
- metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1454
- return metrics
1455
-
1456
- def generate_step(batch):
1457
- student_model.eval()
1458
- output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1459
- output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1460
- return output_ids
1461
-
1462
- logger.info("***** Running training *****")
1463
- logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1464
- if not data_args.streaming:
1465
- logger.info(f" Num epochs = {num_epochs}")
1466
- logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1467
- logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1468
- logger.info(
1469
- f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1470
- )
1471
- logger.info(f" Total optimization steps = {total_train_steps}")
1472
-
1473
- # ======================== Training ================================
1474
- train_time = 0
1475
- train_start = time.time()
1476
- steps_trained_progress_bar = tqdm(
1477
- range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1478
- )
1479
- continue_training = True
1480
- epochs_trained = 0
1481
- cur_step = 0
1482
-
1483
- checkpoint = None
1484
- if training_args.resume_from_checkpoint is not None:
1485
- checkpoint = training_args.resume_from_checkpoint
1486
- elif last_checkpoint is not None:
1487
- checkpoint = last_checkpoint
1488
-
1489
- if checkpoint is not None:
1490
- accelerator.load_state(checkpoint)
1491
- # Find num steps and epoch from saved state string pattern
1492
- pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1493
- match = re.search(pattern, checkpoint)
1494
- cur_step = int(match.group(1))
1495
- epochs_trained = int(match.group(2))
1496
-
1497
- logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1498
- logger.info(f" Continuing training from epoch {epochs_trained}")
1499
- logger.info(f" Continuing training from global step {cur_step}")
1500
-
1501
- steps_trained_progress_bar.update(cur_step)
1502
-
1503
- for epoch in range(0, epochs_trained):
1504
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1505
-
1506
- if not data_args.streaming and training_args.max_steps < 0:
1507
- # we know exactly the number of steps per epoch, so can skip through the required number of batches
1508
- resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1509
- else:
1510
- # Currently we don't know how many steps we've taken in the current epoch
1511
- # So we just shuffle the dataset one extra time and start from a fresh epoch
1512
- # This is "good enough" for our purposes but not fully correct
1513
- resume_step = None
1514
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1515
- else:
1516
- resume_step = None
1517
-
1518
- for epoch in range(epochs_trained, num_epochs):
1519
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1520
- train_dataloader = DataLoader(
1521
- vectorized_datasets["train"],
1522
- collate_fn=data_collator,
1523
- batch_size=per_device_train_batch_size,
1524
- num_workers=dataloader_num_workers,
1525
- prefetch_factor=prefetch_factor,
1526
- pin_memory=training_args.dataloader_pin_memory,
1527
- )
1528
- train_dataloader = accelerator.prepare(train_dataloader)
1529
- if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1530
- train_dataloader.dataset.set_epoch(epoch)
1531
-
1532
- if resume_step is not None:
1533
- # Skip the first N batches in the dataloader when resuming from a checkpoint
1534
- train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1535
- resume_step = None
1536
-
1537
- for batch in train_dataloader:
1538
- with accelerator.accumulate(student_model):
1539
- loss, train_metric = train_step(batch, temperature=training_args.temperature)
1540
- accelerator.backward(loss)
1541
- if accelerator.sync_gradients:
1542
- accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1543
- optimizer.step()
1544
- lr_scheduler.step()
1545
- optimizer.zero_grad()
1546
-
1547
- # Check if the accelerator has performed an optimization step behind the scenes
1548
- if accelerator.sync_gradients:
1549
- steps_trained_progress_bar.update(1)
1550
- cur_step += 1
1551
-
1552
- if cur_step % training_args.logging_steps == 0:
1553
- steps_trained_progress_bar.write(
1554
- f"Step... ({cur_step} / {total_train_steps} | Loss:"
1555
- f" {train_metric['loss']}, Learning Rate:"
1556
- f" {lr_scheduler.get_last_lr()[0]})"
1557
- )
1558
- log_metric(
1559
- accelerator,
1560
- metrics=train_metric,
1561
- learning_rate=lr_scheduler.get_last_lr()[0],
1562
- train_time=train_time + time.time() - train_start,
1563
- step=cur_step,
1564
- epoch=epoch,
1565
- prefix="train",
1566
- )
1567
-
1568
- # save checkpoint and weights after each save_steps and at the end of training
1569
- if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1570
- intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1571
- accelerator.save_state(output_dir=intermediate_dir)
1572
- accelerator.wait_for_everyone()
1573
- if accelerator.is_main_process:
1574
- rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1575
-
1576
- if training_args.push_to_hub:
1577
- upload_folder(
1578
- folder_path=training_args.output_dir,
1579
- repo_id=repo_name,
1580
- repo_type="model",
1581
- commit_message=f"Saving train state of step {cur_step}",
1582
- )
1583
-
1584
- if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1585
- train_time += time.time() - train_start
1586
- student_model.eval()
1587
- # ======================== Evaluating ==============================
1588
- for eval_split in all_eval_splits:
1589
- eval_metrics = []
1590
- eval_preds = []
1591
- eval_labels = []
1592
- eval_start = time.time()
1593
-
1594
- validation_dataloader = DataLoader(
1595
- vectorized_datasets[eval_split],
1596
- collate_fn=data_collator,
1597
- batch_size=per_device_eval_batch_size,
1598
- drop_last=False,
1599
- num_workers=dataloader_num_workers,
1600
- prefetch_factor=prefetch_factor,
1601
- pin_memory=training_args.dataloader_pin_memory,
1602
- )
1603
- validation_dataloader = accelerator.prepare(validation_dataloader)
1604
-
1605
- for batch in tqdm(
1606
- validation_dataloader,
1607
- desc=f"Evaluating {eval_split}...",
1608
- position=2,
1609
- disable=not accelerator.is_local_main_process,
1610
- ):
1611
- # Model forward
1612
- eval_metric = eval_step(batch)
1613
- eval_metric = accelerator.gather_for_metrics(eval_metric)
1614
- eval_metrics.append(eval_metric)
1615
-
1616
- # generation
1617
- if training_args.predict_with_generate:
1618
- generated_ids = generate_step(batch)
1619
- # Gather all predictions and targets
1620
- generated_ids, labels = accelerator.gather_for_metrics(
1621
- (generated_ids, batch["labels"])
1622
- )
1623
- eval_preds.extend(generated_ids)
1624
- eval_labels.extend(labels)
1625
-
1626
- eval_time = time.time() - eval_start
1627
- # normalize eval metrics
1628
- eval_metrics = {
1629
- key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1630
- }
1631
-
1632
- # compute WER metric
1633
- wer_desc = ""
1634
- if training_args.predict_with_generate:
1635
- wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1636
- eval_preds, eval_labels
1637
- )
1638
- eval_metrics.update(wer_metric)
1639
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1640
- log_pred(
1641
- accelerator,
1642
- pred_str,
1643
- label_str,
1644
- norm_pred_str,
1645
- norm_label_str,
1646
- step=cur_step,
1647
- prefix=eval_split,
1648
- )
1649
-
1650
- # Print metrics and update progress bar
1651
- steps_trained_progress_bar.write(
1652
- f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1653
- f" {wer_desc})"
1654
- )
1655
-
1656
- log_metric(
1657
- accelerator,
1658
- metrics=eval_metrics,
1659
- train_time=eval_time,
1660
- step=cur_step,
1661
- epoch=epoch,
1662
- prefix=eval_split,
1663
- )
1664
-
1665
- # flush the train metrics
1666
- train_start = time.time()
1667
-
1668
- # break condition
1669
- if cur_step == total_train_steps:
1670
-
1671
- # un-wrap student model for save
1672
- student_model = accelerator.unwrap_model(student_model)
1673
- student_model.save_pretrained(training_args.output_dir)
1674
-
1675
- if training_args.push_to_hub:
1676
- upload_folder(
1677
- folder_path=training_args.output_dir,
1678
- repo_id=repo_name,
1679
- repo_type="model",
1680
- commit_message=f"Saving final weights of step {cur_step}",
1681
- )
1682
-
1683
- continue_training = False
1684
- break
1685
-
1686
- if not continue_training:
1687
- break
1688
-
1689
- accelerator.end_training()
1690
 
1691
 
1692
  if __name__ == "__main__":
 
750
  else:
751
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
752
 
753
+
754
+
755
  # 2. Initialize the accelerator
756
  # We will let the accelerator handle device placement for us in this example
757
  # We simply have to specify the training precision and any trackers being used
758
  # We'll use the same dtype arguments as our JAX/Flax training script and convert
759
  # it to accelerate format
760
+
761
  if training_args.dtype == "float16":
762
  mixed_precision = "fp16"
763
  teacher_dtype = torch.float16
 
1010
  )
1011
  else:
1012
  is_multilingual = False
1013
+
1014
+ print(f" is_multilingual : {is_multilingual}")
1015
 
1016
  # 8. Create a single speech processor - make sure all processes wait until data is saved
1017
+ # if accelerator.is_main_process:
1018
+ # feature_extractor.save_pretrained(training_args.output_dir)
1019
+ # tokenizer.save_pretrained(training_args.output_dir)
1020
+ # # save the config and generation config as well
1021
+ # config.save_pretrained(training_args.output_dir)
1022
+ # student_model.generation_config.save_pretrained(training_args.output_dir)
1023
+
1024
+ # accelerator.wait_for_everyone()
1025
+ # processor = WhisperProcessor.from_pretrained(training_args.output_dir)
1026
+
1027
+ # # 9. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
1028
+ # # so we just need to set the correct target sampling rate.
1029
+ # sampling_rate = feature_extractor.sampling_rate
1030
+ # raw_datasets = raw_datasets.cast_column(
1031
+ # data_args.audio_column_name,
1032
+ # datasets.features.Audio(sampling_rate=sampling_rate),
1033
+ # )
1034
+
1035
+ # # 10. Preprocessing the datasets: we need to read the audio files as arrays and tokenize the targets.
1036
+ # # 10.1: Define the pre-processing constants
1037
+ # max_input_length = int(data_args.max_duration_in_seconds * sampling_rate)
1038
+ # min_input_length = int(data_args.min_duration_in_seconds * sampling_rate)
1039
+ # max_label_length = (
1040
+ # data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
1041
+ # )
1042
+
1043
+ # timestamp_probability = data_args.timestamp_probability
1044
+ # condition_on_prev_probability = data_args.condition_on_prev_probability
1045
+ # return_timestamps = data_args.return_timestamps if timestamp_probability > 0 else False
1046
+
1047
+ # timestamp_ids = tokenizer.timestamp_ids()
1048
+ # timestamp_begin = tokenizer.all_special_ids[-1]
1049
+ # timestamp_position = 3 if is_multilingual else 1
1050
+
1051
+ # decoder_start_token_id = student_model.config.decoder_start_token_id # <|startoftranscript|>
1052
+ # decoder_prev_token_id = tokenizer.all_special_ids[-3] # <|startofprev|>
1053
+ # prompt_cutoff_length = max_label_length // 2
1054
+
1055
+ # num_workers = data_args.preprocessing_num_workers
1056
+ # dataloader_num_workers = training_args.dataloader_num_workers
1057
+ # prefetch_factor = training_args.dataloader_prefetch_factor
1058
+
1059
+ # metric = evaluate.load("wer")
1060
+ # normalizer = (
1061
+ # BasicTextNormalizer()
1062
+ # if data_args.language is not None
1063
+ # else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
1064
+ # )
1065
+ # wer_threshold = data_args.wer_threshold
1066
+ # use_pseudo_labels = data_args.use_pseudo_labels
1067
+ # train_text_column_name = "whisper_transcript" if use_pseudo_labels else "text"
1068
+
1069
+ # # 10.2: filter based on maximum number of training/evaluation samples
1070
+ # if training_args.do_train and data_args.max_train_samples is not None:
1071
+ # raw_datasets["train"] = (
1072
+ # raw_datasets["train"].take(data_args.max_train_samples)
1073
+ # if data_args.streaming
1074
+ # else raw_datasets["train"].select(range(data_args.max_train_samples))
1075
+ # )
1076
+
1077
+ # if training_args.do_eval and data_args.max_eval_samples is not None:
1078
+ # for eval_split in all_eval_splits:
1079
+ # raw_datasets[eval_split] = (
1080
+ # raw_datasets[eval_split].take(data_args.max_eval_samples)
1081
+ # if data_args.streaming
1082
+ # else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
1083
+ # )
1084
+
1085
+ # # 10.3: filter training data based on WER threshold -> this is KEY to good distillation performance
1086
+ # def is_wer_in_range(ground_truth, whisper_transcript):
1087
+ # norm_ground_truth = normalizer(ground_truth)
1088
+ # if whisper_transcript is not None and whisper_transcript.upper() == whisper_transcript:
1089
+ # # filter entirely upper-case transcriptions: these are erroneous generations from large-v3
1090
+ # return False
1091
+ # elif len(norm_ground_truth) > 0 and whisper_transcript is not None:
1092
+ # norm_whisper_transcript = normalizer(whisper_transcript)
1093
+ # wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
1094
+ # return wer < wer_threshold
1095
+ # else:
1096
+ # # filter automatically since we can't know the WER
1097
+ # return False
1098
+
1099
+ # filter_by_wer_threshold = partial(
1100
+ # raw_datasets["train"].filter,
1101
+ # function=is_wer_in_range,
1102
+ # input_columns=["text", "whisper_transcript"],
1103
+ # )
1104
+
1105
+ # if wer_threshold is not None and use_pseudo_labels:
1106
+ # with accelerator.main_process_first():
1107
+ # raw_datasets["train"] = (
1108
+ # filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
1109
+ # if not data_args.streaming
1110
+ # else filter_by_wer_threshold()
1111
+ # )
1112
+
1113
+ # # 10.4: pre-process training/evaluation datasets
1114
+ # def prepare_train_dataset(batch):
1115
+ # """
1116
+ # Pre-process the raw dataset in a three stage process:
1117
+ # 1. Convert the audio arrays to log-mel spectrogram inputs
1118
+ # 2. Possibly filter the timestamp tokens from the token ids (depending on the timestamp probability)
1119
+ # 3. Possibly add prompt tokens if conditioning on previous text (depending on the conditioning probability)
1120
+ # """
1121
+ # # process audio input
1122
+ # audio = [sample["array"] for sample in batch["audio"]]
1123
+ # inputs = feature_extractor(audio, sampling_rate=sampling_rate)
1124
+ # batch["input_features"] = inputs.input_features
1125
+ # batch["input_length"] = [len(sample) for sample in audio]
1126
+
1127
+ # # process text targets - for training these are the Whisper-generated pseudo-labels
1128
+ # input_str_batched = batch[train_text_column_name]
1129
+ # condition_on_prev_batched = batch.get("condition_on_prev", len(input_str_batched) * [None])
1130
+
1131
+ # all_token_ids = []
1132
+ # all_token_ids_unprompted = []
1133
+ # for prev_ids, input_str in zip(condition_on_prev_batched, input_str_batched):
1134
+ # token_ids = tokenizer(input_str, add_special_tokens=not use_pseudo_labels).input_ids
1135
+
1136
+ # # check whether we have timestamps in the PLs and filter if required
1137
+ # has_timestamps = len(set(token_ids) & set(timestamp_ids)) > 0
1138
+ # if has_timestamps:
1139
+ # # sample from binomial distribution to get probability of training on timestamps
1140
+ # predict_timestamps = bool(np.random.binomial(1, timestamp_probability))
1141
+ # if not predict_timestamps:
1142
+ # # filter timestamps and insert the <|notimestamps|> task token
1143
+ # token_ids = [token for token in token_ids if token < timestamp_begin]
1144
+ # token_ids.insert(timestamp_position, timestamp_begin)
1145
+
1146
+ # all_token_ids_unprompted.append(token_ids)
1147
+ # # check whether to condition on previous text - we do this with probability condition_on_prev_probability
1148
+ # condition_on_prev = bool(np.random.binomial(1, condition_on_prev_probability))
1149
+ # if not condition_on_prev:
1150
+ # prev_ids = None
1151
+ # elif "condition_on_prev" not in batch and len(all_token_ids_unprompted) > 1:
1152
+ # # prompt ids are the penultimate token ids in the batch
1153
+ # prev_ids = all_token_ids_unprompted[-2]
1154
+
1155
+ # if prev_ids is not None:
1156
+ # if has_timestamps and not predict_timestamps:
1157
+ # # filter timestamp ids from prompt when not predicting timestamps
1158
+ # prev_ids = [token for token in prev_ids if token < timestamp_begin]
1159
+
1160
+ # # check that the length of the prompt does not exceed more than half the max label length (224)
1161
+ # if len(prev_ids) > prompt_cutoff_length:
1162
+ # prev_ids = prev_ids[-prompt_cutoff_length + 1 :]
1163
+ # prev_ids = [decoder_prev_token_id] + prev_ids
1164
+
1165
+ # # and that the total length of the labels does not exceed the max label length (448)
1166
+ # if len(prev_ids + token_ids) > max_label_length:
1167
+ # trim_length = len(prev_ids + token_ids) - max_label_length + 1
1168
+ # prev_ids = prev_ids[trim_length:]
1169
+ # prev_ids = [decoder_prev_token_id] + prev_ids
1170
+
1171
+ # token_ids = prev_ids + token_ids
1172
+
1173
+ # all_token_ids.append(token_ids)
1174
+
1175
+ # batch["labels"] = all_token_ids
1176
+ # return batch
1177
+
1178
+ # def prepare_eval_dataset(batch):
1179
+ # # process audio input
1180
+ # sample = batch["audio"]
1181
+ # inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
1182
+ # batch["input_features"] = inputs.input_features[0]
1183
+ # batch["input_length"] = len(sample["array"])
1184
+
1185
+ # # process targets - for evaluation these are the ground-truth transcriptions
1186
+ # input_str = batch["text"]
1187
+ # batch["labels"] = tokenizer(input_str).input_ids
1188
+ # return batch
1189
+
1190
+ # vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
1191
+ # if training_args.do_train:
1192
+ # # with streaming mode we can only have 1 worker, whereas with non-streaming
1193
+ # # we can use `num_workers` (which is much faster)
1194
+ # # We gate the pre-processing function accordingly
1195
+ # map_fn_train = partial(
1196
+ # raw_datasets["train"].map,
1197
+ # function=prepare_train_dataset,
1198
+ # remove_columns=raw_datasets_train_features,
1199
+ # batched=True,
1200
+ # batch_size=data_args.preprocessing_batch_size,
1201
+ # )
1202
+ # with accelerator.main_process_first():
1203
+ # vectorized_datasets["train"] = (
1204
+ # map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
1205
+ # if not data_args.streaming
1206
+ # else map_fn_train()
1207
+ # )
1208
+ # if training_args.do_eval:
1209
+ # for eval_split in all_eval_splits:
1210
+ # raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1211
+ # map_fn_eval = partial(
1212
+ # raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1213
+ # )
1214
+ # with accelerator.main_process_first():
1215
+ # vectorized_datasets[eval_split] = (
1216
+ # map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
1217
+ # if not data_args.streaming
1218
+ # else map_fn_eval()
1219
+ # )
1220
+
1221
+ # # 10.5: Filter training data with inputs longer than `max_input_length`
1222
+ # def is_audio_in_length_range(length):
1223
+ # return min_input_length < length < max_input_length
1224
+
1225
+ # filter_by_audio_fn = partial(
1226
+ # vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1227
+ # )
1228
+ # with accelerator.main_process_first():
1229
+ # vectorized_datasets = (
1230
+ # filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1231
+ # if not data_args.streaming
1232
+ # else filter_by_audio_fn()
1233
+ # )
1234
+
1235
+ # # 10.6: Filter training data with labels longer than `max_label_length`
1236
+ # def is_labels_in_length_range(labels):
1237
+ # return 0 < len(labels) <= max_label_length
1238
+
1239
+ # filter_by_labels_fn = partial(
1240
+ # vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1241
+ # )
1242
+ # with accelerator.main_process_first():
1243
+ # vectorized_datasets = (
1244
+ # filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1245
+ # if not data_args.streaming
1246
+ # else filter_by_labels_fn()
1247
+ # )
1248
+
1249
+ # # Pre-processing complete!
1250
+ # # For large datasets it is advised to run the preprocessing on a
1251
+ # # single machine first with `--preprocessing_only` since there will mostly likely
1252
+ # # be a timeout when running the script in distributed mode.
1253
+ # # In a second step, `--preprocessing_only` can then be set to `False` to load the
1254
+ # # cached dataset
1255
+ # if data_args.preprocessing_only:
1256
+ # if data_args.streaming:
1257
+ # raise ValueError(
1258
+ # "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1259
+ # "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1260
+ # "on the fly with streaming mode."
1261
+ # )
1262
+ # cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1263
+ # logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1264
+ # return
1265
+
1266
+ # # 11. Define Evaluation Metrics
1267
+ # def compute_metrics(preds, labels):
1268
+ # # replace padded labels by the padding token
1269
+ # for idx in range(len(labels)):
1270
+ # labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1271
+
1272
+ # pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1273
+ # # we do not want to group tokens when computing the metrics
1274
+ # label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1275
+ # wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1276
+
1277
+ # # normalize everything and re-compute the WER
1278
+ # norm_pred_str = [normalizer(pred) for pred in pred_str]
1279
+ # norm_label_str = [normalizer(label) for label in label_str]
1280
+ # # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1281
+ # pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1282
+ # label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1283
+ # # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1284
+ # norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1285
+ # norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1286
+
1287
+ # wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1288
+ # return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1289
+
1290
+ # # 12. Define Training Schedule
1291
+ # # Store some constants
1292
+ # per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1293
+ # train_batch_size = per_device_train_batch_size * accelerator.num_processes
1294
+ # gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1295
+ # per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1296
+
1297
+ # if not data_args.streaming and training_args.max_steps < 0:
1298
+ # num_epochs = int(training_args.num_train_epochs)
1299
+ # steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1300
+ # total_train_steps = steps_per_epoch * num_epochs
1301
+ # elif training_args.max_steps > 0:
1302
+ # logger.info("max_steps is given, it will override any value given in num_train_epochs")
1303
+ # total_train_steps = int(training_args.max_steps)
1304
+ # if not data_args.streaming:
1305
+ # steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1306
+ # num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1307
+ # else:
1308
+ # # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1309
+ # num_epochs = sys.maxsize
1310
+ # steps_per_epoch = total_train_steps
1311
+ # else:
1312
+ # raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1313
+
1314
+ # if training_args.eval_steps is None:
1315
+ # logger.info(
1316
+ # f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1317
+ # )
1318
+ # eval_steps = steps_per_epoch
1319
+ # else:
1320
+ # eval_steps = training_args.eval_steps
1321
+
1322
+ # # 13. Define optimizer, LR scheduler, collator
1323
+ # decay_parameters = get_parameter_names(
1324
+ # student_model,
1325
+ # [nn.LayerNorm],
1326
+ # forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1327
+ # )
1328
+ # decay_parameters = [name for name in decay_parameters if "bias" not in name]
1329
+ # optimizer_grouped_parameters = [
1330
+ # {
1331
+ # "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1332
+ # "weight_decay": training_args.weight_decay,
1333
+ # },
1334
+ # {
1335
+ # "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1336
+ # "weight_decay": 0.0,
1337
+ # },
1338
+ # ]
1339
+ # optimizer = torch.optim.AdamW(
1340
+ # params=optimizer_grouped_parameters,
1341
+ # lr=training_args.learning_rate,
1342
+ # betas=(training_args.adam_beta1, training_args.adam_beta2),
1343
+ # eps=training_args.adam_epsilon,
1344
+ # )
1345
+
1346
+ # # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1347
+ # lr_scheduler = get_scheduler(
1348
+ # name=training_args.lr_scheduler_type,
1349
+ # optimizer=optimizer,
1350
+ # num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1351
+ # num_training_steps=total_train_steps * accelerator.num_processes,
1352
+ # )
1353
+
1354
+ # data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1355
+ # processor=processor,
1356
+ # decoder_start_token_id=decoder_start_token_id,
1357
+ # decoder_prev_token_id=decoder_prev_token_id,
1358
+ # input_padding="longest",
1359
+ # target_padding="max_length",
1360
+ # max_target_length=max_label_length,
1361
+ # )
1362
+
1363
+ # # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1364
+ # # so that we can still access the configs
1365
+ # num_beams = (
1366
+ # training_args.generation_num_beams
1367
+ # if training_args.generation_num_beams is not None
1368
+ # else getattr(student_model.generation_config, "num_beams", 1)
1369
+ # )
1370
+
1371
+ # gen_kwargs = {
1372
+ # "max_length": max_label_length,
1373
+ # "num_beams": num_beams,
1374
+ # "return_timestamps": return_timestamps,
1375
+ # }
1376
+ # if is_multilingual:
1377
+ # # forcing the language and task tokens helps multilingual models in their generations
1378
+ # gen_kwargs.update(
1379
+ # {
1380
+ # "language": data_args.language,
1381
+ # "task": data_args.task,
1382
+ # }
1383
+ # )
1384
+
1385
+ # # 15. Prepare everything with accelerate
1386
+ # student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1387
+ # student_model, teacher_model, optimizer, lr_scheduler
1388
+ # )
1389
+
1390
+ # def kl_divergence(target_distribution, log_predicted_distribution, labels):
1391
+ # kl_loss = nn.KLDivLoss(reduction="none")
1392
+ # divergence = kl_loss(log_predicted_distribution, target_distribution)
1393
+ # # ignore padded tokens from divergence, i.e. where labels are not set to -100
1394
+ # padding_mask = labels >= 0
1395
+ # padding_mask = padding_mask.unsqueeze(-1)
1396
+ # divergence = divergence * padding_mask
1397
+ # # take the average over the mini-batch
1398
+ # divergence = divergence.sum() / padding_mask.sum()
1399
+ # return divergence
1400
+
1401
+ # # Define gradient update step fn
1402
+ # def train_step(
1403
+ # batch,
1404
+ # temperature=2.0,
1405
+ # ):
1406
+ # student_model.train()
1407
+ # teacher_model.eval()
1408
+
1409
+ # student_outputs = student_model(**batch)
1410
+ # with torch.no_grad():
1411
+ # if share_hidden_states:
1412
+ # # if the student and teacher share the same frozen encoder then we don't have to recompute the
1413
+ # # encoder hidden-states for the teacher model, we can just re-use from the student
1414
+ # encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1415
+ # teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1416
+ # else:
1417
+ # # do the full forward pass for the teacher model (encoder + decoder)
1418
+ # teacher_outputs = teacher_model(**batch)
1419
+
1420
+ # # CE (data) loss
1421
+ # ce_loss = student_outputs.loss
1422
+ # # rescale distribution by temperature to ensure gradients scale correctly
1423
+ # teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1424
+ # # log softmax of student predictions for numerical stability
1425
+ # student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1426
+ # # KL-divergence loss (scaled by temperature)
1427
+ # kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1428
+
1429
+ # # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1430
+ # loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1431
+ # metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1432
+ # return loss, metrics
1433
+
1434
+ # # Define eval fn
1435
+ # def eval_step(batch):
1436
+ # student_model.eval()
1437
+ # teacher_model.eval()
1438
+
1439
+ # with torch.no_grad():
1440
+ # student_outputs = student_model(**batch)
1441
+ # if share_hidden_states:
1442
+ # encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1443
+ # teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1444
+ # else:
1445
+ # teacher_outputs = teacher_model(**batch)
1446
+
1447
+ # # CE (data) loss
1448
+ # ce_loss = student_outputs.loss
1449
+
1450
+ # # log softmax / softmax for numerical stability
1451
+ # student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1452
+ # teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1453
+ # # temperature is always 1 for eval
1454
+ # kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1455
+
1456
+ # # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1457
+ # loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1458
+ # metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1459
+ # return metrics
1460
+
1461
+ # def generate_step(batch):
1462
+ # student_model.eval()
1463
+ # output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1464
+ # output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1465
+ # return output_ids
1466
+
1467
+ # logger.info("***** Running training *****")
1468
+ # logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
1469
+ # if not data_args.streaming:
1470
+ # logger.info(f" Num epochs = {num_epochs}")
1471
+ # logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1472
+ # logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1473
+ # logger.info(
1474
+ # f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1475
+ # )
1476
+ # logger.info(f" Total optimization steps = {total_train_steps}")
1477
+
1478
+ # # ======================== Training ================================
1479
+ # train_time = 0
1480
+ # train_start = time.time()
1481
+ # steps_trained_progress_bar = tqdm(
1482
+ # range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1483
+ # )
1484
+ # continue_training = True
1485
+ # epochs_trained = 0
1486
+ # cur_step = 0
1487
+
1488
+ # checkpoint = None
1489
+ # if training_args.resume_from_checkpoint is not None:
1490
+ # checkpoint = training_args.resume_from_checkpoint
1491
+ # elif last_checkpoint is not None:
1492
+ # checkpoint = last_checkpoint
1493
+
1494
+ # if checkpoint is not None:
1495
+ # accelerator.load_state(checkpoint)
1496
+ # # Find num steps and epoch from saved state string pattern
1497
+ # pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1498
+ # match = re.search(pattern, checkpoint)
1499
+ # cur_step = int(match.group(1))
1500
+ # epochs_trained = int(match.group(2))
1501
+
1502
+ # logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1503
+ # logger.info(f" Continuing training from epoch {epochs_trained}")
1504
+ # logger.info(f" Continuing training from global step {cur_step}")
1505
+
1506
+ # steps_trained_progress_bar.update(cur_step)
1507
+
1508
+ # for epoch in range(0, epochs_trained):
1509
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1510
+
1511
+ # if not data_args.streaming and training_args.max_steps < 0:
1512
+ # # we know exactly the number of steps per epoch, so can skip through the required number of batches
1513
+ # resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1514
+ # else:
1515
+ # # Currently we don't know how many steps we've taken in the current epoch
1516
+ # # So we just shuffle the dataset one extra time and start from a fresh epoch
1517
+ # # This is "good enough" for our purposes but not fully correct
1518
+ # resume_step = None
1519
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1520
+ # else:
1521
+ # resume_step = None
1522
+
1523
+ # for epoch in range(epochs_trained, num_epochs):
1524
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1525
+ # train_dataloader = DataLoader(
1526
+ # vectorized_datasets["train"],
1527
+ # collate_fn=data_collator,
1528
+ # batch_size=per_device_train_batch_size,
1529
+ # num_workers=dataloader_num_workers,
1530
+ # prefetch_factor=prefetch_factor,
1531
+ # pin_memory=training_args.dataloader_pin_memory,
1532
+ # )
1533
+ # train_dataloader = accelerator.prepare(train_dataloader)
1534
+ # if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1535
+ # train_dataloader.dataset.set_epoch(epoch)
1536
+
1537
+ # if resume_step is not None:
1538
+ # # Skip the first N batches in the dataloader when resuming from a checkpoint
1539
+ # train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1540
+ # resume_step = None
1541
+
1542
+ # for batch in train_dataloader:
1543
+ # with accelerator.accumulate(student_model):
1544
+ # loss, train_metric = train_step(batch, temperature=training_args.temperature)
1545
+ # accelerator.backward(loss)
1546
+ # if accelerator.sync_gradients:
1547
+ # accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1548
+ # optimizer.step()
1549
+ # lr_scheduler.step()
1550
+ # optimizer.zero_grad()
1551
+
1552
+ # # Check if the accelerator has performed an optimization step behind the scenes
1553
+ # if accelerator.sync_gradients:
1554
+ # steps_trained_progress_bar.update(1)
1555
+ # cur_step += 1
1556
+
1557
+ # if cur_step % training_args.logging_steps == 0:
1558
+ # steps_trained_progress_bar.write(
1559
+ # f"Step... ({cur_step} / {total_train_steps} | Loss:"
1560
+ # f" {train_metric['loss']}, Learning Rate:"
1561
+ # f" {lr_scheduler.get_last_lr()[0]})"
1562
+ # )
1563
+ # log_metric(
1564
+ # accelerator,
1565
+ # metrics=train_metric,
1566
+ # learning_rate=lr_scheduler.get_last_lr()[0],
1567
+ # train_time=train_time + time.time() - train_start,
1568
+ # step=cur_step,
1569
+ # epoch=epoch,
1570
+ # prefix="train",
1571
+ # )
1572
+
1573
+ # # save checkpoint and weights after each save_steps and at the end of training
1574
+ # if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1575
+ # intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1576
+ # accelerator.save_state(output_dir=intermediate_dir)
1577
+ # accelerator.wait_for_everyone()
1578
+ # if accelerator.is_main_process:
1579
+ # rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1580
+
1581
+ # if training_args.push_to_hub:
1582
+ # upload_folder(
1583
+ # folder_path=training_args.output_dir,
1584
+ # repo_id=repo_name,
1585
+ # repo_type="model",
1586
+ # commit_message=f"Saving train state of step {cur_step}",
1587
+ # )
1588
+
1589
+ # if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1590
+ # train_time += time.time() - train_start
1591
+ # student_model.eval()
1592
+ # # ======================== Evaluating ==============================
1593
+ # for eval_split in all_eval_splits:
1594
+ # eval_metrics = []
1595
+ # eval_preds = []
1596
+ # eval_labels = []
1597
+ # eval_start = time.time()
1598
+
1599
+ # validation_dataloader = DataLoader(
1600
+ # vectorized_datasets[eval_split],
1601
+ # collate_fn=data_collator,
1602
+ # batch_size=per_device_eval_batch_size,
1603
+ # drop_last=False,
1604
+ # num_workers=dataloader_num_workers,
1605
+ # prefetch_factor=prefetch_factor,
1606
+ # pin_memory=training_args.dataloader_pin_memory,
1607
+ # )
1608
+ # validation_dataloader = accelerator.prepare(validation_dataloader)
1609
+
1610
+ # for batch in tqdm(
1611
+ # validation_dataloader,
1612
+ # desc=f"Evaluating {eval_split}...",
1613
+ # position=2,
1614
+ # disable=not accelerator.is_local_main_process,
1615
+ # ):
1616
+ # # Model forward
1617
+ # eval_metric = eval_step(batch)
1618
+ # eval_metric = accelerator.gather_for_metrics(eval_metric)
1619
+ # eval_metrics.append(eval_metric)
1620
+
1621
+ # # generation
1622
+ # if training_args.predict_with_generate:
1623
+ # generated_ids = generate_step(batch)
1624
+ # # Gather all predictions and targets
1625
+ # generated_ids, labels = accelerator.gather_for_metrics(
1626
+ # (generated_ids, batch["labels"])
1627
+ # )
1628
+ # eval_preds.extend(generated_ids)
1629
+ # eval_labels.extend(labels)
1630
+
1631
+ # eval_time = time.time() - eval_start
1632
+ # # normalize eval metrics
1633
+ # eval_metrics = {
1634
+ # key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1635
+ # }
1636
+
1637
+ # # compute WER metric
1638
+ # wer_desc = ""
1639
+ # if training_args.predict_with_generate:
1640
+ # wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1641
+ # eval_preds, eval_labels
1642
+ # )
1643
+ # eval_metrics.update(wer_metric)
1644
+ # wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1645
+ # log_pred(
1646
+ # accelerator,
1647
+ # pred_str,
1648
+ # label_str,
1649
+ # norm_pred_str,
1650
+ # norm_label_str,
1651
+ # step=cur_step,
1652
+ # prefix=eval_split,
1653
+ # )
1654
+
1655
+ # # Print metrics and update progress bar
1656
+ # steps_trained_progress_bar.write(
1657
+ # f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1658
+ # f" {wer_desc})"
1659
+ # )
1660
+
1661
+ # log_metric(
1662
+ # accelerator,
1663
+ # metrics=eval_metrics,
1664
+ # train_time=eval_time,
1665
+ # step=cur_step,
1666
+ # epoch=epoch,
1667
+ # prefix=eval_split,
1668
+ # )
1669
+
1670
+ # # flush the train metrics
1671
+ # train_start = time.time()
1672
+
1673
+ # # break condition
1674
+ # if cur_step == total_train_steps:
1675
+
1676
+ # # un-wrap student model for save
1677
+ # student_model = accelerator.unwrap_model(student_model)
1678
+ # student_model.save_pretrained(training_args.output_dir)
1679
+
1680
+ # if training_args.push_to_hub:
1681
+ # upload_folder(
1682
+ # folder_path=training_args.output_dir,
1683
+ # repo_id=repo_name,
1684
+ # repo_type="model",
1685
+ # commit_message=f"Saving final weights of step {cur_step}",
1686
+ # )
1687
+
1688
+ # continue_training = False
1689
+ # break
1690
+
1691
+ # if not continue_training:
1692
+ # break
1693
+
1694
+ # accelerator.end_training()
1695
 
1696
 
1697
  if __name__ == "__main__":
__pycache__/evaluate.cpython-39.pyc ADDED
Binary file (142 Bytes). View file
 
distil-whisper/events.out.tfevents.1715057787.server02.1349950.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fb8763b003a4a4d0209179e68aac6e43453e4693f8cee09cd3a53b74ae1f1fa
3
+ size 88
distil-whisper/events.out.tfevents.1715063050.server02.1368197.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad00996d03cdb169f2976796433bb5c0f4b367ecbe8b4ae2c0b22e7472f45793
3
+ size 88
distil-whisper/events.out.tfevents.1715063266.server02.1369570.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfe9e8c42d8e2d53b9e63b1b235b356699c4986e401482da4f033bee21824cbe
3
+ size 88
distil-whisper/events.out.tfevents.1715063402.server02.1370564.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6d637cc9f5a873ec11177c4509c741f2d4e5a13099a60dc50e722fb95533961
3
+ size 88
distil-whisper/events.out.tfevents.1715063677.server02.1372191.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55b02873c18707d53d847d8e5b6cb1df617e81977e42349876eb8abc83573afd
3
+ size 88
distil-whisper/events.out.tfevents.1715063742.server02.1372871.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d8c2457fceadefc49be2acdfd53edc83a1a778537b28c2b81a86b591dc464f8
3
+ size 88
distil-whisper/events.out.tfevents.1715064564.server02.1376229.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d38fabb5cd9df9bbdb7b8ebea98ed65d62d7dd69515c9b1e730923dc12733a10
3
+ size 88
distil-whisper/events.out.tfevents.1715065478.server02.1379863.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4306c5a95d18085ad3d309c137786b0a2bbe4bebda405c86b23a6617a11b10a5
3
+ size 392
run_distillation.py CHANGED
@@ -1010,8 +1010,10 @@ def main():
1010
  )
1011
  else:
1012
  is_multilingual = False
 
 
1013
 
1014
- # 8. Create a single speech processor - make sure all processes wait until data is saved
1015
  if accelerator.is_main_process:
1016
  feature_extractor.save_pretrained(training_args.output_dir)
1017
  tokenizer.save_pretrained(training_args.output_dir)
@@ -1379,8 +1381,8 @@ def main():
1379
  "task": data_args.task,
1380
  }
1381
  )
1382
-
1383
- # 15. Prepare everything with accelerate
1384
  student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1385
  student_model, teacher_model, optimizer, lr_scheduler
1386
  )
@@ -1485,7 +1487,7 @@ def main():
1485
 
1486
  checkpoint = None
1487
  if training_args.resume_from_checkpoint is not None:
1488
- checkpoint = training_args.resume_from_checkpoint
1489
  elif last_checkpoint is not None:
1490
  checkpoint = last_checkpoint
1491
 
@@ -1694,3 +1696,6 @@ def main():
1694
 
1695
  if __name__ == "__main__":
1696
  main()
 
 
 
 
1010
  )
1011
  else:
1012
  is_multilingual = False
1013
+
1014
+ print(f" is_multilingual : {is_multilingual}")
1015
 
1016
+ #8. Create a single speech processor - make sure all processes wait until data is saved
1017
  if accelerator.is_main_process:
1018
  feature_extractor.save_pretrained(training_args.output_dir)
1019
  tokenizer.save_pretrained(training_args.output_dir)
 
1381
  "task": data_args.task,
1382
  }
1383
  )
1384
+ print(f" gen_kwargs : {gen_kwargs}")
1385
+ #15. Prepare everything with accelerate
1386
  student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1387
  student_model, teacher_model, optimizer, lr_scheduler
1388
  )
 
1487
 
1488
  checkpoint = None
1489
  if training_args.resume_from_checkpoint is not None:
1490
+ checkpoint = training_args.resume_from_checkpoint
1491
  elif last_checkpoint is not None:
1492
  checkpoint = last_checkpoint
1493
 
 
1696
 
1697
  if __name__ == "__main__":
1698
  main()
1699
+ '''
1700
+ accelerate launch --mixed_precision=bf16 run_distillation.py --model_name_or_path "./distil-large-v3-init" --teacher_model_name_or_path "openai/whisper-large-v3" --train_dataset_name "mozilla-foundation/common_voice_15_0" --train_dataset_config_name "de" --train_split_name "train" --text_column_name "sentence" --eval_dataset_name "mozilla-foundation/common_voice_15_0" --eval_dataset_config_name "de" --eval_split_name "validation" --eval_text_column_name "sentence" --eval_steps 500 --save_steps 50 --warmup_steps 500 --learning_rate 1e-4 --lr_scheduler_type "linear" --logging_steps 25 --save_total_limit 1 --max_steps 500 --per_device_train_batch_size 4 --per_device_eval_batch_size 2 --dataloader_num_workers 2 --preprocessing_num_workers 2 --ddp_timeout 7200 --dtype "bfloat16" --output_dir "./" --use_pseudo_labels "false" --condition_on_prev_probability "0.0" --do_train --do_eval --gradient_checkpointing --overwrite_output_dir --predict_with_generate --freeze_encoder --streaming --push_to_hub
1701
+ '''
run_evaluate.py ADDED
File without changes
tokenizer.json CHANGED
@@ -14501,6 +14501,12 @@
14501
  "type_id": 0
14502
  }
14503
  },
 
 
 
 
 
 
14504
  {
14505
  "SpecialToken": {
14506
  "id": "<|transcribe|>",
@@ -14533,6 +14539,12 @@
14533
  "type_id": 0
14534
  }
14535
  },
 
 
 
 
 
 
14536
  {
14537
  "SpecialToken": {
14538
  "id": "<|transcribe|>",
@@ -14565,6 +14577,15 @@
14565
  }
14566
  ],
14567
  "special_tokens": {
 
 
 
 
 
 
 
 
 
14568
  "<|endoftext|>": {
14569
  "id": "<|endoftext|>",
14570
  "ids": [
 
14501
  "type_id": 0
14502
  }
14503
  },
14504
+ {
14505
+ "SpecialToken": {
14506
+ "id": "<|de|>",
14507
+ "type_id": 0
14508
+ }
14509
+ },
14510
  {
14511
  "SpecialToken": {
14512
  "id": "<|transcribe|>",
 
14539
  "type_id": 0
14540
  }
14541
  },
14542
+ {
14543
+ "SpecialToken": {
14544
+ "id": "<|de|>",
14545
+ "type_id": 0
14546
+ }
14547
+ },
14548
  {
14549
  "SpecialToken": {
14550
  "id": "<|transcribe|>",
 
14577
  }
14578
  ],
14579
  "special_tokens": {
14580
+ "<|de|>": {
14581
+ "id": "<|de|>",
14582
+ "ids": [
14583
+ 50261
14584
+ ],
14585
+ "tokens": [
14586
+ "<|de|>"
14587
+ ]
14588
+ },
14589
  "<|endoftext|>": {
14590
  "id": "<|endoftext|>",
14591
  "ids": [