supawichwac commited on
Commit
7591c75
1 Parent(s): 1d63fc4

Saving train state of step 30000

Browse files
checkpoint-30000-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c31d18417e3e13a2b79e96d44b8d2606c5959da8b343e76537d86e347ef699e0
3
+ size 3025686376
checkpoint-30000-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b395c8a7e2bda655c415580106288d0387c227efd641bf4e11c1cd735fdb37a
3
+ size 4361070048
checkpoint-30000-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf15242062ca5376a7a4b6d4c62824351fc03bf226f26e3ebce4c39d0fda992c
3
+ size 955539578
checkpoint-30000-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9aba43da2b6b6c5db39f9e95c1de6261bae932477b796a2c7647da423d6f691b
3
+ size 14344
checkpoint-30000-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e57843c87ed32da9817e4fc3151d8fac1890f0df43086ca762177a37f6f342d
3
+ size 1064
distil-whisper/events.out.tfevents.1715222264.server02.2131186.0 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fcd7dd5689696438e3575d4938acdac4d83528f1b877b5d3513be5844455f9b1
3
- size 313523
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be15824155b2dfea692fe22799842b0a069874a3c52f787c080d056a07612fbe
3
+ size 377077
run_distillation.py CHANGED
@@ -1219,7 +1219,7 @@ def main():
1219
  if training_args.do_eval:
1220
  for eval_split in all_eval_splits:
1221
  raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1222
- map_fn_eval = partial(
1223
  raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1224
  )
1225
  with accelerator.main_process_first():
@@ -1229,327 +1229,430 @@ def main():
1229
  else map_fn_eval()
1230
  )
1231
 
1232
-
1233
- # 10.5: Filter training data with inputs longer than `max_input_length`
1234
- def is_audio_in_length_range(length):
1235
- return min_input_length < length < max_input_length
1236
-
1237
- filter_by_audio_fn = partial(
1238
- vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1239
- )
1240
- with accelerator.main_process_first():
1241
- vectorized_datasets = (
1242
- filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1243
- if not data_args.streaming
1244
- else filter_by_audio_fn()
1245
- )
1246
-
1247
- # 10.6: Filter training data with labels longer than `max_label_length`
1248
- def is_labels_in_length_range(labels):
1249
- return 0 < len(labels) <= max_label_length
1250
-
1251
- filter_by_labels_fn = partial(
1252
- vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1253
- )
1254
- with accelerator.main_process_first():
1255
- vectorized_datasets = (
1256
- filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1257
- if not data_args.streaming
1258
- else filter_by_labels_fn()
1259
- )
1260
-
1261
- # Pre-processing complete!
1262
- # For large datasets it is advised to run the preprocessing on a
1263
- # single machine first with `--preprocessing_only` since there will mostly likely
1264
- # be a timeout when running the script in distributed mode.
1265
- # In a second step, `--preprocessing_only` can then be set to `False` to load the
1266
- # cached dataset
1267
- if data_args.preprocessing_only:
1268
- if data_args.streaming:
1269
- raise ValueError(
1270
- "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1271
- "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1272
- "on the fly with streaming mode."
1273
- )
1274
- cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1275
- logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1276
- return
1277
-
1278
- # 11. Define Evaluation Metrics
1279
- def compute_metrics(preds, labels):
1280
- # replace padded labels by the padding token
 
1281
 
1282
- for idx in range(len(labels)):
1283
- labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1284
-
1285
- pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1286
- print(f" pred_str : {pred_str}")
1287
- # we do not want to group tokens when computing the metrics
1288
-
1289
- label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1290
- wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1291
- print(f" label_str : {label_str}")
1292
- # normalize everything and re-compute the WER
1293
- norm_pred_str = [normalizer(pred) for pred in pred_str]
1294
- norm_label_str = [normalizer(label) for label in label_str]
1295
- # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1296
- pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1297
- label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1298
- # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1299
- norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1300
- norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1301
-
1302
- wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1303
- return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1304
-
1305
- # 12. Define Training Schedule
1306
- # Store some constants
1307
- per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1308
- train_batch_size = per_device_train_batch_size * accelerator.num_processes
1309
- gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1310
- per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1311
-
1312
- if not data_args.streaming and training_args.max_steps < 0:
1313
- num_epochs = int(training_args.num_train_epochs)
1314
- steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1315
- total_train_steps = steps_per_epoch * num_epochs
1316
-
1317
- elif training_args.max_steps > 0: #since we use data streaming , this condition is satisfied
1318
- logger.info("max_steps is given, it will override any value given in num_train_epochs")
1319
- total_train_steps = int(training_args.max_steps)
1320
- if not data_args.streaming:
1321
- steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1322
- num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1323
- else:
1324
- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1325
- num_epochs = sys.maxsize #num_epochs as much as possible
1326
- steps_per_epoch = total_train_steps
1327
- else:
1328
- raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1329
-
1330
- if training_args.eval_steps is None:
1331
- logger.info(
1332
- f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1333
- )
1334
- eval_steps = steps_per_epoch
1335
- else:
1336
- eval_steps = training_args.eval_steps
1337
 
