Hugo Flores Garcia commited on
Commit
405226b
1 Parent(s): 88c78e1

use torch.compile for training

Browse files
Files changed (1) hide show
  1. scripts/exp/train.py +7 -5
scripts/exp/train.py CHANGED
@@ -485,7 +485,6 @@ def load(
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
488
- load_weights: bool = False,
489
  fine_tune_checkpoint: Optional[str] = None,
490
  grad_clip_val: float = 5.0,
491
  ) -> State:
@@ -498,7 +497,7 @@ def load(
498
  kwargs = {
499
  "folder": f"{save_path}/{tag}",
500
  "map_location": "cpu",
501
- "package": not load_weights,
502
  }
503
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
504
  if (Path(kwargs["folder"]) / "vampnet").exists():
@@ -511,11 +510,14 @@ def load(
511
 
512
  if args["fine_tune"]:
513
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
514
- model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
515
-
 
 
 
516
 
517
- model = VampNet() if model is None else model
518
 
 
519
  model = accel.prepare_model(model)
520
 
521
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
 
485
  save_path: str,
486
  resume: bool = False,
487
  tag: str = "latest",
 
488
  fine_tune_checkpoint: Optional[str] = None,
489
  grad_clip_val: float = 5.0,
490
  ) -> State:
 
497
  kwargs = {
498
  "folder": f"{save_path}/{tag}",
499
  "map_location": "cpu",
500
+ "package": False,
501
  }
502
  tracker.print(f"Loading checkpoint from {kwargs['folder']}")
503
  if (Path(kwargs["folder"]) / "vampnet").exists():
 
510
 
511
  if args["fine_tune"]:
512
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
513
+ model = torch.compile(
514
+ VampNet.load(location=Path(fine_tune_checkpoint),
515
+ map_location="cpu",
516
+ )
517
+ )
518
 
 
519
 
520
+ model = torch.compile(VampNet()) if model is None else model
521
  model = accel.prepare_model(model)
522
 
523
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks