Hugo Flores Garcia commited on
Commit
99122c4
1 Parent(s): 5a343f4

basic readme stuff

Browse files
README.md CHANGED
@@ -1,80 +1,58 @@
1
- # Lyrebird VampNet
2
 
3
- This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
4
 
 
5
 
6
- ## Development
7
- ### Setting everything up
8
 
9
- Run the setup script to set up your environment via:
10
 
11
  ```bash
12
- python env/setup.py
 
13
  ```
14
 
15
- The setup script does not require any dependencies beyond just Python.
16
- Once run, follow the instructions it prints out to create your
17
- environment file, which will be at `env/env.sh`.
18
-
19
- Note that if this is a new machine, and
20
- the data is not downloaded somewhere on it already, it will ask you
21
- for a directory to download the data to.
22
-
23
- For Github setup, if you don't have a .netrc token, create one by going to your Github profile -> Developer settings -> Personal access tokens -> Generate new token. Copy the token and [keep it secret, keep it safe](https://www.youtube.com/watch?v=iThtELZvfPs).
24
-
25
- When complete, run:
26
 
27
  ```bash
28
- source env/env.sh
 
29
  ```
30
 
31
- Now build and launch the Docker containers:
32
 
33
  ```bash
34
- docker compose up -d
 
35
  ```
36
 
37
- This builds and runs a Jupyter notebook and Tensorboard
38
- in the background, which points to your `TENSORBOARD_PATH`
39
- env. variable.
40
 
41
- Now, launch your development environment via:
42
 
43
- ```bash
44
- docker compose run dev
45
- ```
46
-
47
- To tear down your development environment, just do
48
-
49
- ```bash
50
- docker compose down
51
- ```
52
 
 
53
 
54
- ### Launching an experiment
55
 
56
- Experiments are first _staged_ by running the `stage` command (which corresponds to the script `scripts/exp/stage.py`).
57
-
58
- `stage` creates a directory with a copy of all of the Git-tracked files in the root repository.`stage` launches a shell into said directory, so all commands are run on the
59
- copy of the original repository code. This is useful for rewinding to an old experiment
60
- and resuming it, for example. Even if the repository code changes, the snapshot in the experiment directory is unchanged from the original run, so it can be re-used.
61
-
62
- Then, the experiment can be run via:
63
 
64
  ```bash
65
- torchrun --nproc_per_node gpu \
66
- scripts/exp/train.py \
67
- --args.load=conf/args.yml \
68
  ```
69
 
70
- The full settings are in [conf/daps/train.yml](conf/daps/train.yml).
71
-
72
- ### Useful commands
73
-
74
- #### Cleaning up after a run
75
-
76
- Sometimes DDP runs fail to clear themselves out of the machine. To fix this, run
77
-
78
  ```bash
79
- cleanup
80
  ```
 
 
 
 
 
 
1
+ # VampNet
2
 
3
+ This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
4
 
5
+ # Setting up
6
 
7
+ ## Install LAC
 
8
 
9
+ install AudioTools
10
 
11
  ```bash
12
+ git clone https://github.com/hugofloresgarcia/audiotools.git
13
+ pip install -e ./audiotools
14
  ```
15
 
16
+ install the LAC library.
 
 
 
 
 
 
 
 
 
 
17
 
18
  ```bash
19
+ git clone https://github.com/hugofloresgarcia/lac.git
20
+ pip install -e ./lac
21
  ```
22
 
23
+ install VampNet
24
 
25
  ```bash
26
+ git clone https://github.com/hugofloresgarcia/vampnet2.git
27
+ pip install -e ./vampnet2
28
  ```
29
 
30
+ ## A note on Argbind
31
+ This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
32
+ Config files are stored in the `conf/` folder.
33
 
34
+ # Usage
35
 
36
+ ## Staging a Run
 
 
 
 
 
 
 
 
37
 
38
+ Staging a run makes a copy of all the git-tracked files in the codebase and saves them to a folder for reproducibility. You can then run the training script from the staged folder.
39
 
40
+ coming soon
41
 