1338
- print(f" num_epochs : {num_epochs}")
1339
- print(f" steps_per_epoch = total_train_steps : {steps_per_epoch}")
1340
- # 13. Define optimizer, LR scheduler, collator
1341
- decay_parameters = get_parameter_names(
1342
- student_model,
1343
- [nn.LayerNorm],
1344
- forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1345
- )
1346
- decay_parameters = [name for name in decay_parameters if "bias" not in name]
1347
- optimizer_grouped_parameters = [
1348
- {
1349
- "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1350
- "weight_decay": training_args.weight_decay,
1351
- },
1352
- {
1353
- "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1354
- "weight_decay": 0.0,
1355
- },
1356
- ]
1357
- optimizer = torch.optim.AdamW(
1358
- params=optimizer_grouped_parameters,
1359
- lr=training_args.learning_rate,
1360
- betas=(training_args.adam_beta1, training_args.adam_beta2),
1361
- eps=training_args.adam_epsilon,
1362
- )
1363
-
1364
- # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1365
- lr_scheduler = get_scheduler(
1366
- name=training_args.lr_scheduler_type,
1367
- optimizer=optimizer,
1368
- num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1369
- num_training_steps=total_train_steps * accelerator.num_processes,
1370
- )
1371
- print()
1372
- data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1373
- processor=processor,
1374
- decoder_start_token_id=decoder_start_token_id,
1375
- decoder_prev_token_id=decoder_prev_token_id,
1376
- input_padding="longest",
1377
- target_padding="max_length",
1378
- max_target_length=max_label_length,
1379
- )
1380
-
1381
- # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1382
- # so that we can still access the configs
1383
- num_beams = (
1384
- training_args.generation_num_beams
1385
- if training_args.generation_num_beams is not None
1386
- else getattr(student_model.generation_config, "num_beams", 1)
1387
- )
1388
-
1389
- gen_kwargs = {
1390
- "max_length": max_label_length,
1391
- "num_beams": num_beams,
1392
- "return_timestamps": return_timestamps,
1393
- }
1394
- if is_multilingual:
1395
- # forcing the language and task tokens helps multilingual models in their generations
1396
- gen_kwargs.update(
1397
- {
1398
- "language": data_args.language,
1399
- "task": data_args.task,
1400
- }
1401
- )
1402
- print(f" gen_kwargs : {gen_kwargs}")
1403
- print(f" raw_datasets['eval']: {raw_datasets['eval']}")
1404
-
1405
- #15. Prepare everything with accelerate
1406
- student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1407
- student_model, teacher_model, optimizer, lr_scheduler
1408
- )
1409
 
1410
 
1411
 
1412
 
1413
- def kl_divergence(target_distribution, log_predicted_distribution, labels):
1414
- kl_loss = nn.KLDivLoss(reduction="none")
1415
- divergence = kl_loss(log_predicted_distribution, target_distribution)
1416
- # ignore padded tokens from divergence, i.e. where labels are not set to -100
1417
- padding_mask = labels >= 0
1418
- padding_mask = padding_mask.unsqueeze(-1)
1419
- divergence = divergence * padding_mask
1420
- # take the average over the mini-batch
1421
- divergence = divergence.sum() / padding_mask.sum()
1422
- return divergence
1423
-
1424
- # Define gradient update step fn
1425
- def train_step(
1426
- batch,
1427
- temperature=2.0,
1428
- ):
1429
- student_model.train()
1430
- teacher_model.eval()
1431
-
1432
- student_outputs = student_model(**batch) # __call__ is overidden for forward function , note : student_model and teacher model both are whisperforconditionalgeneration object
1433
- with torch.no_grad():
1434
- if share_hidden_states:
1435
- # if the student and teacher share the same frozen encoder then we don't have to recompute the
1436
- # encoder hidden-states for the teacher model, we can just re-use from the student
1437
- encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1438
- teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1439
- else:
1440
- # do the full forward pass for the teacher model (encoder + decoder)
1441
- teacher_outputs = teacher_model(**batch)
1442
 
