supawichwac commited on
Commit
f544a5d
1 Parent(s): 12b47c1

Saving train state of step 5

Browse files
Files changed (25) hide show
  1. checkpoint-5-epoch-0/model.safetensors +3 -0
  2. checkpoint-5-epoch-0/model_1.safetensors +3 -0
  3. checkpoint-5-epoch-0/optimizer.bin +3 -0
  4. checkpoint-5-epoch-0/random_states_0.pkl +3 -0
  5. checkpoint-5-epoch-0/scheduler.bin +3 -0
  6. distil-whisper/events.out.tfevents.1715073979.server02.1433788.0 +3 -0
  7. distil-whisper/events.out.tfevents.1715074029.server02.1434198.0 +3 -0
  8. distil-whisper/events.out.tfevents.1715095796.server02.1514457.0 +3 -0
  9. distil-whisper/events.out.tfevents.1715137750.server02.1659182.0 +3 -0
  10. distil-whisper/events.out.tfevents.1715142860.server02.1688240.0 +3 -0
  11. distil-whisper/events.out.tfevents.1715144009.server02.1717420.0 +3 -0
  12. distil-whisper/events.out.tfevents.1715144142.server02.1721266.0 +3 -0
  13. distil-whisper/events.out.tfevents.1715144248.server02.1724677.0 +3 -0
  14. distil-whisper/events.out.tfevents.1715144329.server02.1726964.0 +3 -0
  15. distil-whisper/events.out.tfevents.1715144689.server02.1736871.0 +3 -0
  16. distil-whisper/events.out.tfevents.1715144766.server02.1739137.0 +3 -0
  17. distil-whisper/events.out.tfevents.1715145134.server02.1748391.0 +3 -0
  18. distil-whisper/events.out.tfevents.1715152989.server02.1776687.0 +3 -0
  19. distil-whisper/events.out.tfevents.1715153425.server02.1778557.0 +3 -0
  20. distil-whisper/events.out.tfevents.1715153634.server02.1779609.0 +3 -0
  21. distil-whisper/events.out.tfevents.1715153723.server02.1780155.0 +3 -0
  22. distil-whisper/events.out.tfevents.1715154461.server02.1782973.0 +3 -0
  23. distil-whisper/events.out.tfevents.1715160495.server02.1805047.0 +3 -0
  24. run_distillation.py +36 -8
  25. test_partial_function.py +41 -0