42
+ ## Training a model
 
 
 
 
 
 
43
 
44
  ```bash
45
+ python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
 
 
46
  ```
47
 
48
+ ## Fine-tuning
49
+ To fine-tune a model, see the configuration files under `conf/lora/`.
50
+ You just need to provide a list of audio files // folders to fine-tune on, then launch the training job as usual.
 
 
 
 
 
51
  ```bash
52
+ python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
53
  ```
54
+
55
+ ## Launching the Gradio Interface
56
+ ```bash
57
+ python demo.py --args.load conf/interface/spotdl.yml --Interface.device cuda
58
+ ```
conf/lora/gas-station.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+
4
+ fine_tune: True
5
+
6
+ train/AudioLoader.sources:
7
+ - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
8
+
9
+ val/AudioLoader.sources:
10
+ - /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
demo.py CHANGED
@@ -48,6 +48,7 @@ def load_audio(file):
48
  sig.write(out_dir / "input.wav")
49
  return sig.path_to_file
50
 
 
51
  def load_random_audio():
52
  index = np.random.randint(0, len(dataset))
53
  sig = dataset[index]["signal"]
@@ -68,7 +69,7 @@ def ez_vamp(
68
  sig = at.AudioSignal(input_audio)
69
 
70
  print(f"running standard vampnet with {num_vamps} vamps")
71
- zv = interface.coarse_vamp_v2(
72
  sig,
73
  sampling_steps=num_steps,
74
  temperature=(init_temp, final_temp),
@@ -140,7 +141,7 @@ def vamp(
140
 
141
  if mode == "standard":
142
  print(f"running standard vampnet with {num_vamps} vamps")
143
- zv, mask_z = interface.coarse_vamp_v2(
144
  sig,
145
  sampling_steps=num_steps,
146
  temperature=(init_temp, final_temp),
 
48
  sig.write(out_dir / "input.wav")
49
  return sig.path_to_file
50
 
51
+
52
  def load_random_audio():
53
  index = np.random.randint(0, len(dataset))
54
  sig = dataset[index]["signal"]
 
69
  sig = at.AudioSignal(input_audio)
70
 
71
  print(f"running standard vampnet with {num_vamps} vamps")
72
+ zv = interface.coarse_vamp(
73
  sig,
74
  sampling_steps=num_steps,
75
  temperature=(init_temp, final_temp),
 
141
 
142
  if mode == "standard":
143
  print(f"running standard vampnet with {num_vamps} vamps")
144
+ zv, mask_z = interface.coarse_vamp(
145
  sig,
146
  sampling_steps=num_steps,
147
  temperature=(init_temp, final_temp),
scripts/exp/train.py CHANGED
@@ -115,6 +115,10 @@ def load(
115
  }
116
  if (Path(kwargs["folder"]) / "vampnet").exists():
117
  model, v_extra = VampNet.load_from_folder(**kwargs)
 
 
 
 
118
 
119
  codec = LAC.load(args["codec_ckpt"], map_location="cpu")
120
  codec.eval()
@@ -149,25 +153,6 @@ def load(
149
  }
150
 
151
 
152
- def get_gpu_memory_map():
153
- """Get the current gpu usage.
154
-
155
- Returns
156
- -------
157
- usage: dict
158
- Keys are device ids as integers.
159
- Values are memory usage as integers in MB.
160
- """
161
- result = subprocess.check_output(
162
- ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
163
- encoding="utf-8",
164
- )
165
- # Convert lines into a dictionary
166
- gpu_memory = [int(x) for x in result.strip().split("\n")]
167
- gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
168
- gpu_memory_map = {f"gpu/{k}": v / 1024 for k, v in gpu_memory_map.items()}
169
- return gpu_memory_map
170
-
171
 
172
  def num_params_hook(o, p):
173
  return o + f" {p/1e6:<.3f}M params."
@@ -189,7 +174,6 @@ def accuracy(
189
  target: torch.Tensor,
190
  top_k: int = 1,
191
  ignore_index: Optional[int] = None,
192
- **kwargs,
193
  ) -> torch.Tensor:
194
  # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
195
  preds = rearrange(preds, "b p s -> (b s) p")
@@ -214,30 +198,6 @@ def accuracy(
214
 
215
  return accuracy
216
 
217
- def sample_prefix_suffix_amt(
218
- z,
219
- n_batch,
220
- prefix_amt,
221
- suffix_amt,
222
- prefix_dropout,
223
- suffix_dropout,
224
- rng
225
- ):
226
- """
227
- Sample the number of prefix and suffix tokens to drop.
228
- """
229
- if prefix_amt > 0.0:
230
- prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
231
- n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
232
- else:
233
- n_prefix = None
234
- if suffix_amt > 0.0:
235
- suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
236
- n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
237
- else:
238
- n_suffix = None
239
- return n_prefix, n_suffix
240
-
241
 
242
  @argbind.bind(without_prefix=True)
243
  def train(
@@ -256,10 +216,6 @@ def train(
256
  num_workers: int = 10,
257
  detect_anomaly: bool = False,
258
  grad_clip_val: float = 5.0,
259
- prefix_amt: float = 0.0,
260
- suffix_amt: float = 0.0,
261
- prefix_dropout: float = 0.1,
262
- suffix_dropout: float = 0.1,
263
  fine_tune: bool = False,
264
  quiet: bool = False,
265
  ):
@@ -342,16 +298,12 @@ def train(
342
  target=r_unmasked_target,
343
  ignore_index=IGNORE_INDEX,
344
  top_k=topk,
345
- task="multiclass",
346
- num_classes=vn.vocab_size,
347
  )
348
  output[f"{tag}/masked"] = accuracy(
349
  preds=r_z_hat,
350
  target=r_masked_target,
351
  ignore_index=IGNORE_INDEX,
352
  top_k=topk,
353
- task="multiclass",
354
- num_classes=vn.vocab_size,
355
  )
356
 
357
  def train_loop(self, engine, batch):
@@ -370,15 +322,7 @@ def train(
370
  n_batch = z.shape[0]
371
  r = rng.draw(n_batch)[:, 0].to(accel.device)
372
 
373
- n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
374
- n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
375
- prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
376
- rng=rng
377
- )
378
-
379
- z_mask, mask = vn.add_noise(
380
- z, r, n_prefix=n_prefix, n_suffix=n_suffix
381
- )
382
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
383
 
384
  dtype = torch.bfloat16 if accel.amp else None
@@ -454,13 +398,7 @@ def train(
454
  n_batch = z.shape[0]
455
  r = rng.draw(n_batch)[:, 0].to(accel.device)
456
 
457
- n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
458
- n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
459
- prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
460
- rng=rng
461
- )
462
-
463
- z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
464
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
465
 
466
  z_hat = model(z_mask_latent, r)
@@ -574,17 +512,8 @@ def train(
574
  )
575
 
576
  def save_imputation(self, z: torch.Tensor):
577
- # imputations
578
- _prefix_amt = prefix_amt
579
- _suffix_amt = suffix_amt
580
-
581
- if _prefix_amt == 0:
582
- _prefix_amt = 0.25
583
- if _suffix_amt == 0:
584
- _suffix_amt = 0.25
585
-
586
- n_prefix = int(z.shape[-1] * _prefix_amt)
587
- n_suffix = int(z.shape[-1] * _suffix_amt)
588
  downsample_factor = None
589
 
590
  vn = accel.unwrap(model)
@@ -647,13 +576,7 @@ def train(
647
 
648
  n_batch = z.shape[0]
649
 
650
- n_prefix, n_suffix = sample_prefix_suffix_amt(z=z,
651
- n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
652
- prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
653
- rng=rng
654
- )
655
-
656
- z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
657
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
658
 
659
  z_hat = model(z_mask_latent, r)
@@ -664,7 +587,6 @@ def train(
664
  z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
665
  z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
666
 
667
- print("z_mask", z_mask.shape)
668
  generated = vn.to_signal(z_pred, codec)
669
  reconstructed = vn.to_signal(z, codec)
670
  masked = vn.to_signal(z_mask.squeeze(1), codec)
 
115
  }
116
  if (Path(kwargs["folder"]) / "vampnet").exists():
117
  model, v_extra = VampNet.load_from_folder(**kwargs)
118
+ else:
119
+ raise ValueError(
120
+ f"Could not find a VampNet checkpoint in {kwargs['folder']}"
121
+ )
122
 
123
  codec = LAC.load(args["codec_ckpt"], map_location="cpu")
124
  codec.eval()
 
153
  }
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def num_params_hook(o, p):
158
  return o + f" {p/1e6:<.3f}M params."
 
174
  target: torch.Tensor,
175
  top_k: int = 1,
176
  ignore_index: Optional[int] = None,
 
177
  ) -> torch.Tensor:
178
  # Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
179
  preds = rearrange(preds, "b p s -> (b s) p")
 
198
 
199
  return accuracy
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  @argbind.bind(without_prefix=True)
203
  def train(
 
216
  num_workers: int = 10,
217
  detect_anomaly: bool = False,
218
  grad_clip_val: float = 5.0,
 
 
 
 
219
  fine_tune: bool = False,
220
  quiet: bool = False,
221
  ):
 
298
  target=r_unmasked_target,
299
  ignore_index=IGNORE_INDEX,
300
  top_k=topk,
 
 
301
  )
302
  output[f"{tag}/masked"] = accuracy(
303
  preds=r_z_hat,
304
  target=r_masked_target,
305
  ignore_index=IGNORE_INDEX,
306
  top_k=topk,
 
 
307
  )
308
 
309
  def train_loop(self, engine, batch):
 
322
  n_batch = z.shape[0]
323
  r = rng.draw(n_batch)[:, 0].to(accel.device)
324
 
325
+ z_mask, mask = vn.add_noise(z, r)
 
 
 
 
 
 
 
 
326
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
327
 
328
  dtype = torch.bfloat16 if accel.amp else None
 
398
  n_batch = z.shape[0]
399
  r = rng.draw(n_batch)[:, 0].to(accel.device)
400
 
401
+ z_mask, mask = vn.add_noise(z, r)
 
 
 
 
 
 
402
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
403
 
404
  z_hat = model(z_mask_latent, r)
 
512
  )
513
 
514
  def save_imputation(self, z: torch.Tensor):
515
+ n_prefix = int(z.shape[-1] * 0.25)
516
+ n_suffix = int(z.shape[-1] * 0.25)
 
 
 
 
 
 
 
 
 
517
  downsample_factor = None
518
 
519
  vn = accel.unwrap(model)
 
576
 
577
  n_batch = z.shape[0]
578
 
579
+ z_mask, mask = vn.add_noise(z, r)
 
 
 
 
 
 
580
  z_mask_latent = vn.embedding.from_codes(z_mask, codec)
581
 
582
  z_hat = model(z_mask_latent, r)
 
587
  z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
588
  z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
589
 
 
590
  generated = vn.to_signal(z_pred, codec)
591
  reconstructed = vn.to_signal(z, codec)
592
  masked = vn.to_signal(z_mask.squeeze(1), codec)
scripts/utils/vamp_folder.py CHANGED
@@ -56,7 +56,7 @@ class CoarseCond:
56
 
57
  def __call__(self, sig, interface):
58
  n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
59
- zv = interface.coarse_vamp_v2(sig,
60
  n_conditioning_codebooks=n_conditioning_codebooks,
61
  downsample_factor=self.downsample_factor,
62
  )
@@ -113,7 +113,7 @@ def mask_ratio_1_step(ratio=1.0):
113
  r = interface.coarse.invgamma(ratio).to(interface.device)
114
  intensity = 1-r
115
 
116
- zv = interface.coarse_vamp_v2(
117
  sig,
118
  sample='argmax',
119
  sampling_steps=1,
@@ -125,7 +125,7 @@ def mask_ratio_1_step(ratio=1.0):
125
 
126
  def num_sampling_steps(num_steps=1):
127
  def wrapper(sig, interface):
128
- zv = interface.coarse_vamp_v2(
129
  sig,
130
  downsample_factor=16,
131
  sampling_steps=num_steps,
@@ -143,7 +143,7 @@ def beat_mask(ctx_time):
143
  after_beat_s=ctx_time,
144
  invert=True
145
  )
146
- zv = interface.coarse_vamp_v2(
147
  sig,
148
  ext_mask=beat_mask,
149
  )
@@ -154,7 +154,7 @@ def beat_mask(ctx_time):
154
 
155
  def inpaint(ctx_time):
156
  def wrapper(sig, interface):
157
- zv = interface.coarse_vamp_v2(
158
  sig,
159
  prefix_dur_s=ctx_time,
160
  suffix_dur_s=ctx_time,
 
56
 
57
  def __call__(self, sig, interface):
58
  n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
59
+ zv = interface.coarse_vamp(sig,
60
  n_conditioning_codebooks=n_conditioning_codebooks,
61
  downsample_factor=self.downsample_factor,
62
  )
 
113
  r = interface.coarse.invgamma(ratio).to(interface.device)
114
  intensity = 1-r
115
 
116
+ zv = interface.coarse_vamp(
117
  sig,
118
  sample='argmax',
119
  sampling_steps=1,
 
125
 
126
  def num_sampling_steps(num_steps=1):
127
  def wrapper(sig, interface):
128
+ zv = interface.coarse_vamp(
129
  sig,
130
  downsample_factor=16,
131
  sampling_steps=num_steps,
 
143
  after_beat_s=ctx_time,
144
  invert=True
145
  )
146
+ zv = interface.coarse_vamp(
147
  sig,
148
  ext_mask=beat_mask,
149
  )
 
154
 
155
  def inpaint(ctx_time):
156
  def wrapper(sig, interface):
157
+ zv = interface.coarse_vamp(
158
  sig,
159
  prefix_dur_s=ctx_time,
160
  suffix_dur_s=ctx_time,
vampnet/interface.py CHANGED
@@ -20,6 +20,14 @@ def signal_concat(
20
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
21
 
22
 
 
 
 
 
 
 
 
 
23
  class Interface(torch.nn.Module):
24
  def __init__(
25
  self,
@@ -28,7 +36,7 @@ class Interface(torch.nn.Module):
28
  codec_ckpt: str = None,
29
  wavebeat_ckpt: str = None,
30
  device: str = "cpu",
31
- coarse_chunk_size_s: int = 5,
32
  coarse2fine_chunk_size_s: int = 3,
33
  ):
34
  super().__init__()
@@ -141,7 +149,7 @@ class Interface(torch.nn.Module):
141
  """make a beat synced mask. that is, make a mask that
142
  places 1s at and around the beat, and 0s everywhere else.
143
  """
144
- assert hasattr(self, "beat_tracker"), "No beat tracker loaded"
145
 
146
  # get the beat times
147
  beats, downbeats = self.beat_tracker.extract_beats(signal)
@@ -242,7 +250,7 @@ class Interface(torch.nn.Module):
242
  return fine_z[:, :, :length].clone()
243
 
244
 
245
- def coarse_vamp_v2(
246
  self,
247
  signal,
248
  prefix_dur_s: float = 0.0,
@@ -471,7 +479,7 @@ class Interface(torch.nn.Module):
471
  else:
472
  ext_mask = None
473
 
474
- out_z = self.coarse_vamp_v2(
475
  sig,
476
  num_vamps=1,
477
  swap_prefix_suffix=False,
@@ -520,7 +528,7 @@ class Interface(torch.nn.Module):
520
  range_fn = range if not verbose else tqdm.trange
521
  for i in range_fn(num_loops):
522
  is_flipped = i % 2 == 0
523
- vamped = self.coarse_vamp_v2(
524
  signal,
525
  prefix_dur_s=prefix_dur_s,
526
  suffix_dur_s=suffix_dur_s,
 
20
  return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
21
 
22
 
23
+ class SignalPrompt:
24
+
25
+ def __init__(self, signal: AudioSignal):
26
+ self.sig = signal
27
+
28
+
29
+
30
+
31
  class Interface(torch.nn.Module):
32
  def __init__(
33
  self,
 
36
  codec_ckpt: str = None,
37
  wavebeat_ckpt: str = None,
38
  device: str = "cpu",
39
+ coarse_chunk_size_s: int = 10,
40
  coarse2fine_chunk_size_s: int = 3,
41
  ):
42
  super().__init__()
 
149
  """make a beat synced mask. that is, make a mask that
150
  places 1s at and around the beat, and 0s everywhere else.
151
  """
152
+ assert self.beat_tracker is not None, "No beat tracker loaded"
153
 
154
  # get the beat times
155
  beats, downbeats = self.beat_tracker.extract_beats(signal)
 
250
  return fine_z[:, :, :length].clone()
251
 
252
 
253
+ def coarse_vamp(
254
  self,
255
  signal,
256
  prefix_dur_s: float = 0.0,
 
479
  else:
480
  ext_mask = None
481
 
482
+ out_z = self.coarse_vamp(
483
  sig,
484
  num_vamps=1,
485
  swap_prefix_suffix=False,
 
528
  range_fn = range if not verbose else tqdm.trange
529
  for i in range_fn(num_loops):
530
  is_flipped = i % 2 == 0
531
+ vamped = self.coarse_vamp(
532
  signal,
533
  prefix_dur_s=prefix_dur_s,
534
  suffix_dur_s=suffix_dur_s,
vampnet/modules/base.py CHANGED
@@ -10,6 +10,8 @@ import torch.nn.functional as F
10
  from einops import rearrange
11
  from tqdm import tqdm
12
 
 
 
13
 
14
  def log(t, eps=1e-20):
15
  return torch.log(t + eps)
@@ -24,9 +26,6 @@ def gumbel_sample(t, temperature=1.0, dim=-1):
24
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
25
 
26
 
27
- def scalar_to_batch_tensor(x, batch_size):
28
- return torch.tensor(x).repeat(batch_size)
29
-
30
  class VampBase(at.ml.BaseModel):
31
  def forward(self, x: torch.Tensor, r: torch.Tensor):
32
  raise NotImplementedError
@@ -150,6 +149,8 @@ class VampBase(at.ml.BaseModel):
150
  z_hat = z_hat * mask + truth * (1 - mask)
151
 
152
  z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
 
 
153
 
154
  return z_hat
155
 
@@ -186,6 +187,9 @@ class VampBase(at.ml.BaseModel):
186
 
187
  @torch.no_grad()
188
  def to_signal(self, z, codec):
 
 
 
189
  if z.ndim == 2:
190
  z = self.embedding.unflatten(z)
191
  assert z.ndim == 3
@@ -207,122 +211,7 @@ class VampBase(at.ml.BaseModel):
207
  return signal
208
 
209
  @torch.no_grad()
210
- def sample(self, **kwargs):
211
- if self.noise_mode == "mask":
212
- return self.maskgit_sample(**kwargs)
213
- else:
214
- return self.paella_sample(**kwargs)
215
-
216
- def paella_sample(
217
- self,
218
- codec,
219
- time_steps: int = 400,
220
- sampling_steps: int = 36,
221
- start_tokens: Optional[torch.Tensor] = None,
222
- mask: Optional[torch.Tensor] = None,
223
- temperature: Union[float, Tuple[float, float]] = 0.8,
224
- top_k: int = None,
225
- sample: str = "gumbel",
226
- renoise_mode: str = "start",
227
- renoise_steps=None,
228
- typical_filtering=True,
229
- typical_mass=0.2,
230
- typical_min_tokens=1,
231
- return_signal=True,
232
- ):
233
-
234
- r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
235
- if renoise_steps == None:
236
- renoise_steps = sampling_steps - 1
237
-
238
- if isinstance(temperature, float):
239
- temperature = torch.tensor(temperature).repeat(sampling_steps)
240
- elif isinstance(temperature, tuple):
241
- assert len(temperature) == 2
242
- l, h = temperature
243
- temperature = torch.linspace(l, h, sampling_steps)
244
- else:
245
- raise TypeError(f"invalid type for temperature")
246
-
247
- if self.n_conditioning_codebooks > 0:
248
- assert (
249
- start_tokens is not None
250
- ), "must provide start_tokens if n_conditioning_codebooks > 0"
251
-
252
- if start_tokens is None:
253
- if self.noise_mode == "noise":
254
- z = torch.randint(
255
- 0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
256
- ).to(self.device)
257
- elif self.noise_mode == "mask":
258
- z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
259
- else:
260
- z = start_tokens
261
- assert (
262
- z.ndim == 3
263
- ), f"start_tokens must be shape (batch, n_codebooks, seq_len), got {z.shape}"
264
- assert z.shape[0] == 1, f"batch size must be 1"
265
-
266
- if mask is None:
267
- mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
268
- mask = mask[:, None, :]
269
- mask = mask.repeat(1, z.shape[1], 1)
270
-
271
- mask[:, : self.n_conditioning_codebooks, :] = 0.0
272
-
273
-
274
- z_true = z.clone()
275
-
276
- z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
277
- z_init = z.clone()
278
- for i, tmpt in enumerate(temperature):
279
- if renoise_mode == "prev":
280
- z_prev = z.clone()
281
-
282
- latents = self.embedding.from_codes(z, codec)
283
- logits = self.forward(latents, r[i])
284
-
285
- # for mask mode
286
- logits = self.add_truth_to_logits(z_true, logits, mask)
287
-
288
- # Apply topk sampling
289
- logits = logits.permute(0, 2, 1)
290
-
291
- z = self.sample_from_logits(
292
- logits,
293
- top_k=top_k,
294
- temperature=tmpt,
295
- sample=sample,
296
- typical_filtering=typical_filtering,
297
- typical_mass=typical_mass,
298
- typical_min_tokens=typical_min_tokens,
299
- )
300
-
301
- # add back in conditioning codebooks
302
- z = self.embedding.unflatten(z, n_codebooks=self.n_predict_codebooks)
303
- z = torch.cat(
304
- [z_init[:, : self.n_conditioning_codebooks, :], z], dim=1
305
- ).int()
306
-
307
- if i < renoise_steps:
308
- if renoise_mode == "prev":
309
- z, _ = self.add_noise(z, r[i + 1], random_x=z_prev)
310
- elif renoise_mode == "start":
311
- z, _ = self.add_noise(z, r[i + 1], random_x=z_init)
312
- elif renoise_mode == "rand":
313
- z, _ = self.add_noise(z, r[i + 1])
314
- else:
315
- raise ValueError(f"Invalid renoise_mode: {renoise_mode}")
316
-
317
- if mask is not None:
318
- z = start_tokens * (1 - mask) + z * mask
319
-
320
- if return_signal:
321
- return self.to_signal(z, codec)
322
- else:
323
- return z
324
-
325
- def maskgit_sample(
326
  self,
327
  codec,
328
  time_steps: int = 300,
 
10
  from einops import rearrange
11
  from tqdm import tqdm
12
 
13
+ from ..util import scalar_to_batch_tensor
14
+
15
 
16
  def log(t, eps=1e-20):
17
  return torch.log(t + eps)
 
26
  return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
27
 
28
 
 
 
 
29
  class VampBase(at.ml.BaseModel):
30
  def forward(self, x: torch.Tensor, r: torch.Tensor):
31
  raise NotImplementedError
 
149
  z_hat = z_hat * mask + truth * (1 - mask)
150
 
151
  z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
152
+ else:
153
+ raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
154
 
155
  return z_hat
156
 
 
187
 
188
  @torch.no_grad()
189
  def to_signal(self, z, codec):
190
+ """
191
+ convert a sequence of latents to a signal.
192
+ """
193
  if z.ndim == 2:
194
  z = self.embedding.unflatten(z)
195
  assert z.ndim == 3
 
211
  return signal
212
 
213
  @torch.no_grad()
214
+ def sample(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  self,
216
  codec,
217
  time_steps: int = 300,
vampnet/modules/layers.py CHANGED
@@ -132,6 +132,11 @@ class CodebookEmbedding(nn.Module):
132
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
 
134
  def from_codes(self, codes: torch.Tensor, codec):
 
 
 
 
 
135
  n_codebooks = codes.shape[1]
136
  latent = []
137
  for i in range(n_codebooks):
@@ -151,14 +156,23 @@ class CodebookEmbedding(nn.Module):
151
  return latent
152
 
153
  def forward(self, latents: torch.Tensor):
 
 
 
154
  x = self.out_proj(latents)
155
  return x
156
 
157
  def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
 
 
 
158
  n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
159
  return rearrange(tokens, "b c t -> b (t c)", c=n_c)
160
 
161
  def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
 
 
 
162
  nb, nt = flat_tokens.shape
163
 
164
  n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
 
132
  self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
133
 
134
  def from_codes(self, codes: torch.Tensor, codec):
135
+ """
136
+ get a sequence of continuous embeddings from a sequence of discrete codes.
137
+ unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
138
+ necessary for the language model, like <MASK>.
139
+ """
140
  n_codebooks = codes.shape[1]
141
  latent = []
142
  for i in range(n_codebooks):
 
156
  return latent
157
 
158
  def forward(self, latents: torch.Tensor):
159
+ """
160
+ project a sequence of latents to a sequence of embeddings
161
+ """
162
  x = self.out_proj(latents)
163
  return x
164
 
165
  def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
166
+ """
167
+ flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
168
+ """
169
  n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
170
  return rearrange(tokens, "b c t -> b (t c)", c=n_c)
171
 
172
  def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
173
+ """
174
+ unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
175
+ """
176
  nb, nt = flat_tokens.shape
177
 
178
  n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
vampnet/signal.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional, Tuple
3
+
4
+ from .util import scalar_to_batch_tensor
5
+
vampnet/util.py CHANGED
@@ -1,40 +1,9 @@
1
  import tqdm
2
- # import pathos
3
 
4
- def process_map(fn, *iterables, **tqdm_kwargs):
5
- """
6
- Equivalent of `list(map(fn, *iterables))`
7
- driven by `concurrent.futures.ProcessPoolExecutor`.
8
 
9
- Parameters
10
- ----------
11
- tqdm_class : optional
12
- `tqdm` class to use for bars [default: tqdm.auto.tqdm].
13
- max_workers : int, optional
14
- Maximum number of workers to spawn; passed to
15
- `concurrent.futures.ProcessPoolExecutor.__init__`.
16
- [default: min(32, cpu_count() + 4)].
17
- chunksize : int, optional
18
- Size of chunks sent to worker processes; passed to
19
- `concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
20
- lock_name : str, optional
21
- Member of `tqdm_class.get_lock()` to use [default: mp_lock].
22
- """
23
- from concurrent.futures import ProcessPoolExecutor
24
- if iterables and "chunksize" not in tqdm_kwargs:
25
- # default `chunksize=1` has poor performance for large iterables
26
- # (most time spent dispatching items to workers).
27
- longest_iterable_len = max(map(length_hint, iterables))
28
- if longest_iterable_len > 1000:
29
- from warnings import warn
30
- warn("Iterable length %d > 1000 but `chunksize` is not set."
31
- " This may seriously degrade multiprocess performance."
32
- " Set `chunksize=1` or more." % longest_iterable_len,
33
- TqdmWarning, stacklevel=2)
34
- if "lock_name" not in tqdm_kwargs:
35
- tqdm_kwargs = tqdm_kwargs.copy()
36
- tqdm_kwargs["lock_name"] = "mp_lock"
37
- return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
38
 
39
 
40
  def parallelize(
 
1
  import tqdm
 
2
 
3
+ import torch
 
 
 
4
 
5
+ def scalar_to_batch_tensor(x, batch_size):
6
+ return torch.tensor(x).repeat(batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  def parallelize(