1443
- # CE (data) loss
1444
- ce_loss = student_outputs.loss
1445
- # rescale distribution by temperature to ensure gradients scale correctly
1446
- teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1447
- # log softmax of student predictions for numerical stability
1448
- student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1449
- # KL-divergence loss (scaled by temperature)
1450
- kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1451
-
1452
- # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1453
- loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1454
- metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1455
- return loss, metrics
1456
-
1457
- # Define eval fn
1458
- def eval_step(batch):
1459
- student_model.eval()
1460
- teacher_model.eval()
1461
-
1462
- with torch.no_grad():
1463
- student_outputs = student_model(**batch)
1464
- if share_hidden_states:
1465
- encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1466
- teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1467
- else:
1468
- teacher_outputs = teacher_model(**batch)
1469
-
1470
- # CE (data) loss
1471
- ce_loss = student_outputs.loss
1472
-
1473
- # log softmax / softmax for numerical stability
1474
- student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1475
- teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1476
- # temperature is always 1 for eval
1477
- kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1478
-
1479
- # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1480
- loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1481
- metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1482
- return metrics
1483
-
1484
- def generate_step(batch):
1485
- student_model.eval()
1486
- output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1487
- output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1488
- return output_ids
1489
-
1490
- logger.info("***** Running training *****")
1491
- logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") #num examples that actually are trained
1492
- if not data_args.streaming:
1493
- logger.info(f" Num epochs = {num_epochs}")
1494
- logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1495
- logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1496
- logger.info(
1497
- f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1498
- )
1499
- logger.info(f" Total optimization steps = {total_train_steps}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1500
 
1501
- # ======================== Training ================================
1502
- train_time = 0
1503
- train_start = time.time()
1504
- steps_trained_progress_bar = tqdm(
1505
- range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1506
- )
1507
- continue_training = True
1508
- epochs_trained = 0
1509
- cur_step = 0
1510
-
1511
- checkpoint = None
1512
- if training_args.resume_from_checkpoint is not None:
1513
- checkpoint = training_args.resume_from_checkpoint
1514
- elif last_checkpoint is not None:
1515
- checkpoint = last_checkpoint
1516
-
1517
- if checkpoint is not None:
1518
- accelerator.load_state(checkpoint)
1519
- # Find num steps and epoch from saved state string pattern
1520
- pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1521
- match = re.search(pattern, checkpoint)
1522
- cur_step = int(match.group(1))
1523
- epochs_trained = int(match.group(2))
1524
-
1525
- logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1526
- logger.info(f" Continuing training from epoch {epochs_trained}")
1527
- logger.info(f" Continuing training from global step {cur_step}")
1528
-
1529
- steps_trained_progress_bar.update(cur_step)
1530
-
1531
- for epoch in range(0, epochs_trained):
1532
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1533
-
1534
- if not data_args.streaming and training_args.max_steps < 0:
1535
- # we know exactly the number of steps per epoch, so can skip through the required number of batches
1536
- resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1537
- else:
1538
- # Currently we don't know how many steps we've taken in the current epoch
1539
- # So we just shuffle the dataset one extra time and start from a fresh epoch
1540
- # This is "good enough" for our purposes but not fully correct
1541
- resume_step = None
1542
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1543
- else:
1544
- resume_step = None
1545
- print(f" raw_datasets['train'] : {raw_datasets['train']} ")
1546
- print(f" raw_datasets['eval'] : {raw_datasets['eval']} ")
1547
 
1548
- print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
1549
- print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1550
 
1551
- #see example of validation dataloader
1552
- # validation_dataloader = DataLoader(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1553
  # vectorized_datasets[eval_split],
1554
  # collate_fn=data_collator,
1555
  # batch_size=per_device_eval_batch_size,
@@ -1559,198 +1662,96 @@ def main():
1559
  # pin_memory=training_args.dataloader_pin_memory,
1560
  # )
1561
 
1562
- # for batch in validation_dataloader:
1563
- # print(batch['input_features'].shape)
1564
-
1565
 
1566
- print(f" student_model : {type(student_model)}")
1567
-
1568
-
1569
- for epoch in range(epochs_trained, num_epochs):
1570
- vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1571
- train_dataloader = DataLoader(
1572
- vectorized_datasets["train"],
1573
- collate_fn=data_collator,
1574
- batch_size=per_device_train_batch_size,
1575
- num_workers=dataloader_num_workers,
1576
- prefetch_factor=prefetch_factor,
1577
- pin_memory=training_args.dataloader_pin_memory,
1578
- )
1579
- train_dataloader = accelerator.prepare(train_dataloader)
1580
- if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1581
- train_dataloader.dataset.set_epoch(epoch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1582
 
1583
- if resume_step is not None:
1584
- # Skip the first N batches in the dataloader when resuming from a checkpoint
1585
- train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1586
- resume_step = None
 
 
 
 
1587
 
1588
-
1589
- for batch in train_dataloader:
1590
- with accelerator.accumulate(student_model):
1591
- #they are updated their parameters every batch
1592
- loss, train_metric = train_step(batch, temperature=training_args.temperature)
1593
- #backward pass with loss
1594
- accelerator.backward(loss)
1595
- if accelerator.sync_gradients:
1596
- accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1597
- #update after forward method
1598
- optimizer.step()
1599
- lr_scheduler.step()
1600
- optimizer.zero_grad()
1601
-
1602
- # Check if the accelerator has performed an optimization step behind the scenes
1603
- if accelerator.sync_gradients:
1604
- steps_trained_progress_bar.update(1)
1605
- cur_step += 1
1606
-
1607
 
1608
- #logging timing
1609
- if cur_step % training_args.logging_steps == 0:
1610
- steps_trained_progress_bar.write(
1611
- f"Step... ({cur_step} / {total_train_steps} | Loss:"
1612
- f" {train_metric['loss']}, Learning Rate:"
1613
- f" {lr_scheduler.get_last_lr()[0]})"
1614
- )
1615
- log_metric(
1616
- accelerator,
1617
- metrics=train_metric,
1618
- learning_rate=lr_scheduler.get_last_lr()[0],
1619
- train_time=train_time + time.time() - train_start,
1620
- step=cur_step,
1621
- epoch=epoch,
1622
- prefix="train",
1623
- )
1624
 
1625
- # save checkpoint and weights after each save_steps and at the end of training
1626
- if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1627
- intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1628
- accelerator.save_state(output_dir=intermediate_dir)
1629
- accelerator.wait_for_everyone()
1630
- if accelerator.is_main_process:
1631
- rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1632
-
1633
- if training_args.push_to_hub:
1634
- upload_folder(
1635
- folder_path=training_args.output_dir,
1636
- repo_id=repo_name,
1637
- repo_type="model",
1638
- commit_message=f"Saving train state of step {cur_step}",
1639
- )
1640
-
1641
- if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1642
- print("evaluating dsakdlaskdfl;skl;afksdl;fdasl;fkdl;askfl;asdkfldskfl;das")
1643
- train_time += time.time() - train_start
1644
- student_model.eval()
1645
-
1646
- # ======================== Evaluating ==============================
1647
-
1648
- for eval_split in all_eval_splits:
1649
- eval_metrics = []
1650
- eval_preds = []
1651
- eval_labels = []
1652
- eval_start = time.time()
1653
-
1654
- validation_dataloader = DataLoader(
1655
- vectorized_datasets[eval_split],
1656
- collate_fn=data_collator,
1657
- batch_size=per_device_eval_batch_size,
1658
- drop_last=False,
1659
- num_workers=dataloader_num_workers,
1660
- prefetch_factor=prefetch_factor,
1661
- pin_memory=training_args.dataloader_pin_memory,
1662
- )
1663
-
1664
-
1665
- validation_dataloader = accelerator.prepare(validation_dataloader)
1666
-
1667
- for batch in tqdm(
1668
- validation_dataloader,
1669
- desc=f"Evaluating {eval_split}...",
1670
- position=2,
1671
- disable=not accelerator.is_local_main_process,
1672
- ):
1673
- print(f"type(batch) : {type(batch)}")
1674
- # Model forward
1675
- eval_metric = eval_step(batch)
1676
- eval_metric = accelerator.gather_for_metrics(eval_metric)
1677
- eval_metrics.append(eval_metric)
1678
-
1679
- # generation
1680
- if training_args.predict_with_generate:
1681
-
1682
- generated_ids = generate_step(batch)
1683
- # Gather all predictions and targets
1684
- generated_ids, labels = accelerator.gather_for_metrics(
1685
- (generated_ids, batch["labels"])
1686
- )
1687
- eval_preds.extend(generated_ids)
1688
- eval_labels.extend(labels)
1689
-
1690
- eval_time = time.time() - eval_start
1691
- # normalize eval metrics
1692
- eval_metrics = {
1693
- key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1694
- }
1695
-
1696
- # compute WER metric
1697
- wer_desc = ""
1698
- if training_args.predict_with_generate:
1699
- wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1700
- eval_preds, eval_labels
1701
- )
1702
- eval_metrics.update(wer_metric)
1703
- wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1704
- log_pred(
1705
- accelerator,
1706
- pred_str,
1707
- label_str,
1708
- norm_pred_str,
1709
- norm_label_str,
1710
- step=cur_step,
1711
- prefix=eval_split,
1712
- )
1713
-
1714
- # Print metrics and update progress bar
1715
- steps_trained_progress_bar.write(
1716
- f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1717
- f" {wer_desc})"
1718
- )
1719
-
1720
- log_metric(
1721
- accelerator,
1722
- metrics=eval_metrics,
1723
- train_time=eval_time,
1724
- step=cur_step,
1725
- epoch=epoch,
1726
- prefix=eval_split,
1727
- )
1728
-
1729
- # flush the train metrics
1730
- train_start = time.time()
1731
-
1732
- # break condition
1733
- if cur_step == total_train_steps:
1734
-
1735
- # un-wrap student model for save
1736
- student_model = accelerator.unwrap_model(student_model)
1737
- student_model.save_pretrained(training_args.output_dir)
1738
-
1739
- if training_args.push_to_hub:
1740
- upload_folder(
1741
- folder_path=training_args.output_dir,
1742
- repo_id=repo_name,
1743
- repo_type="model",
1744
- commit_message=f"Saving final weights of step {cur_step}",
1745
- )
1746
-
1747
- continue_training = False
1748
- break
1749
-
1750
- if not continue_training:
1751
- break
1752
-
1753
- accelerator.end_training()
1754
 
1755
 
1756
  if __name__ == "__main__":
 
1219
  if training_args.do_eval:
1220
  for eval_split in all_eval_splits:
1221
  raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
1222
+ map_fn_eval = partial( #partial is predefined argument for a function in this case is map function with prepare_eval_dataset function as a predefined argument
1223
  raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
1224
  )
1225
  with accelerator.main_process_first():
 
1229
  else map_fn_eval()
1230
  )