checkpoint-5-epoch-0/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a21e3711ac40e9335e1f3f3996f60b973cd257c3f524366cf6b834e59d49f13
3
+ size 3025686376
checkpoint-5-epoch-0/model_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b395c8a7e2bda655c415580106288d0387c227efd641bf4e11c1cd735fdb37a
3
+ size 4361070048
checkpoint-5-epoch-0/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d29f0667d7e38b9abb98596a5a9348d8f95ae4e4a7715159e01a41ac9d2f620
3
+ size 955539578
checkpoint-5-epoch-0/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:85d573cec64fffbd3f22840ac5142a2d5238117a2d0f909e2a3a64155fe22435
3
+ size 14344
checkpoint-5-epoch-0/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61c54c0f7915329263989409611568f153678f74fb6fe4366f23ad24844d158f
3
+ size 1064
distil-whisper/events.out.tfevents.1715073979.server02.1433788.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5acad6483f543c7a7c6c1db549ee743b5cd298b504a7d47ab30b9f233fb919c4
3
+ size 88
distil-whisper/events.out.tfevents.1715074029.server02.1434198.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b67893905a27526433ed6485bb3eaffe1e82b3c7da45cbaa27d9266c53433144
3
+ size 88
distil-whisper/events.out.tfevents.1715095796.server02.1514457.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:979ff30080df67a4dcab044cf870600f43073d71992160952446f78b19dcf897
3
+ size 88
distil-whisper/events.out.tfevents.1715137750.server02.1659182.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c806ece5e794b1027c129222cdd44b22a92e26332c899a6e4dc8583f757f7dc
3
+ size 88
distil-whisper/events.out.tfevents.1715142860.server02.1688240.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e1d8a8e923945bfdcb5f29a077c4f7009484ae9a917bf9ea970492efde3c5aa
3
+ size 88
distil-whisper/events.out.tfevents.1715144009.server02.1717420.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4825318ef78b792ae93dc0cd60918637895c9833607e1918cfad58d83bff016
3
+ size 88
distil-whisper/events.out.tfevents.1715144142.server02.1721266.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:980b48300850fc0e5190e9466fd1749a1ed461b5ff2fe918d3e3dfb3644625ef
3
+ size 88
distil-whisper/events.out.tfevents.1715144248.server02.1724677.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bdc23e6b0151e29c67e9069845f1218f686f7667fe3f3bdd1663eea19240cc6
3
+ size 88
distil-whisper/events.out.tfevents.1715144329.server02.1726964.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:551fe2b2ff53ab0e30742564ca1935299589000ee897fb65612d7706002e701c
3
+ size 88
distil-whisper/events.out.tfevents.1715144689.server02.1736871.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:631771d076472b08d53a7585ff812de4dd8e4e500b011da3a34c74ea6cc65d33
3
+ size 88
distil-whisper/events.out.tfevents.1715144766.server02.1739137.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57ce892831e97b587c4573c575dc1ba0e11317517f0b5e3ba4b41822d4eea0e6
3
+ size 88
distil-whisper/events.out.tfevents.1715145134.server02.1748391.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89c5cccd2e339637b9564d94fb6abf49e6dcd9e481292d4d12deaa5367ac49bb
3
+ size 88
distil-whisper/events.out.tfevents.1715152989.server02.1776687.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02fc9bf530ac7fe72cc27f5be66569b22a8d0f20634adea1ffd8b9b8e084cefe
3
+ size 88
distil-whisper/events.out.tfevents.1715153425.server02.1778557.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f66ee5f168a300d24bffa01d463604a5486270b89177b5579501bd69da02f864
3
+ size 88
distil-whisper/events.out.tfevents.1715153634.server02.1779609.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb46c3608e41a15b89df7edb7aa506521d7c2f5f9528f58b44be97b6a7a4b90
3
+ size 88
distil-whisper/events.out.tfevents.1715153723.server02.1780155.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:697804e781696b9e3f46d180576ac88978f77f5493d86bc6f8591928a313daa1
3
+ size 88
distil-whisper/events.out.tfevents.1715154461.server02.1782973.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3790ffdcde3556f8cb2531ac6ade1add4b4cce83bda933fa6f5cb5cdd68f3566
3
+ size 88
distil-whisper/events.out.tfevents.1715160495.server02.1805047.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f4403d07051c2a9e92defb2b8ba4d895beda85e4d2a8b1c8a2fa816d0183ffb
3
+ size 392
run_distillation.py CHANGED
@@ -855,6 +855,9 @@ def main():
855
  )
856
  raw_datasets_train_features = list(raw_datasets["train"].features.keys())
857
 
 
 
 
858
  if training_args.do_eval:
859
  dataset_names_dict = convert_dataset_str_to_list(
860
  data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
@@ -1074,6 +1077,7 @@ def main():
1074
  else raw_datasets["train"].select(range(data_args.max_train_samples))
1075
  )
1076
 
 
1077
  if training_args.do_eval and data_args.max_eval_samples is not None:
1078
  for eval_split in all_eval_splits:
1079
  raw_datasets[eval_split] = (
@@ -1101,6 +1105,13 @@ def main():
1101
  function=is_wer_in_range,
1102
  input_columns=["text", "whisper_transcript"],
1103
  )
 
 
 
 
 
 
 
1104
 
1105
  if wer_threshold is not None and use_pseudo_labels:
1106
  with accelerator.main_process_first():
@@ -1217,6 +1228,7 @@ def main():
1217
  if not data_args.streaming
1218
  else map_fn_eval()
1219
  )
 
1220
 
1221
  # 10.5: Filter training data with inputs longer than `max_input_length`
1222
  def is_audio_in_length_range(length):
@@ -1266,6 +1278,8 @@ def main():
1266
  # 11. Define Evaluation Metrics
1267
  def compute_metrics(preds, labels):
1268
  # replace padded labels by the padding token
 
 
1269
  for idx in range(len(labels)):
1270
  labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1271
 
@@ -1289,7 +1303,7 @@ def main():
1289
 
1290
  # 12. Define Training Schedule
1291
  # Store some constants
1292
- per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1293
  train_batch_size = per_device_train_batch_size * accelerator.num_processes
1294
  gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1295
  per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
@@ -1306,8 +1320,8 @@ def main():
1306
  num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1307
  else:
1308
  # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1309
