Hugo Flores Garcia commited on
Commit
75a7169
β€’
1 Parent(s): 13b04cf

efficient lora ckpts

Browse files
README.md CHANGED
@@ -33,41 +33,6 @@ Config files are stored in the `conf/` folder.
33
 
34
  Download the pretrained models from [this link](https://drive.google.com/file/d/1ZIBMJMt8QRE8MYYGjg4lH7v7BLbZneq2/view?usp=sharing). Then, extract the models to the `models/` folder.
35
 
36
- # How the code is structured
37
-
38
- This code was written fast to meet a publication deadline, so it can be messy and redundant at times. Currently working on cleaning it up.
39
-
40
- ```
41
- β”œβ”€β”€ conf <- (conf files for training, finetuning, etc)
42
- β”œβ”€β”€ demo.py <- (gradio UI for playing with vampnet)
43
- β”œβ”€β”€ env <- (environment variables)
44
- β”‚Β Β  └── env.sh
45
- β”œβ”€β”€ models <- (extract pretrained models)
46
- β”‚Β Β  β”œβ”€β”€ spotdl
47
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ c2f.pth <- (coarse2fine checkpoint)
48
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ coarse.pth <- (coarse checkpoint)
49
- β”‚Β Β  β”‚Β Β  └── codec.pth <- (codec checkpoint)
50
- β”‚Β Β  └── wavebeat.pth
51
- β”œβ”€β”€ README.md
52
- β”œβ”€β”€ scripts
53
- β”‚Β Β  β”œβ”€β”€ exp
54
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ eval.py <- (eval script)
55
- β”‚Β Β  β”‚Β Β  └── train.py <- (training/finetuning script)
56
- β”‚Β Β  └── utils
57
- β”œβ”€β”€ vampnet
58
- β”‚Β Β  β”œβ”€β”€ beats.py <- (beat tracking logic)
59
- β”‚Β Β  β”œβ”€β”€ __init__.py
60
- β”‚Β Β  β”œβ”€β”€ interface.py <- (high-level programmatic interface)
61
- β”‚Β Β  β”œβ”€β”€ mask.py
62
- β”‚Β Β  β”œβ”€β”€ modules
63
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ activations.py
64
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ __init__.py
65
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ layers.py
66
- β”‚Β Β  β”‚Β Β  └── transformer.py <- (architecture + sampling code)
67
- β”‚Β Β  β”œβ”€β”€ scheduler.py
68
- β”‚Β Β  └── util.py
69
- ```
70
-
71
  # Usage
72
 
73
  First, you'll want to set up your environment
@@ -90,12 +55,33 @@ python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/ch
90
  ```
91
 
92
  ## Fine-tuning
93
- To fine-tune a model, see the configuration files under `conf/lora/`.
94
- You just need to provide a list of audio files // folders to fine-tune on, then launch the training job as usual.
 
95
  ```bash
96
- python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
97
  ```
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  ## Launching the Gradio Interface
 
33
 
34
  Download the pretrained models from [this link](https://drive.google.com/file/d/1ZIBMJMt8QRE8MYYGjg4lH7v7BLbZneq2/view?usp=sharing). Then, extract the models to the `models/` folder.
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Usage
37
 
38
  First, you'll want to set up your environment
 
55
  ```
56
 
57
  ## Fine-tuning
58
+ To fine-tune a model, use the script in `scripts/exp/fine_tune.py` to generate 3 configuration files: `c2f.yml`, `coarse.yml`, and `interface.yml`.
59
+ The first two are used to fine-tune the coarse and fine models, respectively. The last one is used to fine-tune the interface.
60
+
61
  ```bash
62
+ python scripts/exp/fine_tune.py "/path/to/audio1.mp3 /path/to/audio2/ /path/to/audio3.wav" <fine_tune_name>
63
  ```
64
 
65
+ This will create a folder under `conf/<fine_tune_name>/` with the 3 configuration files.
66
+
67
+ The save_paths will be set to `runs/<fine_tune_name>/coarse` and `runs/<fine_tune_name>/c2f`.
68
+
69
+ launch the coarse job:
70
+ ```bash
71
+ python scripts/exp/train.py --args.load conf/<fine_tune_name>/coarse.yml
72
+ ```
73
+
74
+ this will save the coarse model to `runs/<fine_tune_name>/coarse/ckpt/best/`.
75
+
76
+ launch the c2f job:
77
+ ```bash
78
+ python scripts/exp/train.py --args.load conf/<fine_tune_name>/c2f.yml
79
+ ```
80
+
81
+ launch the interface:
82
+ ```bash
83
+ python demo.py --args.load conf/generated/<fine_tune_name>/interface.yml
84
+ ```
85
 
86
 
87
  ## Launching the Gradio Interface
conf/{generated β†’ generated-v0}/berta-goldman-speech/c2f.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/berta-goldman-speech/coarse.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/berta-goldman-speech/interface.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/gamelan-xeno-canto/c2f.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/gamelan-xeno-canto/coarse.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/gamelan-xeno-canto/interface.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/nasralla/c2f.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/nasralla/coarse.yml RENAMED
File without changes
conf/{generated β†’ generated-v0}/nasralla/interface.yml RENAMED
File without changes
conf/generated/musica-bolero-marimba/c2f.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/musica-bolero-marimba/c2f
13
+ train/AudioLoader.sources:
14
+ - /media/CHONK/hugo/loras/boleros
15
+ - /media/CHONK/hugo/loras/marimba-honduras
16
+ val/AudioLoader.sources:
17
+ - /media/CHONK/hugo/loras/boleros
18
+ - /media/CHONK/hugo/loras/marimba-honduras
conf/generated/musica-bolero-marimba/coarse.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/musica-bolero-marimba/coarse
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/loras/boleros
8
+ - /media/CHONK/hugo/loras/marimba-honduras
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/loras/boleros
11
+ - /media/CHONK/hugo/loras/marimba-honduras
conf/generated/musica-bolero-marimba/interface.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - /media/CHONK/hugo/loras/boleros
3
+ - /media/CHONK/hugo/loras/marimba-honduras
4
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
5
+ Interface.coarse2fine_lora_ckpt: ./runs/musica-bolero-marimba/c2f/latest/lora.pth
6
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
7
+ Interface.coarse_lora_ckpt: ./runs/musica-bolero-marimba/coarse/latest/lora.pth
8
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/xeno-canto/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/xeno-canto/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/xeno-canto-2/
15
+ val/AudioLoader.sources: *id001
conf/generated/xeno-canto/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/xeno-canto/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/xeno-canto-2/
8
+ val/AudioLoader.sources: *id001
conf/generated/xeno-canto/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/xeno-canto-2/
3
+ Interface.coarse2fine_ckpt: ./mod els/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/xeno-canto/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/xeno-canto/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/lora/lora.yml CHANGED
@@ -13,10 +13,10 @@ NoamScheduler.warmup: 500
13
  batch_size: 7
14
  num_workers: 7
15
  epoch_length: 100
16
- save_audio_epochs: 4
17
 
18
  AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
- max_epochs: 1000
 
13
  batch_size: 7
14
  num_workers: 7
15
  epoch_length: 100
16
+ save_audio_epochs: 10
17
 
18
  AdamW.lr: 0.0001
19
 
20
  # let's us organize sound classes into folders and choose from those sound classes uniformly
21
  AudioDataset.without_replacement: False
22
+ max_epochs: 500
conf/vampnet.yml CHANGED
@@ -1,5 +1,5 @@
1
 
2
- codec_ckpt: /home/hugo/descript/vampnet/models/spotdl/codec.pth
3
  save_path: ckpt
4
  max_epochs: 1000
5
  epoch_length: 1000
 
1
 
2
+ codec_ckpt: ./models/spotdl/codec.pth
3
  save_path: ckpt
4
  max_epochs: 1000
5
  epoch_length: 1000
demo.py CHANGED
@@ -104,7 +104,11 @@ def _vamp(data, return_mask=False):
104
  # save the mask as a txt file
105
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
106
 
107
- top_k = data[topk] if data[topk] > 0 else None
 
 
 
 
108
  zv, mask_z = interface.coarse_vamp(
109
  z,
110
  mask=mask,
@@ -354,17 +358,16 @@ with gr.Blocks() as demo:
354
  value=0.0
355
  )
356
 
357
- vamp_button = gr.Button("vamp!!!")
358
 
359
  # mask settings
360
  with gr.Column():
 
361
  output_audio = gr.Audio(
362
  label="output audio",
363
  interactive=False,
364
  type="filepath"
365
  )
366
 
367
- use_as_input_button = gr.Button("use as input")
368
 
369
 
370
  # with gr.Column():
@@ -397,6 +400,7 @@ with gr.Blocks() as demo:
397
  label="vamp to download will appear here",
398
  interactive=False
399
  )
 
400
 
401
  thank_you = gr.Markdown("")
402
 
 
104
  # save the mask as a txt file
105
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
106
 
107
+ if data[topk] is not None:
108
+ top_k = data[topk] if data[topk] > 0 else None
109
+ else:
110
+ top_k = None
111
+
112
  zv, mask_z = interface.coarse_vamp(
113
  z,
114
  mask=mask,
 
358
  value=0.0
359
  )
360
 
 
361
 
362
  # mask settings
363
  with gr.Column():
364
+ vamp_button = gr.Button("vamp!!!")
365
  output_audio = gr.Audio(
366
  label="output audio",
367
  interactive=False,
368
  type="filepath"
369
  )
370
 
 
371
 
372
 
373
  # with gr.Column():
 
400
  label="vamp to download will appear here",
401
  interactive=False
402
  )
403
+ use_as_input_button = gr.Button("use output as input")
404
 
405
  thank_you = gr.Markdown("")
406
 
scripts/exp/fine_tune.py CHANGED
@@ -1,6 +1,7 @@
1
  import argbind
2
  from pathlib import Path
3
  import yaml
 
4
 
5
 
6
 
@@ -10,7 +11,7 @@ import yaml
10
  """
11
 
12
  @argbind.bind(without_prefix=True, positional=True)
13
- def fine_tune(audio_file_or_folder: str, name: str):
14
 
15
  conf_dir = Path("conf")
16
  assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
@@ -24,8 +25,8 @@ def fine_tune(audio_file_or_folder: str, name: str):
24
  finetune_c2f_conf = {
25
  "$include": ["conf/lora/lora.yml"],
26
  "fine_tune": True,
27
- "train/AudioLoader.sources": [audio_file_or_folder],
28
- "val/AudioLoader.sources": [audio_file_or_folder],
29
  "VampNet.n_codebooks": 14,
30
  "VampNet.n_conditioning_codebooks": 4,
31
  "VampNet.embedding_dim": 1280,
@@ -34,21 +35,27 @@ def fine_tune(audio_file_or_folder: str, name: str):
34
  "AudioDataset.duration": 3.0,
35
  "AudioDataset.loudness_cutoff": -40.0,
36
  "save_path": f"./runs/{name}/c2f",
 
37
  }
38
 
39
  finetune_coarse_conf = {
40
  "$include": ["conf/lora/lora.yml"],
41
  "fine_tune": True,
42
- "train/AudioLoader.sources": [audio_file_or_folder],
43
- "val/AudioLoader.sources": [audio_file_or_folder],
44
  "save_path": f"./runs/{name}/coarse",
 
45
  }
46
 
47
  interface_conf = {
48
- "Interface.coarse_ckpt": f"./runs/{name}/coarse/best/vampnet/weights.pth",
49
- "Interface.coarse2fine_ckpt": f"./runs/{name}/c2f/best/vampnet/weights.pth",
 
 
 
 
50
  "Interface.codec_ckpt": "./models/spotdl/codec.pth",
51
- "AudioLoader.sources": [audio_file_or_folder],
52
  }
53
 
54
  # save the confs
@@ -61,18 +68,8 @@ def fine_tune(audio_file_or_folder: str, name: str):
61
  with open(finetune_dir / "interface.yml", "w") as f:
62
  yaml.dump(interface_conf, f)
63
 
64
- # copy the starter weights to the save paths
65
- import shutil
66
-
67
- def pmkdir(path):
68
- Path(path).parent.mkdir(exist_ok=True, parents=True)
69
- return path
70
-
71
- shutil.copy("./models/spotdl/c2f.pth", pmkdir(f"./runs/{name}/c2f/starter/vampnet/weights.pth"))
72
- shutil.copy("./models/spotdl/coarse.pth", pmkdir(f"./runs/{name}/coarse/starter/vampnet/weights.pth"))
73
-
74
 
75
- print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml --resume --load_weights --tag starter` ")
76
 
77
  if __name__ == "__main__":
78
  args = argbind.parse_args()
 
1
  import argbind
2
  from pathlib import Path
3
  import yaml
4
+ from typing import List
5
 
6
 
7
 
 
11
  """