1231
 
1232
+ print(f' vectorized_datasets["train"] : {vectorized_datasets["train"]}')
1233
+
1234
+ # # 10.5: Filter training data with inputs longer than `max_input_length`
1235
+ # def is_audio_in_length_range(length):
1236
+ # return min_input_length < length < max_input_length
1237
+
1238
+ # filter_by_audio_fn = partial(
1239
+ # vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
1240
+ # )
1241
+ # with accelerator.main_process_first():
1242
+ # vectorized_datasets = (
1243
+ # filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
1244
+ # if not data_args.streaming
1245
+ # else filter_by_audio_fn()
1246
+ # )
1247
+
1248
+ # # 10.6: Filter training data with labels longer than `max_label_length`
1249
+ # def is_labels_in_length_range(labels):
1250
+ # return 0 < len(labels) <= max_label_length
1251
+
1252
+ # filter_by_labels_fn = partial(
1253
+ # vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
1254
+ # )
1255
+ # with accelerator.main_process_first():
1256
+ # vectorized_datasets = (
1257
+ # filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
1258
+ # if not data_args.streaming
1259
+ # else filter_by_labels_fn()
1260
+ # )
1261
+
1262
+ # # Pre-processing complete!
1263
+ # # For large datasets it is advised to run the preprocessing on a
1264
+ # # single machine first with `--preprocessing_only` since there will mostly likely
1265
+ # # be a timeout when running the script in distributed mode.
1266
+ # # In a second step, `--preprocessing_only` can then be set to `False` to load the
1267
+ # # cached dataset
1268
+ # if data_args.preprocessing_only:
1269
+ # if data_args.streaming:
1270
+ # raise ValueError(
1271
+ # "When using streaming mode, dataset pre-processing is performed on the fly, hence there is no notion"
1272
+ # "of a cached pre-processed dataset. Remove the argument `--preprocessing_only` to run pre-processing "
1273
+ # "on the fly with streaming mode."
1274
+ # )
1275
+ # cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
1276
+ # logger.info(f"Data preprocessing finished. Files cached at {cache}.")
1277
+ # return
1278
+
1279
+ # # 11. Define Evaluation Metrics
1280
+ # def compute_metrics(preds, labels):
1281
+ # # replace padded labels by the padding token
1282
 
