scripts/exp/fine_tune.py CHANGED
@@ -53,6 +53,7 @@ def fine_tune(audio_files_or_folders: List[str], name: str):
53
 
54
  "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
 
56
 
57
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
58
  "AudioLoader.sources": [audio_files_or_folders],
 
53
 
54
  "Interface.coarse2fine_ckpt": f"./models/vampnet/c2f.pth",
55
  "Interface.coarse2fine_lora_ckpt": f"./runs/{name}/c2f/latest/lora.pth",
56
+ "Interface.wavebeat_ckpt": "./models/wavebeat.pth",
57
 
58
  "Interface.codec_ckpt": "./models/vampnet/codec.pth",
59
  "AudioLoader.sources": [audio_files_or_folders],
vampnet/interface.py CHANGED
@@ -65,7 +65,7 @@ class Interface(torch.nn.Module):
65
  ):
66
  super().__init__()
67
  assert codec_ckpt is not None, "must provide a codec checkpoint"
68
- self.codec = DAC.load(codec_ckpt)
69
  self.codec.eval()
70
  self.codec.to(device)
71
 
 
65
  ):
66
  super().__init__()
67
  assert codec_ckpt is not None, "must provide a codec checkpoint"
68
+ self.codec = DAC.load(Path(codec_ckpt))
69
  self.codec.eval()
70
  self.codec.to(device)
71
 
vampnet/modules/transformer.py CHANGED
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  mask: Optional[torch.Tensor] = None,
584
- temperature: Union[float, Tuple[float, float]] = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
@@ -592,7 +592,7 @@ class VampNet(at.ml.BaseModel):
592
  #####################
593
  # resolve temperature #
594
  #####################
595
- assert isinstance(temperature, float)
596
  logging.debug(f"temperature: {temperature}")
597
 
598
 
 
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  mask: Optional[torch.Tensor] = None,
584
+ temperature: float = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
 
592
  #####################
593
  # resolve temperature #
594
  #####################
595
+
596
  logging.debug(f"temperature: {temperature}")
597
 
598