12
 
13
  @argbind.bind(without_prefix=True, positional=True)
14
+ def fine_tune(audio_files_or_folders: List[str], name: str):
15
 
16
  conf_dir = Path("conf")
17
  assert conf_dir.exists(), "conf directory not found. are you in the vampnet directory?"
 
25
  finetune_c2f_conf = {
26
  "$include": ["conf/lora/lora.yml"],
27
  "fine_tune": True,
28
+ "train/AudioLoader.sources": audio_files_or_folders,
29
+ "val/AudioLoader.sources": audio_files_or_folders,
30
  "VampNet.n_codebooks": 14,
31
  "VampNet.n_conditioning_codebooks": 4,
32
  "VampNet.embedding_dim": 1280,
 
35
  "AudioDataset.duration": 3.0,
36
  "AudioDataset.loudness_cutoff": -40.0,
37
  "save_path": f"./runs/{name}/c2f",
38
+ "fine_tune_checkpoint": "./models/spotdl/c2f.pth"
39
  }
40
 
41
  finetune_coarse_conf = {
42
  "$include": ["conf/lora/lora.yml"],
43
  "fine_tune": True,
44
+ "train/AudioLoader.sources": audio_files_or_folders,
45
+ "val/AudioLoader.sources": audio_files_or_folders,
46
  "save_path": f"./runs/{name}/coarse",
47
+ "fine_tune_checkpoint": "./models/spotdl/coarse.pth"
48
  }
