lgris commited on
Commit
0569275
1 Parent(s): 4ca7515

Saving train state of step 1000

Browse files
checkpoint-1000-epoch-0/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5846539ce04532d21ae1018baad4bd0a3389ec6969bfa7327a995a96fe03ff44
3
  size 3024943976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0872c0c697bae76960041656d4e8354cc69b767929e06c0f35615b5e9bc6ed4c
3
  size 3024943976
checkpoint-1000-epoch-0/optimizer.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c0d86c554f72013a349c5dd47f1d18b5ec93818136e8f7dbd644709a847184cd
3
  size 955529338
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:697196fece1aa67a773dc71ee5d622ea128185dbda531c1796ea76634c55d988
3
  size 955529338
distil-whisper/events.out.tfevents.1705598735.c066756f484e.20967.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:08d3c29719c5def41e6b7b23aae8e54d6851e848fa50a028b1940416efea4dee
3
+ size 12458
run_distillation.py CHANGED
@@ -458,7 +458,7 @@ def log_pred(
458
  ):
459
  """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
460
  if accelerator.is_main_process:
461
- wandb_tracker = accelerator.get_tracker("wandb")
462
  # pretty name for current step: step 50000 -> step 50k
463
  cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
464
  prefix_pretty = prefix.replace("/", "-")
@@ -466,23 +466,23 @@ def log_pred(
466
  # convert str data to a wandb compatible format
467
  str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
468
  # log as a table with the appropriate headers
469
- wandb_tracker.log_table(
470
- table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
471
- columns=["Target", "Pred", "Norm Target", "Norm Pred"],
472
- data=str_data[:num_lines],
473
- step=step,
474
- )
475
 
476
  # log incorrect normalised predictions
477
  str_data = np.asarray(str_data)
478
  str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
479
  # log as a table with the appropriate headers
480
- wandb_tracker.log_table(
481
- table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
482
- columns=["Target", "Pred", "Norm Target", "Norm Pred"],
483
- data=str_data_incorrect[:num_lines],
484
- step=step,
485
- )
486
 
487
 
488
  def convert_dataset_str_to_list(
 
458
  ):
459
  """Helper function to log target/predicted transcriptions to weights and biases (wandb)."""
460
  if accelerator.is_main_process:
461
+ # wandb_tracker = accelerator.get_tracker("wandb")
462
  # pretty name for current step: step 50000 -> step 50k
463
  cur_step_pretty = f"{int(step // 1000)}k" if step > 1000 else step
464
  prefix_pretty = prefix.replace("/", "-")
 
466
  # convert str data to a wandb compatible format
467
  str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
468
  # log as a table with the appropriate headers
469
+ # wandb_tracker.log_table(
470
+ # table_name=f"predictions/{prefix_pretty}-step-{cur_step_pretty}",
471
+ # columns=["Target", "Pred", "Norm Target", "Norm Pred"],
472
+ # data=str_data[:num_lines],
473
+ # step=step,
474
+ # )
475
 
476
  # log incorrect normalised predictions
477
  str_data = np.asarray(str_data)
478
  str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
479
  # log as a table with the appropriate headers
480
+ # wandb_tracker.log_table(
481
+ # table_name=f"incorrect_predictions/{prefix_pretty}-step-{cur_step_pretty}",
482
+ # columns=["Target", "Pred", "Norm Target", "Norm Pred"],
483
+ # data=str_data_incorrect[:num_lines],
484
+ # step=step,
485
+ # )
486
 
487
 
488
  def convert_dataset_str_to_list(