1283
+ # for idx in range(len(labels)):
1284
+ # labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1285
+
1286
+ # pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
1287
+ # print(f" pred_str : {pred_str}")
1288
+ # # we do not want to group tokens when computing the metrics
1289
+
1290
+ # label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
1291
+ # wer_ortho = 100 * metric.compute(predictions=pred_str, references=label_str)
1292
+ # print(f" label_str : {label_str}")
1293
+ # # normalize everything and re-compute the WER
1294
+ # norm_pred_str = [normalizer(pred) for pred in pred_str]
1295
+ # norm_label_str = [normalizer(label) for label in label_str]
1296
+ # # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
1297
+ # pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1298
+ # label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1299
+ # # filtering step to only evaluate the samples that correspond to non-zero normalized references:
1300
+ # norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
1301
+ # norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
1302
+
1303
+ # wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
1304
+ # return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
1305
+
1306
+ # # 12. Define Training Schedule
1307
+ # # Store some constants
1308
+ # per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1309
+ # train_batch_size = per_device_train_batch_size * accelerator.num_processes
1310
+ # gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1311
+ # per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
1312
+
1313
+ # if not data_args.streaming and training_args.max_steps < 0:
1314
+ # num_epochs = int(training_args.num_train_epochs)
1315
+ # steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1316
+ # total_train_steps = steps_per_epoch * num_epochs
1317
+
1318
+ # elif training_args.max_steps > 0: #since we use data streaming , this condition is satisfied
1319
+ # logger.info("max_steps is given, it will override any value given in num_train_epochs")
1320
+ # total_train_steps = int(training_args.max_steps)
1321
+ # if not data_args.streaming:
1322
+ # steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
1323
+ # num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1324
+ # else:
1325
+ # # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1326
+ # num_epochs = sys.maxsize #num_epochs as much as possible
1327
+ # steps_per_epoch = total_train_steps
1328
+ # else:
1329
+ # raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1330
+
1331
+ # if training_args.eval_steps is None:
1332
+ # logger.info(
1333
+ # f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
1334
+ # )
1335
+ # eval_steps = steps_per_epoch
1336
+ # else:
1337
+ # eval_steps = training_args.eval_steps
1338
 
1339
+ # print(f" num_epochs : {num_epochs}")
1340
+ # print(f" steps_per_epoch = total_train_steps : {steps_per_epoch}")
1341
+ # # 13. Define optimizer, LR scheduler, collator
1342
+ # decay_parameters = get_parameter_names(
1343
+ # student_model,
1344
+ # [nn.LayerNorm],
1345
+ # forbidden_module=[student_model.model.encoder] if training_args.freeze_encoder else None,
1346
+ # )
1347
+ # decay_parameters = [name for name in decay_parameters if "bias" not in name]
1348
+ # optimizer_grouped_parameters = [
1349
+ # {
1350
+ # "params": [param for name, param in student_model.named_parameters() if name in decay_parameters],
1351
+ # "weight_decay": training_args.weight_decay,
1352
+ # },
1353
+ # {
1354
+ # "params": [param for name, param in student_model.named_parameters() if name not in decay_parameters],
1355
+ # "weight_decay": 0.0,
1356
+ # },
1357
+ # ]
1358
+ # optimizer = torch.optim.AdamW(
1359
+ # params=optimizer_grouped_parameters,
1360
+ # lr=training_args.learning_rate,
1361
+ # betas=(training_args.adam_beta1, training_args.adam_beta2),
1362
+ # eps=training_args.adam_epsilon,
1363
+ # )
1364
+
1365
+ # # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
1366
+ # lr_scheduler = get_scheduler(
1367
+ # name=training_args.lr_scheduler_type,
1368
+ # optimizer=optimizer,
1369
+ # num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1370
+ # num_training_steps=total_train_steps * accelerator.num_processes,
1371
+ # )
1372
+ # print()
1373
+ # data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1374
+ # processor=processor,
1375
+ # decoder_start_token_id=decoder_start_token_id,
1376
+ # decoder_prev_token_id=decoder_prev_token_id,
1377
+ # input_padding="longest",
1378
+ # target_padding="max_length",
1379
+ # max_target_length=max_label_length,
1380
+ # )
1381
+
1382
+ # # 14. Define generation arguments - we need to do this before we wrap the models in DDP
1383
+ # # so that we can still access the configs
1384
+ # num_beams = (
1385
+ # training_args.generation_num_beams
1386
+ # if training_args.generation_num_beams is not None
1387
+ # else getattr(student_model.generation_config, "num_beams", 1)
1388
+ # )
1389
+
1390
+ # gen_kwargs = {
1391
+ # "max_length": max_label_length,
1392
+ # "num_beams": num_beams,
1393
+ # "return_timestamps": return_timestamps,
1394
+ # }
1395
+ # if is_multilingual:
1396
+ # # forcing the language and task tokens helps multilingual models in their generations
1397
+ # gen_kwargs.update(
1398
+ # {
1399
+ # "language": data_args.language,
1400
+ # "task": data_args.task,
1401
+ # }
1402
+ # )
1403
+ # print(f" gen_kwargs : {gen_kwargs}")
1404
+ # print(f" raw_datasets['eval']: {raw_datasets['eval']}")
1405
+
1406
+ # #15. Prepare everything with accelerate
1407
+ # student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1408
+ # student_model, teacher_model, optimizer, lr_scheduler
1409
+ # )
1410
 
1411
 
1412
 
1413
 
1414
+ # def kl_divergence(target_distribution, log_predicted_distribution, labels):
1415
+ # kl_loss = nn.KLDivLoss(reduction="none")
1416
+ # divergence = kl_loss(log_predicted_distribution, target_distribution)
1417
+ # # ignore padded tokens from divergence, i.e. where labels are not set to -100
1418
+ # padding_mask = labels >= 0
1419
+ # padding_mask = padding_mask.unsqueeze(-1)
1420
+ # divergence = divergence * padding_mask
1421
+ # # take the average over the mini-batch
1422
+ # divergence = divergence.sum() / padding_mask.sum()
1423
+ # return divergence
1424
+
1425
+ # # Define gradient update step fn
1426
+ # def train_step(
1427
+ # batch,
1428
+ # temperature=2.0,
1429
+ # ):
1430
+ # student_model.train()
1431
+ # teacher_model.eval()
1432
+
1433
+ # student_outputs = student_model(**batch) # __call__ is overidden for forward function , note : student_model and teacher model both are whisperforconditionalgeneration object
1434
+ # with torch.no_grad():
1435
+ # if share_hidden_states:
1436
+ # # if the student and teacher share the same frozen encoder then we don't have to recompute the
1437
+ # # encoder hidden-states for the teacher model, we can just re-use from the student
1438
+ # encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1439
+ # teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1440
+ # else:
1441
+ # # do the full forward pass for the teacher model (encoder + decoder)
1442
+ # teacher_outputs = teacher_model(**batch)
1443
 
