Hugo Flores Garcia commited on
Commit
93ca721
1 Parent(s): a004369
Files changed (3) hide show
  1. .gitignore +2 -1
  2. scripts/exp/train.py +2 -2
  3. vampnet/interface.py +2 -2
.gitignore CHANGED
@@ -179,4 +179,5 @@ gradio-outputs/
179
  models/
180
  samples*/
181
  models-all/
182
- models.zip
 
 
179
  models/
180
  samples*/
181
  models-all/
182
+ models.zip
183
+ .git-old
scripts/exp/train.py CHANGED
@@ -20,7 +20,7 @@ import vampnet
20
  from vampnet.modules.transformer import VampNet
21
  from vampnet.util import codebook_unflatten, codebook_flatten
22
  from vampnet import mask as pmask
23
- from lac.model.lac import LAC
24
 
25
 
26
  # Enable cudnn autotuner to speed up training
@@ -109,7 +109,7 @@ def load(
109
  load_weights: bool = False,
110
  fine_tune_checkpoint: Optional[str] = None,
111
  ):
112
- codec = LAC.load(args["codec_ckpt"], map_location="cpu")
113
  codec.eval()
114
 
115
  model, v_extra = None, {}
 
20
  from vampnet.modules.transformer import VampNet
21
  from vampnet.util import codebook_unflatten, codebook_flatten
22
  from vampnet import mask as pmask
23
+ from dac.model.dac import DAC
24
 
25
 
26
  # Enable cudnn autotuner to speed up training
 
109
  load_weights: bool = False,
110
  fine_tune_checkpoint: Optional[str] = None,
111
  ):
112
+ codec = DAC.load(args["codec_ckpt"], map_location="cpu")
113
  codec.eval()
114
 
115
  model, v_extra = None, {}
vampnet/interface.py CHANGED
@@ -11,7 +11,7 @@ from .modules.transformer import VampNet
11
  from .beats import WaveBeat
12
  from .mask import *
13
 
14
- from lac.model.lac import LAC
15
 
16
 
17
  def signal_concat(
@@ -63,7 +63,7 @@ class Interface(torch.nn.Module):
63
  ):
64
  super().__init__()
65
  assert codec_ckpt is not None, "must provide a codec checkpoint"
66
- self.codec = LAC.load(Path(codec_ckpt))
67
  self.codec.eval()
68
  self.codec.to(device)
69
 
 
11
  from .beats import WaveBeat
12
  from .mask import *
13
 
14
+ from dac.model.dac import DAC
15
 
16
 
17
  def signal_concat(
 
63
  ):
64
  super().__init__()
65
  assert codec_ckpt is not None, "must provide a codec checkpoint"
66
+ self.codec = DAC.load(Path(codec_ckpt))
67
  self.codec.eval()
68
  self.codec.to(device)
69