- num_epochs = sys.maxsize
1310
- steps_per_epoch = total_train_steps
1311
  else:
1312
  raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1313
 
@@ -1318,7 +1332,9 @@ def main():
1318
  eval_steps = steps_per_epoch
1319
  else:
1320
  eval_steps = training_args.eval_steps
1321
-
 
 
1322
  # 13. Define optimizer, LR scheduler, collator
1323
  decay_parameters = get_parameter_names(
1324
  student_model,
@@ -1350,7 +1366,7 @@ def main():
1350
  num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1351
  num_training_steps=total_train_steps * accelerator.num_processes,
1352
  )
1353
-
1354
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1355
  processor=processor,
1356
  decoder_start_token_id=decoder_start_token_id,
@@ -1382,11 +1398,16 @@ def main():
1382
  }
1383
  )
1384
  print(f" gen_kwargs : {gen_kwargs}")
 
 
1385
  #15. Prepare everything with accelerate
1386
  student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1387
  student_model, teacher_model, optimizer, lr_scheduler
1388
  )
1389
 
 
 
 
1390
  def kl_divergence(target_distribution, log_predicted_distribution, labels):
1391
  kl_loss = nn.KLDivLoss(reduction="none")
1392
  divergence = kl_loss(log_predicted_distribution, target_distribution)
@@ -1415,8 +1436,8 @@ def main():
1415
  teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1416
  else:
1417
  # do the full forward pass for the teacher model (encoder + decoder)
1418
- teacher_outputs = teacher_model(**batch)
1419
-
1420
  # CE (data) loss
1421
  ce_loss = student_outputs.loss
1422
  # rescale distribution by temperature to ensure gradients scale correctly
@@ -1519,6 +1540,13 @@ def main():
1519
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1520
  else:
1521
  resume_step = None
 
 
 
 
 
 
 
1522
 
1523
  for epoch in range(epochs_trained, num_epochs):
1524
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
@@ -1596,7 +1624,7 @@ def main():
1596
  eval_labels = []
1597
  eval_start = time.time()
1598
 
1599
- F = DataLoader(
1600
  vectorized_datasets[eval_split],
1601
  collate_fn=data_collator,
1602
  batch_size=per_device_eval_batch_size,
 
855
  )
856
  raw_datasets_train_features = list(raw_datasets["train"].features.keys())
857
 
858
+
859
+ print(f'858 raw_datasets["train"] : {raw_datasets["train"] }')
860
+
861
  if training_args.do_eval:
862
  dataset_names_dict = convert_dataset_str_to_list(
863
  data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
 
1077
  else raw_datasets["train"].select(range(data_args.max_train_samples))
1078
  )
1079
 
1080
+ #if we want to select first n samples , not entire validation set
1081
  if training_args.do_eval and data_args.max_eval_samples is not None:
1082
  for eval_split in all_eval_splits:
1083
  raw_datasets[eval_split] = (
 
1105
  function=is_wer_in_range,
1106
  input_columns=["text", "whisper_transcript"],
1107
  )
1108
+
1109
+
1110
+
1111
+
1112
+ print(f' raw_datasets["train"].filter : {raw_datasets["train"].filter}')
1113
+ print(f' raw_datasets["train"] : {raw_datasets["train"]}')
1114
+
1115
 
1116
  if wer_threshold is not None and use_pseudo_labels:
1117
  with accelerator.main_process_first():
 
1228
  if not data_args.streaming
1229
  else map_fn_eval()
1230
  )
1231
+
1232
 
1233
  # 10.5: Filter training data with inputs longer than `max_input_length`
1234
  def is_audio_in_length_range(length):
 
1278
  # 11. Define Evaluation Metrics
1279
  def compute_metrics(preds, labels):
1280
  # replace padded labels by the padding token
1281
+ print(f" preds : {preds}")
1282
+ print(f" labels : {labels}")
1283
  for idx in range(len(labels)):
1284
  labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
1285
 
 
1303
 
1304
  # 12. Define Training Schedule
1305
  # Store some constants
1306
+ per_device_train_batch_size = int(training_args.per_device_train_batch_size)
1307
  train_batch_size = per_device_train_batch_size * accelerator.num_processes
1308
  gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
1309
  per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
 
1320
  num_epochs = int(np.ceil(total_train_steps / steps_per_epoch))
1321
  else:
1322
  # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1323