1444
+ # # CE (data) loss
1445
+ # ce_loss = student_outputs.loss
1446
+ # # rescale distribution by temperature to ensure gradients scale correctly
1447
+ # teacher_distribution = nn.functional.softmax(teacher_outputs.logits / temperature, dim=-1)
1448
+ # # log softmax of student predictions for numerical stability
1449
+ # student_distribution = nn.functional.log_softmax(student_outputs.logits / temperature, dim=-1)
1450
+ # # KL-divergence loss (scaled by temperature)
1451
+ # kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"]) * temperature**2
1452
+
1453
+ # # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1454
+ # loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1455
+ # metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1456
+ # return loss, metrics
1457
+
1458
+ # # Define eval fn
1459
+ # def eval_step(batch):
1460
+ # student_model.eval()
1461
+ # teacher_model.eval()
1462
+
1463
+ # with torch.no_grad():
1464
+ # student_outputs = student_model(**batch)
1465
+ # if share_hidden_states:
1466
+ # encoder_outputs = BaseModelOutput(student_outputs.encoder_last_hidden_state.to(dtype=teacher_dtype))
1467
+ # teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1468
+ # else:
1469
+ # teacher_outputs = teacher_model(**batch)
1470
+
1471
+ # # CE (data) loss
1472
+ # ce_loss = student_outputs.loss
1473
+
1474
+ # # log softmax / softmax for numerical stability
1475
+ # student_distribution = nn.functional.log_softmax(student_outputs.logits, dim=-1)
1476
+ # teacher_distribution = nn.functional.softmax(teacher_outputs.logits, dim=-1)
1477
+ # # temperature is always 1 for eval
1478
+ # kl_loss = kl_divergence(teacher_distribution, student_distribution, batch["labels"])
1479
+
1480
+ # # use Distil-Whisper formulation (fix weight of CE loss and tune KL weight)
1481
+ # loss = 0.8 * ce_loss + training_args.kl_weight * kl_loss
1482
+ # metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss}
1483
+ # return metrics
1484
+
1485
+ # def generate_step(batch):
1486
+ # student_model.eval()
1487
+ # output_ids = accelerator.unwrap_model(student_model).generate(batch["input_features"], **gen_kwargs)
1488
+ # output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
1489
+ # return output_ids
1490
+
1491
+ # logger.info("***** Running training *****")
1492
+ # logger.info(f" Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}") #num examples that actually are trained
1493
+ # if not data_args.streaming:
1494
+ # logger.info(f" Num epochs = {num_epochs}")
1495
+ # logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
1496
+ # logger.info(" Gradient accumulation steps =" f" {gradient_accumulation_steps}")
1497
+ # logger.info(
1498
+ # f" Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
1499
+ # )
1500
+ # logger.info(f" Total optimization steps = {total_train_steps}")
1501
+
1502
+ # # ======================== Training ================================
1503
+ # train_time = 0
1504
+ # train_start = time.time()
1505
+ # steps_trained_progress_bar = tqdm(
1506
+ # range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
1507
+ # )
1508
+ # continue_training = True
1509
+ # epochs_trained = 0
1510
+ # cur_step = 0
1511
+
1512
+ # checkpoint = None
1513
+ # if training_args.resume_from_checkpoint is not None:
1514
+ # checkpoint = training_args.resume_from_checkpoint
1515
+ # elif last_checkpoint is not None:
1516
+ # checkpoint = last_checkpoint
1517
+
1518
+ # if checkpoint is not None:
1519
+ # accelerator.load_state(checkpoint)
1520
+ # # Find num steps and epoch from saved state string pattern
1521
+ # pattern = r"checkpoint-(\d+)-epoch-(\d+)"
1522
+ # match = re.search(pattern, checkpoint)
1523
+ # cur_step = int(match.group(1))
1524
+ # epochs_trained = int(match.group(2))
1525
+
1526
+ # logger.info(" Continuing training from checkpoint, will skip to saved global_step")
1527
+ # logger.info(f" Continuing training from epoch {epochs_trained}")
1528
+ # logger.info(f" Continuing training from global step {cur_step}")
1529
+
1530
+ # steps_trained_progress_bar.update(cur_step)
1531
+
1532
+ # for epoch in range(0, epochs_trained):
1533
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1534
+
1535
+ # if not data_args.streaming and training_args.max_steps < 0:
1536
+ # # we know exactly the number of steps per epoch, so can skip through the required number of batches
1537
+ # resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
1538
+ # else:
1539
+ # # Currently we don't know how many steps we've taken in the current epoch
1540
+ # # So we just shuffle the dataset one extra time and start from a fresh epoch
1541
+ # # This is "good enough" for our purposes but not fully correct
1542
+ # resume_step = None
1543
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1544
+ # else:
1545
+ # resume_step = None
1546
+ # print(f" raw_datasets['train'] : {raw_datasets['train']} ")
1547
+ # print(f" raw_datasets['eval'] : {raw_datasets['eval']} ")
1548
+
1549
+ # print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
1550
+ # print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
1551
+
1552
+ # #see example of validation dataloader
1553
+ # # validation_dataloader = DataLoader(
1554
+ # # vectorized_datasets[eval_split],
1555
+ # # collate_fn=data_collator,
1556
+ # # batch_size=per_device_eval_batch_size,
1557
+ # # drop_last=False,
1558
+ # # num_workers=dataloader_num_workers,
1559
+ # # prefetch_factor=prefetch_factor,
1560
+ # # pin_memory=training_args.dataloader_pin_memory,
1561
+ # # )
1562
+
1563
+ # # for batch in validation_dataloader:
1564
+ # # print(batch['input_features'].shape)
1565
+
1566
 
1567
+ # print(f" student_model : {type(student_model)}")
1568
+
1569
+
1570
+ # for epoch in range(epochs_trained, num_epochs):
1571
+ # vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1572
+ # train_dataloader = DataLoader(
1573
+ # vectorized_datasets["train"],
1574
+ # collate_fn=data_collator,
1575
+ # batch_size=per_device_train_batch_size,
1576
+ # num_workers=dataloader_num_workers,
1577
+ # prefetch_factor=prefetch_factor,
1578
+ # pin_memory=training_args.dataloader_pin_memory,
1579
+ # )
1580
+ # train_dataloader = accelerator.prepare(train_dataloader)
1581
+ # if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
1582
+ # train_dataloader.dataset.set_epoch(epoch)
1583
+
1584
+ # if resume_step is not None:
1585
+ # # Skip the first N batches in the dataloader when resuming from a checkpoint
1586
+ # train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
1587
+ # resume_step = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1588
 
1589
+
1590
+ # for batch in train_dataloader:
1591
+ # with accelerator.accumulate(student_model):
1592
+ # #they are updated their parameters every batch
1593
+ # loss, train_metric = train_step(batch, temperature=training_args.temperature)
1594
+ # #backward pass with loss
1595
+ # accelerator.backward(loss)
1596
+ # if accelerator.sync_gradients:
1597
+ # accelerator.clip_grad_norm_(student_model.parameters(), training_args.max_grad_norm)
1598
+ # #update after forward method
1599
+ # optimizer.step()
1600
+ # lr_scheduler.step()
1601
+ # optimizer.zero_grad()
1602
+
1603
+ # # Check if the accelerator has performed an optimization step behind the scenes
1604
+ # if accelerator.sync_gradients:
1605
+ # steps_trained_progress_bar.update(1)
1606
+ # cur_step += 1
1607
+
1608
 
