Hugo Flores commited on
Commit
3d08285
2 Parent(s): 79bcce6 d6b9d5b

Merge branch 'main' of github.com:descriptinc/lyrebird-vampnet into main

Browse files
requirements.txt CHANGED
@@ -2,12 +2,12 @@ argbind>=0.3.1
2
  pytorch-ignite
3
  rich
4
  audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@0.6.3
 
5
  tqdm
6
  tensorboard
7
  google-cloud-logging==2.2.0
8
  pytest
9
  pytest-cov
10
- papaya_client @ git+https://github.com/descriptinc/lyrebird-papaya.git@master
11
  pynvml
12
  psutil
13
  pandas
 
2
  pytorch-ignite
3
  rich
4
  audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@0.6.3
5
+ lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main
6
  tqdm
7
  tensorboard
8
  google-cloud-logging==2.2.0
9
  pytest
10
  pytest-cov
 
11
  pynvml
12
  psutil
13
  pandas
scripts/exp/train.py CHANGED
@@ -59,7 +59,7 @@ IGNORE_INDEX = -100
59
  @argbind.bind("train", "val", without_prefix=True)
60
  def build_transform():
61
  transform = transforms.Compose(
62
- tfm.VolumeNorm(("uniform", -32, -14)),
63
  tfm.VolumeChange(("uniform", -6, 3)),
64
  tfm.RescaleAudio(),
65
  )
@@ -250,6 +250,7 @@ def train(
250
  max_epochs: int = int(100e3),
251
  epoch_length: int = 1000,
252
  save_audio_epochs: int = 10,
 
253
  batch_size: int = 48,
254
  grad_acc_steps: int = 1,
255
  val_idx: list = [0, 1, 2, 3, 4],
@@ -506,6 +507,9 @@ def train(
506
  loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
507
  self.print(f"Saving to {str(Path('.').absolute())}")
508
 
 
 
 
509
  if self.is_best(engine, loss_key):
510
  self.print(f"Best model so far")
511
  tags.append("best")
 
59
  @argbind.bind("train", "val", without_prefix=True)
60
  def build_transform():
61
  transform = transforms.Compose(
62
+ tfm.VolumeNorm(("uniform", -32, -20)),
63
  tfm.VolumeChange(("uniform", -6, 3)),
64
  tfm.RescaleAudio(),
65
  )
 
250
  max_epochs: int = int(100e3),
251
  epoch_length: int = 1000,
252
  save_audio_epochs: int = 10,
253
+ save_epochs: list = [10, 50, 100, 200, 300, 400,],
254
  batch_size: int = 48,
255
  grad_acc_steps: int = 1,
256
  val_idx: list = [0, 1, 2, 3, 4],
 
507
  loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
508
  self.print(f"Saving to {str(Path('.').absolute())}")
509
 
510
+ if self.state.epoch in save_epochs:
511
+ tags.append(f"epoch={self.state.epoch}")
512
+
513
  if self.is_best(engine, loss_key):
514
  self.print(f"Best model so far")
515
  tags.append("best")
setup.py CHANGED
@@ -30,11 +30,13 @@ setup(
30
  "argbind>=0.3.2",
31
  "pytorch-ignite",
32
  "rich",
33
- "audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@0.6.0",
 
34
  "tqdm",
35
  "tensorboard",
36
  "google-cloud-logging==2.2.0",
37
  "torchmetrics>=0.7.3",
38
  "einops",
 
39
  ],
40
  )
 
30
  "argbind>=0.3.2",
31
  "pytorch-ignite",
32
  "rich",
33
+ "audiotools @ git+https://github.com/descriptinc/lyrebird-audiotools.git@0.6.3",
34
+ "lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main",
35
  "tqdm",
36
  "tensorboard",
37
  "google-cloud-logging==2.2.0",
38
  "torchmetrics>=0.7.3",
39
  "einops",
40
+ "flash-attn",
41
  ],
42
  )
vampnet/modules/base.py CHANGED
@@ -153,7 +153,7 @@ class VampBase(at.ml.BaseModel):
153
  sampling_steps: int = 12,
154
  start_tokens: Optional[torch.Tensor] = None,
155
  mask: Optional[torch.Tensor] = None,
156
- temperature: Union[float, Tuple[float, float]] = 1.0,
157
  top_k: int = None,
158
  sample: str = "gumbel",
159
  renoise_mode: str = "start",
@@ -262,7 +262,7 @@ class VampBase(at.ml.BaseModel):
262
  sampling_steps: int = 24,
263
  start_tokens: Optional[torch.Tensor] = None,
264
  mask: Optional[torch.Tensor] = None,
265
- temperature: Union[float, Tuple[float, float]] = 1.0,
266
  top_k: int = None,
267
  sample: str = "multinomial",
268
  typical_filtering=False,
 
153
  sampling_steps: int = 12,
154
  start_tokens: Optional[torch.Tensor] = None,
155
  mask: Optional[torch.Tensor] = None,
156
+ temperature: Union[float, Tuple[float, float]] = 0.8,
157
  top_k: int = None,
158
  sample: str = "gumbel",
159
  renoise_mode: str = "start",
 
262
  sampling_steps: int = 24,
263
  start_tokens: Optional[torch.Tensor] = None,
264
  mask: Optional[torch.Tensor] = None,
265
+ temperature: Union[float, Tuple[float, float]] = 0.8,
266
  top_k: int = None,
267
  sample: str = "multinomial",
268
  typical_filtering=False,