Hugo Flores Garcia commited on
Commit
31b771c
1 Parent(s): a66dc9c

dropping torch.compile for now

Browse files
scripts/exp/train.py CHANGED
@@ -29,6 +29,9 @@ from audiotools.ml.decorators import (
29
 
30
  import loralib as lora
31
 
 
 
 
32
 
33
  # Enable cudnn autotuner to speed up training
34
  # (can be altered by the funcs.seed function)
@@ -510,14 +513,14 @@ def load(
510
 
511
  if args["fine_tune"]:
512
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
513
- model = torch.compile(
514
  VampNet.load(location=Path(fine_tune_checkpoint),
515
  map_location="cpu",
516
  )
517
  )
518
 
519
 
520
- model = torch.compile(VampNet()) if model is None else model
521
  model = accel.prepare_model(model)
522
 
523
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
@@ -601,7 +604,7 @@ def train(
601
  accel=accel,
602
  tracker=tracker,
603
  save_path=save_path)
604
-
605
 
606
  train_dataloader = accel.prepare_dataloader(
607
  state.train_data,
@@ -616,13 +619,15 @@ def train(
616
  num_workers=num_workers,
617
  batch_size=batch_size,
618
  collate_fn=state.val_data.collate,
619
- persistent_workers=True,
620
  )
 
621
 
622
 
623
 
624
  if fine_tune:
625
  lora.mark_only_lora_as_trainable(state.model)
 
626
 
627
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
628
  # and only run when specific conditions are met.
@@ -637,6 +642,7 @@ def train(
637
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
638
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
639
 
 
640
  with tracker.live:
641
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
642
  train_loop(state, batch, accel)
 
29
 
30
  import loralib as lora
31
 
32
+ import torch._dynamo
33
+ torch._dynamo.config.verbose=True
34
+
35
 
36
  # Enable cudnn autotuner to speed up training
37
  # (can be altered by the funcs.seed function)
 
513
 
514
  if args["fine_tune"]:
515
  assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
516
+ model = (
517
  VampNet.load(location=Path(fine_tune_checkpoint),
518
  map_location="cpu",
519
  )
520
  )
521
 
522
 
523
+ model = VampNet() if model is None else model
524
  model = accel.prepare_model(model)
525
 
526
  # assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
 
604
  accel=accel,
605
  tracker=tracker,
606
  save_path=save_path)
607
+ print("initialized state.")
608
 
609
  train_dataloader = accel.prepare_dataloader(
610
  state.train_data,
 
619
  num_workers=num_workers,
620
  batch_size=batch_size,
621
  collate_fn=state.val_data.collate,
622
+ persistent_workers=num_workers > 0,
623
  )
624
+ print("initialized dataloader.")
625
 
626
 
627
 
628
  if fine_tune:
629
  lora.mark_only_lora_as_trainable(state.model)
630
+ print("marked only lora as trainable.")
631
 
632
  # Wrap the functions so that they neatly track in TensorBoard + progress bars
633
  # and only run when specific conditions are met.
 
642
  save_samples = when(lambda: accel.local_rank == 0)(save_samples)
643
  checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
644
 
645
+ print("starting training loop.")
646
  with tracker.live:
647
  for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
648
  train_loop(state, batch, accel)
scripts/utils/split_long_audio_file.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argbind
3
+
4
+ import audiotools as at
5
+ import tqdm
6
+
7
+
8
+ @argbind.bind(without_prefix=True)
9
+ def split_long_audio_file(
10
+ file: str = None,
11
+ max_chunk_size_s: int = 60*10
12
+ ):
13
+ file = Path(file)
14
+ output_dir = file.parent / file.stem
15
+ output_dir.mkdir()
16
+
17
+ sig = at.AudioSignal(file)
18
+
19
+ # split into chunks
20
+ for i, sig in tqdm.tqdm(enumerate(sig.windows(
21
+ window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
22
+ preprocess=True))
23
+ ):
24
+ sig.write(output_dir / f"{i}.wav")
25
+
26
+ print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
27
+
28
+ return output_dir
29
+
30
+ if __name__ == "__main__":
31
+ args = argbind.parse_args()
32
+
33
+ with argbind.scope(args):
34
+ split_long_audio_file()