1609
+ # #logging timing
1610
+ # if cur_step % training_args.logging_steps == 0:
1611
+ # steps_trained_progress_bar.write(
1612
+ # f"Step... ({cur_step} / {total_train_steps} | Loss:"
1613
+ # f" {train_metric['loss']}, Learning Rate:"
1614
+ # f" {lr_scheduler.get_last_lr()[0]})"
1615
+ # )
1616
+ # log_metric(
1617
+ # accelerator,
1618
+ # metrics=train_metric,
1619
+ # learning_rate=lr_scheduler.get_last_lr()[0],
1620
+ # train_time=train_time + time.time() - train_start,
1621
+ # step=cur_step,
1622
+ # epoch=epoch,
1623
+ # prefix="train",
1624
+ # )
1625
+
1626
+ # # save checkpoint and weights after each save_steps and at the end of training
1627
+ # if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
1628
+ # intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
1629
+ # accelerator.save_state(output_dir=intermediate_dir)
1630
+ # accelerator.wait_for_everyone()
1631
+ # if accelerator.is_main_process:
1632
+ # rotate_checkpoints(training_args.save_total_limit, output_dir=training_args.output_dir)
1633
+
1634
+ # if training_args.push_to_hub:
1635
+ # upload_folder(
1636
+ # folder_path=training_args.output_dir,
1637
+ # repo_id=repo_name,
1638
+ # repo_type="model",
1639
+ # commit_message=f"Saving train state of step {cur_step}",
1640
+ # )
1641
+
1642
+ # if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
1643
+ # print("evaluating dsakdlaskdfl;skl;afksdl;fdasl;fkdl;askfl;asdkfldskfl;das")
1644
+ # train_time += time.time() - train_start
1645
+ # student_model.eval()
1646
+
1647
+ # # ======================== Evaluating ==============================
1648
+
1649
+ # for eval_split in all_eval_splits:
1650
+ # eval_metrics = []
1651
+ # eval_preds = []
1652
+ # eval_labels = []
1653
+ # eval_start = time.time()
1654
+
1655
+ # validation_dataloader = DataLoader(
1656
  # vectorized_datasets[eval_split],
1657
  # collate_fn=data_collator,
1658
  # batch_size=per_device_eval_batch_size,
 
1662
  # pin_memory=training_args.dataloader_pin_memory,
1663
  # )
1664
 
 
 
 
1665
 
1666
+ # validation_dataloader = accelerator.prepare(validation_dataloader)
1667
+
1668
+ # for batch in tqdm(
1669
+ # validation_dataloader,
1670
+ # desc=f"Evaluating {eval_split}...",
1671
+ # position=2,
1672
+ # disable=not accelerator.is_local_main_process,
1673
+ # ):
1674
+ # print(f"type(batch) : {type(batch)}")
1675
+ # # Model forward
1676
+ # eval_metric = eval_step(batch)
1677
+ # eval_metric = accelerator.gather_for_metrics(eval_metric)
1678
+ # eval_metrics.append(eval_metric)
1679
+
1680
+ # # generation
1681
+ # if training_args.predict_with_generate:
1682
+
1683
+ # generated_ids = generate_step(batch)
1684
+ # # Gather all predictions and targets
1685
+ # generated_ids, labels = accelerator.gather_for_metrics(
1686
+ # (generated_ids, batch["labels"])
1687
+ # )
1688
+ # eval_preds.extend(generated_ids)
1689
+ # eval_labels.extend(labels)
1690
+
1691
+ # eval_time = time.time() - eval_start
1692
+ # # normalize eval metrics
1693
+ # eval_metrics = {
1694
+ # key: torch.mean(torch.stack([d[key] for d in eval_metrics])) for key in eval_metrics[0]
1695
+ # }
1696
+
1697
+ # # compute WER metric
1698
+ # wer_desc = ""
1699
+ # if training_args.predict_with_generate:
1700
+ # wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
1701
+ # eval_preds, eval_labels
1702
+ # )
1703
+ # eval_metrics.update(wer_metric)
1704
+ # wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
1705
+ # log_pred(
1706
+ # accelerator,
1707
+ # pred_str,
1708
+ # label_str,
1709
+ # norm_pred_str,
1710
+ # norm_label_str,
1711
+ # step=cur_step,
1712
+ # prefix=eval_split,
1713
+ # )
1714
+
1715
+ # # Print metrics and update progress bar
1716
+ # steps_trained_progress_bar.write(
1717
+ # f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
1718
+ # f" {wer_desc})"
1719
+ # )
1720
 
1721
+ # log_metric(
1722
+ # accelerator,
1723
+ # metrics=eval_metrics,
1724
+ # train_time=eval_time,
1725
+ # step=cur_step,
1726
+ # epoch=epoch,
1727
+ # prefix=eval_split,
1728
+ # )
1729
 
1730
+ # # flush the train metrics
1731
+ # train_start = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1732
 
1733
+ # # break condition
1734
+ # if cur_step == total_train_steps:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1735
 
1736
+ # # un-wrap student model for save
1737
+ # student_model = accelerator.unwrap_model(student_model)
1738
+ # student_model.save_pretrained(training_args.output_dir)
1739
+
1740
+ # if training_args.push_to_hub:
1741
+ # upload_folder(
1742
+ # folder_path=training_args.output_dir,
1743
+ # repo_id=repo_name,
1744
+ # repo_type="model",
1745
+ # commit_message=f"Saving final weights of step {cur_step}",
1746
+ # )
1747
+
1748
+ # continue_training = False
1749
+ # break
1750
+
1751
+ # if not continue_training:
1752
+ # break
1753
+
1754
+ # accelerator.end_training()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1755
 
1756
 
1757
  if __name__ == "__main__":