Hugo Flores Garcia commited on
Commit
2257706
1 Parent(s): b862275

add save_epochs

Browse files
Files changed (1) hide show
  1. scripts/exp/train.py +4 -0
scripts/exp/train.py CHANGED
@@ -250,6 +250,7 @@ def train(
250
  max_epochs: int = int(100e3),
251
  epoch_length: int = 1000,
252
  save_audio_epochs: int = 10,
 
253
  batch_size: int = 48,
254
  grad_acc_steps: int = 1,
255
  val_idx: list = [0, 1, 2, 3, 4],
@@ -505,6 +506,9 @@ def train(
505
  loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
506
  self.print(f"Saving to {str(Path('.').absolute())}")
507
 
 
 
 
508
  if self.is_best(engine, loss_key):
509
  self.print(f"Best model so far")
510
  tags.append("best")
 
250
  max_epochs: int = int(100e3),
251
  epoch_length: int = 1000,
252
  save_audio_epochs: int = 10,
253
+ save_epochs: list = [10, 50, 100, 200, 300, 400,],
254
  batch_size: int = 48,
255
  grad_acc_steps: int = 1,
256
  val_idx: list = [0, 1, 2, 3, 4],
 
506
  loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
507
  self.print(f"Saving to {str(Path('.').absolute())}")
508
 
509
+ if self.state.epoch in save_epochs:
510
+ tags.append(f"epoch={self.state.epoch}")
511
+
512
  if self.is_best(engine, loss_key):
513
  self.print(f"Best model so far")
514
  tags.append("best")