49
 
50
  interface_conf = {
51
+ "Interface.coarse_ckpt": f"./models/spotdl/coarse.pth",
52
+ "Interface.coarse_lora_ckpt": f"./runs/{name}/coarse/latest/lora.pth",
53
+
54
+ "Interface.coarse2fine_ckpt": f"./models/spotdl/c2f.pth",
55
+ "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
+
57
  "Interface.codec_ckpt": "./models/spotdl/codec.pth",
58
+ "AudioLoader.sources": [audio_files_or_folders],
59
  }
60
 
61
  # save the confs
 
68
  with open(finetune_dir / "interface.yml", "w") as f:
69
  yaml.dump(interface_conf, f)
70
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ print(f"generated confs in {finetune_dir}. run training jobs with `python scripts/exp/train.py --args.load {finetune_dir}/<c2f/coarse>.yml` ")
73
 
74
  if __name__ == "__main__":
75
  args = argbind.parse_args()
scripts/exp/train.py CHANGED
@@ -107,7 +107,11 @@ def load(
107
  resume: bool = False,
108
  tag: str = "latest",
109
  load_weights: bool = False,
 
110
  ):
 
 
 
111
  model, v_extra = None, {}
112
 
113
  if resume:
@@ -123,8 +127,12 @@ def load(
123
  f"Could not find a VampNet checkpoint in {kwargs['folder']}"
124
  )
125
 
126
- codec = LAC.load(args["codec_ckpt"], map_location="cpu")
127
- codec.eval()
 
 
 
 
128
  model = VampNet() if model is None else model
129
 
130
  model = accel.prepare_model(model)
@@ -460,6 +468,15 @@ def train(
460
  self.print(f"Best model so far")
461
  tags.append("best")
462
 
 
 
 
 
 
 
 
 
 
463
  for tag in tags:
464
  model_extra = {
465
  "optimizer.pth": optimizer.state_dict(),
 
107
  resume: bool = False,
108
  tag: str = "latest",
109
  load_weights: bool = False,
110
+ fine_tune_checkpoint: Optional[str] = None,
111
  ):
112
+ codec = LAC.load(args["codec_ckpt"], map_location="cpu")
113
+ codec.eval()
114
+
115
  model, v_extra = None, {}
116
 
117
  if resume:
 
127
  f"Could not find a VampNet checkpoint in {kwargs['folder']}"
128
  )
129
 
130
+
131
+ if args["fine_tune"]:
132
+ assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
133
+ model = VampNet.load(location=Path(fine_tune_checkpoint), map_location="cpu")
134
+
135
+
136
  model = VampNet() if model is None else model
137
 
138
  model = accel.prepare_model(model)
 
468
  self.print(f"Best model so far")
469
  tags.append("best")
470
 
471
+ if fine_tune:
472
+ for tag in tags:
473
+ # save the lora model
474
+ (Path(save_path) / tag).mkdir(parents=True, exist_ok=True)
475
+ torch.save(
476
+ lora.lora_state_dict(accel.unwrap(model)),
477
+ f"{save_path}/{tag}/lora.pth"
478
+ )
479
+
480
  for tag in tags:
481
  model_extra = {
482
  "optimizer.pth": optimizer.state_dict(),
vampnet/interface.py CHANGED
@@ -21,12 +21,40 @@ def signal_concat(
21
 
22
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  class Interface(torch.nn.Module):
26
  def __init__(
27
  self,
28
  coarse_ckpt: str = None,
 
29
  coarse2fine_ckpt: str = None,
 
30
  codec_ckpt: str = None,
31
  wavebeat_ckpt: str = None,
32
  device: str = "cpu",
@@ -40,18 +68,21 @@ class Interface(torch.nn.Module):
40
  self.codec.to(device)
41
 
42
  assert coarse_ckpt is not None, "must provide a coarse checkpoint"
43
- self.coarse = VampNet.load(location=Path(coarse_ckpt), map_location="cpu")
44
- self.coarse.to(device)
45
- self.coarse.eval()
46
- self.coarse.chunk_size_s = self.s2t2s(coarse_chunk_size_s)
 
 
47
 
 
48
  if coarse2fine_ckpt is not None:
49
- self.c2f = VampNet.load(
50
- location=Path(coarse2fine_ckpt), map_location="cpu"
 
 
 
51
  )
52
- self.c2f.to(device)
53
- self.c2f.eval()
54
- self.c2f.chunk_size_s = self.s2t2s(coarse2fine_chunk_size_s)
55
  else:
56
  self.c2f = None
57
 
@@ -64,6 +95,21 @@ class Interface(torch.nn.Module):
64
 
65
  self.device = device
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  def s2t(self, seconds: float):
68
  """seconds to tokens"""
69
  if isinstance(seconds, np.ndarray):
 
21
 
22
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
23
 
24
+ def _load_model(
25
+ ckpt: str,
26
+ lora_ckpt: str = None,
27
+ device: str = "cpu",
28
+ chunk_size_s: int = 10,
29
+ ):
30
+ # we need to set strict to False if the model has lora weights to add later
31
+ model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False)
32
+
33
+ # load lora weights if needed
34
+ if lora_ckpt is not None:
35
+ if not Path(lora_ckpt).exists():
36
+ should_cont = input(
37
+ f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) "
38
+ )
39
+ if should_cont != "y":
40
+ raise Exception("aborting")
41
+ else:
42
+ model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False)
43
+
44
+ model.to(device)
45
+ model.eval()
46
+ model.chunk_size_s = chunk_size_s
47
+ return model
48
+
49
+
50
 
51
  class Interface(torch.nn.Module):
52
  def __init__(
53
  self,
54
  coarse_ckpt: str = None,
55
+ coarse_lora_ckpt: str = None,
56
  coarse2fine_ckpt: str = None,
57
+ coarse2fine_lora_ckpt: str = None,
58
  codec_ckpt: str = None,
59
  wavebeat_ckpt: str = None,
60
  device: str = "cpu",
 
68
  self.codec.to(device)
69
 
70
  assert coarse_ckpt is not None, "must provide a coarse checkpoint"
71
+ self.coarse = _load_model(
72
+ ckpt=coarse_ckpt,
73
+ lora_ckpt=coarse_lora_ckpt,
74
+ device=device,
75
+ chunk_size_s=coarse_chunk_size_s,
76
+ )
77
 
78
+ # check if we have a coarse2fine ckpt
79
  if coarse2fine_ckpt is not None:
80
+ self.c2f = _load_model(
81
+ ckpt=coarse2fine_ckpt,
82
+ lora_ckpt=coarse2fine_lora_ckpt,
83
+ device=device,
84
+ chunk_size_s=coarse2fine_chunk_size_s,
85
  )
 
 
 
86
  else:
87
  self.c2f = None
88
 
 
95
 
96
  self.device = device
97
 
98
+ def lora_load(
99
+ self,
100
+ coarse_lora_ckpt: str = None,
101
+ coarse2fine_lora_ckpt: str = None,
102
+ ):
103
+ if coarse_lora_ckpt is not None:
104
+ self.coarse.to("cpu")
105
+ self.coarse.load_state_dict(torch.load(coarse_lora_ckpt, map_location="cpu"))
106
+ self.coarse.to(self.device)
107
+ if coarse2fine_lora_ckpt is not None:
108
+ self.c2f.to("cpu")
109
+ self.c2f.load_state_dict(torch.load(coarse2fine_lora_ckpt, map_location="cpu"))
110
+ self.c2f.to(self.device)
111
+
112
+
113
  def s2t(self, seconds: float):
114
  """seconds to tokens"""
115
  if isinstance(seconds, np.ndarray):