+ num_epochs = sys.maxsize #num_epochs as much as possible
1324
+ steps_per_epoch = total_train_steps
1325
  else:
1326
  raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")
1327
 
 
1332
  eval_steps = steps_per_epoch
1333
  else:
1334
  eval_steps = training_args.eval_steps
1335
+
1336
+ print(f" num_epochs : {num_epochs}")
1337
+ print(f" steps_per_epoch = total_train_steps : {steps_per_epoch}")
1338
  # 13. Define optimizer, LR scheduler, collator
1339
  decay_parameters = get_parameter_names(
1340
  student_model,
 
1366
  num_warmup_steps=training_args.warmup_steps * accelerator.num_processes,
1367
  num_training_steps=total_train_steps * accelerator.num_processes,
1368
  )
1369
+ print()
1370
  data_collator = DataCollatorSpeechSeq2SeqWithPadding(
1371
  processor=processor,
1372
  decoder_start_token_id=decoder_start_token_id,
 
1398
  }
1399
  )
1400
  print(f" gen_kwargs : {gen_kwargs}")
1401
+ print(f" raw_datasets['eval']: {raw_datasets['eval']}")
1402
+
1403
  #15. Prepare everything with accelerate
1404
  student_model, teacher_model, optimizer, lr_scheduler = accelerator.prepare(
1405
  student_model, teacher_model, optimizer, lr_scheduler
1406
  )
1407
 
1408
+
1409
+
1410
+
1411
  def kl_divergence(target_distribution, log_predicted_distribution, labels):
1412
  kl_loss = nn.KLDivLoss(reduction="none")
1413
  divergence = kl_loss(log_predicted_distribution, target_distribution)
 
1436
  teacher_outputs = teacher_model(encoder_outputs=encoder_outputs, labels=batch["labels"])
1437
  else:
1438
  # do the full forward pass for the teacher model (encoder + decoder)
1439
+ teacher_outputs = teacher_model(**batch)
1440
+
1441
  # CE (data) loss
1442
  ce_loss = student_outputs.loss
1443
  # rescale distribution by temperature to ensure gradients scale correctly
 
1540
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
1541
  else:
1542
  resume_step = None
1543
+ print(f" raw_datasets['train'] : {raw_datasets['train']} ")
1544
+ print(f" raw_datasets['eval'] : {raw_datasets['eval']} ")
1545
+
1546
+ print(f" vectorized_datasets['eval'] : {vectorized_datasets['eval']}")
1547
+ print(f" vectorized_datasets['train'] : {vectorized_datasets['train']}")
1548
+
1549
+
1550
 
1551
  for epoch in range(epochs_trained, num_epochs):
1552
  vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
 
1624
  eval_labels = []
1625
  eval_start = time.time()
1626
 
1627
+ validation_dataloader = DataLoader(
1628
  vectorized_datasets[eval_split],
1629
  collate_fn=data_collator,
1630
  batch_size=per_device_eval_batch_size,
test_partial_function.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ # Mock dataset in a dictionary form, similar to what you might find in a data processing library
4
+ dataset = {
5
+ "train": [
6
+ {"text": "Hello world", "id": 1},
7
+ {"text": "Partial functions are cool", "id": 2},
8
+ ]
9
+ }
10
+
11
+ # Function to preprocess the dataset
12
+ def prepare_train_dataset(example):
13
+ # Let's say we just transform the text to uppercase for simplicity
14
+ return {"text": example["text"].upper()}
15
+
16
+ # Columns to remove from the dataset after the transformation
17
+ columns_to_remove = ['id']
18
+
19
+ # Creating a mock map function for the dataset
20
+ def dataset_map(batch, function, remove_columns, batched, batch_size):
21
+ # Process each batch
22
+ transformed_data = [function(example) for example in batch]
23
+ # Remove specified columns
24
+ for item in transformed_data:
25
+ for column in remove_columns:
26
+ item.pop(column, None)
27
+ return transformed_data
28
+
29
+ # Using partial to pre-configure the map function
30
+ map_fn_train = partial(
31
+ dataset_map,
32
+ batch=dataset["train"],
33
+ function=prepare_train_dataset,
34
+ remove_columns=columns_to_remove,
35
+ batched=True,
36
+ batch_size=2 # Assuming we process all data in one batch for simplicity
37
+ )
38
+
39
+ # Using the configured function
40
+ transformed_dataset = map_fn_train()
41
+ print(transformed_dataset)