Hugo Flores Garcia commited on
Commit
3f6f517
1 Parent(s): 75a7169

critical sampling fix, two demoes for comparing old and new sampling

Browse files
conf/generated/bulgarian-tv-choir/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/bulgarian-tv-choir/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
15
+ val/AudioLoader.sources: *id001
conf/generated/bulgarian-tv-choir/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/bulgarian-tv-choir/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
8
+ val/AudioLoader.sources: *id001
conf/generated/bulgarian-tv-choir/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/panchos/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/panchos/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/panchos/
15
+ val/AudioLoader.sources: *id001
conf/generated/panchos/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/panchos/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/panchos/
8
+ val/AudioLoader.sources: *id001
conf/generated/panchos/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/panchos/
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/generated/titi-monkey/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/titi-monkey/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/loras/titi-monkey.mp3
15
+ val/AudioLoader.sources: *id001
conf/generated/titi-monkey/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/titi-monkey/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/loras/titi-monkey.mp3
8
+ val/AudioLoader.sources: *id001
conf/generated/titi-monkey/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/loras/titi-monkey.mp3
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
conf/interface/spotdl.yml CHANGED
@@ -7,6 +7,6 @@ Interface.coarse2fine_chunk_size_s: 3
7
 
8
 
9
  AudioLoader.sources:
10
- # - /media/CHONK/hugo/spotdl/subsets/jazz-blues
11
  - /media/CHONK/null
12
 
 
7
 
8
 
9
  AudioLoader.sources:
10
+ # - /media/CHONK/hugo/spotdl/subsets/jazz-blues/
11
  - /media/CHONK/null
12
 
demo-new.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+ import yaml
4
+ import tempfile
5
+ import uuid
6
+ from dataclasses import dataclass, asdict
7
+
8
+ import numpy as np
9
+ import audiotools as at
10
+ import argbind
11
+
12
+ import gradio as gr
13
+ from vampnet.interface import Interface
14
+ from vampnet import mask as pmask
15
+
16
+ import logging
17
+ logger = logging.getLogger()
18
+ logger.setLevel(logging.CRITICAL)
19
+
20
+ Interface = argbind.bind(Interface)
21
+ AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
22
+
23
+ conf = argbind.parse_args()
24
+
25
+ with argbind.scope(conf):
26
+ interface = Interface()
27
+ loader = AudioLoader()
28
+ print(f"interface device is {interface.device}")
29
+
30
+ dataset = at.data.datasets.AudioDataset(
31
+ loader,
32
+ sample_rate=interface.codec.sample_rate,
33
+ duration=interface.coarse.chunk_size_s,
34
+ n_examples=5000,
35
+ without_replacement=True,
36
+ )
37
+
38
+
39
+ checkpoints = {
40
+ "spotdl": {
41
+ "coarse": "./models/spotdl/coarse.pth",
42
+ "c2f": "./models/spotdl/c2f.pth",
43
+ "codec": "./models/spotdl/codec.pth",
44
+ "full_ckpt": True
45
+ },
46
+ "berta": {
47
+ "coarse": "./models/finetuned/berta-goldman-speech/coarse.pth",
48
+ "c2f": "./models/finetuned/berta-goldman-speech/c2f.pth",
49
+ "codec": "./model/spotdl/codec.pth",
50
+ "full_ckpt": True
51
+ },
52
+ "xeno-canto-2": {
53
+ "coarse": "./models/finetuned/xeno-canto-2/coarse.pth",
54
+ "c2f": "./models/finetuned/xeno-canto-2/c2f.pth",
55
+ "codec": "./models/spotdl/codec.pth",
56
+ "full_ckpt": True
57
+ },
58
+ "panchos": {
59
+ "coarse": "./models/finetuned/panchos/coarse.pth",
60
+ "c2f": "./models/finetuned/panchos/c2f.pth",
61
+ "codec": "./models/spotdl/codec.pth",
62
+ "full_ckpt": False
63
+ },
64
+ "tv-choir": {
65
+ "coarse": "./models/finetuned/tv-choir/coarse.pth",
66
+ "c2f": "./models/finetuned/tv-choir/c2f.pth",
67
+ "codec": "./models/spotdl/codec.pth",
68
+ "full_ckpt": False
69
+ },
70
+ "titi": {
71
+ "coarse": "./models/finetuned/titi/coarse.pth",
72
+ "c2f": "./models/finetuned/titi/c2f.pth",
73
+ "codec": "./models/spotdl/codec.pth",
74
+ "full_ckpt": False
75
+ },
76
+ "titi-clean": {
77
+ "coarse": "./models/finetuned/titi-clean/coarse.pth",
78
+ "c2f": "./models/finetuned/titi-clean/c2f.pth",
79
+ "codec": "./models/spotdl/codec.pth",
80
+ "full_ckpt": False
81
+ }
82
+ }
83
+ interface.checkpoint_key = "spotdl"
84
+
85
+
86
+ OUT_DIR = Path("gradio-outputs")
87
+ OUT_DIR.mkdir(exist_ok=True, parents=True)
88
+
89
+
90
+ def load_audio(file):
91
+ print(file)
92
+ filepath = file.name
93
+ sig = at.AudioSignal.salient_excerpt(
94
+ filepath,
95
+ duration=interface.coarse.chunk_size_s
96
+ )
97
+ sig = interface.preprocess(sig)
98
+
99
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
100
+ out_dir.mkdir(parents=True, exist_ok=True)
101
+ sig.write(out_dir / "input.wav")
102
+ return sig.path_to_file
103
+
104
+
105
+ def load_random_audio():
106
+ index = np.random.randint(0, len(dataset))
107
+ sig = dataset[index]["signal"]
108
+ sig = interface.preprocess(sig)
109
+
110
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
111
+ out_dir.mkdir(parents=True, exist_ok=True)
112
+ sig.write(out_dir / "input.wav")
113
+ return sig.path_to_file
114
+
115
+
116
+ def _vamp(data, return_mask=False):
117
+
118
+ # if our checkpoint key is different, we need to load a new checkpoint
119
+ if data[checkpoint_key] != interface.checkpoint_key:
120
+ print(f"loading checkpoint {data[checkpoint_key]}")
121
+ interface.lora_load(
122
+ checkpoints[data[checkpoint_key]]["coarse"],
123
+ checkpoints[data[checkpoint_key]]["c2f"],
124
+ checkpoints[data[checkpoint_key]]["full_ckpt"],
125
+ )
126
+ interface.checkpoint_key = data[checkpoint_key]
127
+
128
+ out_dir = OUT_DIR / str(uuid.uuid4())
129
+ out_dir.mkdir()
130
+ sig = at.AudioSignal(data[input_audio])
131
+ #pitch shift input
132
+ sig = sig.shift_pitch(data[input_pitch_shift])
133
+
134
+ # TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
135
+
136
+ z = interface.encode(sig)
137
+
138
+ ncc = data[n_conditioning_codebooks]
139
+
140
+ # build the mask
141
+ mask = pmask.linear_random(z, data[rand_mask_intensity])
142
+ mask = pmask.mask_and(
143
+ mask, pmask.inpaint(
144
+ z,
145
+ interface.s2t(data[prefix_s]),
146
+ interface.s2t(data[suffix_s])
147
+ )
148
+ )
149
+ mask = pmask.mask_and(
150
+ mask, pmask.periodic_mask(
151
+ z,
152
+ data[periodic_p],
153
+ data[periodic_w],
154
+ random_roll=True
155
+ )
156
+ )
157
+ if data[onset_mask_width] > 0:
158
+ mask = pmask.mask_or(
159
+ mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
160
+ )
161
+ # these should be the last two mask ops
162
+ mask = pmask.dropout(mask, data[dropout])
163
+ mask = pmask.codebook_unmask(mask, ncc)
164
+
165
+ print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[init_temp]}, final temp {data[final_temp]}, use coarse2fine {data[use_coarse2fine]}")
166
+ # save the mask as a txt file
167
+ np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
168
+
169
+ # if data[topk] is not None:
170
+ # top_k = data[topk] if data[topk] > 0 else None
171
+ # else:
172
+ # top_k = None
173
+
174
+ zv, mask_z = interface.coarse_vamp(
175
+ z,
176
+ mask=mask,
177
+ sampling_steps=data[num_steps],
178
+ temperature=(data[init_temp]*10, data[final_temp]*10),
179
+ return_mask=True,
180
+ # sample=data[sampling_strategy],
181
+ typical_filtering=data[typical_filtering],
182
+ typical_mass=data[typical_mass],
183
+ typical_min_tokens=data[typical_min_tokens],
184
+ # top_k=top_k,
185
+ gen_fn=interface.coarse.generate,
186
+ )
187
+
188
+ if use_coarse2fine:
189
+ zv = interface.coarse_to_fine(zv)
190
+
191
+ sig = interface.to_signal(zv).cpu()
192
+ print("done")
193
+
194
+
195
+
196
+ sig.write(out_dir / "output.wav")
197
+
198
+ if return_mask:
199
+ mask = interface.to_signal(mask_z).cpu()
200
+ mask.write(out_dir / "mask.wav")
201
+ return sig.path_to_file, mask.path_to_file
202
+ else:
203
+ return sig.path_to_file
204
+
205
+ def vamp(data):
206
+ return _vamp(data, return_mask=True)
207
+
208
+ def api_vamp(data):
209
+ return _vamp(data, return_mask=False)
210
+
211
+ def save_vamp(data):
212
+ out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
213
+ out_dir.mkdir(parents=True, exist_ok=True)
214
+
215
+ sig_in = at.AudioSignal(data[input_audio])
216
+ sig_out = at.AudioSignal(data[output_audio])
217
+
218
+ sig_in.write(out_dir / "input.wav")
219
+ sig_out.write(out_dir / "output.wav")
220
+
221
+ _data = {
222
+ "init_temp": data[init_temp],
223
+ "final_temp": data[final_temp],
224
+ "prefix_s": data[prefix_s],
225
+ "suffix_s": data[suffix_s],
226
+ "rand_mask_intensity": data[rand_mask_intensity],
227
+ "num_steps": data[num_steps],
228
+ "notes": data[notes_text],
229
+ "periodic_period": data[periodic_p],
230
+ "periodic_width": data[periodic_w],
231
+ "n_conditioning_codebooks": data[n_conditioning_codebooks],
232
+ "use_coarse2fine": data[use_coarse2fine],
233
+ "stretch_factor": data[stretch_factor],
234
+ }
235
+
236
+ # save with yaml
237
+ with open(out_dir / "data.yaml", "w") as f:
238
+ yaml.dump(_data, f)
239
+
240
+ import zipfile
241
+ zip_path = out_dir.with_suffix(".zip")
242
+ with zipfile.ZipFile(zip_path, "w") as zf:
243
+ for file in out_dir.iterdir():
244
+ zf.write(file, file.name)
245
+
246
+ return f"saved! your save code is {out_dir.stem}", zip_path
247
+
248
+
249
+
250
+ with gr.Blocks() as demo:
251
+
252
+ with gr.Row():
253
+ with gr.Column():
254
+ use_coarse2fine = gr.Checkbox(
255
+ label="use coarse2fine",
256
+ value=True
257
+ )
258
+
259
+ manual_audio_upload = gr.File(
260
+ label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
261
+ file_types=["audio"]
262
+ )
263
+ load_random_audio_button = gr.Button("or load random audio")
264
+
265
+ input_audio = gr.Audio(
266
+ label="input audio",
267
+ interactive=False,
268
+ type="filepath",
269
+ )
270
+
271
+ audio_mask = gr.Audio(
272
+ label="audio mask (listen to this to hear the mask hints)",
273
+ interactive=False,
274
+ type="filepath",
275
+ )
276
+
277
+ # connect widgets
278
+ load_random_audio_button.click(
279
+ fn=load_random_audio,
280
+ inputs=[],
281
+ outputs=[ input_audio]
282
+ )
283
+
284
+ manual_audio_upload.change(
285
+ fn=load_audio,
286
+ inputs=[manual_audio_upload],
287
+ outputs=[ input_audio]
288
+ )
289
+
290
+ # mask settings
291
+ with gr.Column():
292
+
293
+ input_pitch_shift = gr.Slider(
294
+ label="input pitch shift (semitones)",
295
+ minimum=-36,
296
+ maximum=36,
297
+ step=1,
298
+ value=0,
299
+ )
300
+
301
+ rand_mask_intensity = gr.Slider(
302
+ label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
303
+ minimum=0.0,
304
+ maximum=1.0,
305
+ value=1.0
306
+ )
307
+
308
+ periodic_p = gr.Slider(
309
+ label="periodic prompt (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
310
+ minimum=0,
311
+ maximum=128,
312
+ step=1,
313
+ value=3,
314
+ )
315
+ periodic_w = gr.Slider(
316
+ label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
317
+ minimum=1,
318
+ maximum=20,
319
+ step=1,
320
+ value=1,
321
+ )
322
+
323
+ onset_mask_width = gr.Slider(
324
+ label="onset mask width (steps, 1 step ~= 10milliseconds)",
325
+ minimum=0,
326
+ maximum=20,
327
+ step=1,
328
+ value=5,
329
+ )
330
+
331
+ with gr.Accordion("extras ", open=False):
332
+ n_conditioning_codebooks = gr.Number(
333
+ label="number of conditioning codebooks. probably 0",
334
+ value=0,
335
+ precision=0,
336
+ )
337
+
338
+ stretch_factor = gr.Slider(
339
+ label="time stretch factor",
340
+ minimum=0,
341
+ maximum=64,
342
+ step=1,
343
+ value=1,
344
+ )
345
+
346
+
347
+ with gr.Accordion("prefix/suffix hints", open=False):
348
+ prefix_s = gr.Slider(
349
+ label="prefix hint length (seconds)",
350
+ minimum=0.0,
351
+ maximum=10.0,
352
+ value=0.0
353
+ )
354
+ suffix_s = gr.Slider(
355
+ label="suffix hint length (seconds)",
356
+ minimum=0.0,
357
+ maximum=10.0,
358
+ value=0.0
359
+ )
360
+
361
+ with gr.Accordion("temperature settings", open=False):
362
+ init_temp = gr.Slider(
363
+ label="initial temperature (should probably stay between 0.6 and 1)",
364
+ minimum=0.0,
365
+ maximum=1.5,
366
+ value=0.8
367
+ )
368
+ final_temp = gr.Slider(
369
+ label="final temperature (should probably stay between 0.7 and 2)",
370
+ minimum=0.0,
371
+ maximum=2.0,
372
+ value=0.8
373
+ )
374
+
375
+ with gr.Accordion("sampling settings", open=False):
376
+ sampling_strategy = gr.Radio(
377
+ label="sampling strategy",
378
+ choices=["gumbel", "multinomial"],
379
+ value="gumbel"
380
+ )
381
+ typical_filtering = gr.Checkbox(
382
+ label="typical filtering (cannot be used with topk)",
383
+ value=False
384
+ )
385
+ typical_mass = gr.Slider(
386
+ label="typical mass (should probably stay between 0.1 and 0.5)",
387
+ minimum=0.01,
388
+ maximum=0.99,
389
+ value=0.2
390
+ )
391
+ typical_min_tokens = gr.Slider(
392
+ label="typical min tokens (should probably stay between 1 and 256)",
393
+ minimum=1,
394
+ maximum=256,
395
+ step=1,
396
+ value=1
397
+ )
398
+
399
+
400
+
401
+
402
+ num_steps = gr.Slider(
403
+ label="number of steps (should normally be between 12 and 36)",
404
+ minimum=1,
405
+ maximum=128,
406
+ step=1,
407
+ value=36
408
+ )
409
+
410
+ dropout = gr.Slider(
411
+ label="mask dropout",
412
+ minimum=0.0,
413
+ maximum=1.0,
414
+ step=0.01,
415
+ value=0.0
416
+ )
417
+
418
+
419
+ # mask settings
420
+ with gr.Column():
421
+ checkpoint_key = gr.Radio(
422
+ label="checkpoint",
423
+ choices=list(checkpoints.keys()),
424
+ value="spotdl"
425
+ )
426
+ vamp_button = gr.Button("vamp!!!")
427
+ output_audio = gr.Audio(
428
+ label="output audio",
429
+ interactive=False,
430
+ type="filepath"
431
+ )
432
+
433
+
434
+
435
+ # with gr.Column():
436
+ # with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
437
+ # use_beats = gr.Checkbox(
438
+ # label="use beat hints (helps the output stick to the beat structure of the input)",
439
+ # value=False
440
+ # )
441
+
442
+ # snap_to_beats = gr.Checkbox(
443
+ # label="trim to beat markers (uncheck if the output audio is too short.)",
444
+ # value=True
445
+ # )
446
+
447
+ # beat_unmask_dur = gr.Slider(
448
+ # label="duration",
449
+ # minimum=0.0,
450
+ # maximum=3.0,
451
+ # value=0.07
452
+ # )
453
+
454
+
455
+ notes_text = gr.Textbox(
456
+ label="type any notes about the generated audio here",
457
+ value="",
458
+ interactive=True
459
+ )
460
+ save_button = gr.Button("save vamp")
461
+ download_file = gr.File(
462
+ label="vamp to download will appear here",
463
+ interactive=False
464
+ )
465
+ use_as_input_button = gr.Button("use output as input")
466
+
467
+ thank_you = gr.Markdown("")
468
+
469
+
470
+ _inputs = {
471
+ input_audio,
472
+ num_steps,
473
+ init_temp, final_temp,
474
+ prefix_s, suffix_s,
475
+ rand_mask_intensity,
476
+ periodic_p, periodic_w,
477
+ n_conditioning_codebooks,
478
+ dropout,
479
+ use_coarse2fine,
480
+ stretch_factor,
481
+ onset_mask_width,
482
+ input_pitch_shift,
483
+ sampling_strategy,
484
+ typical_filtering,
485
+ typical_mass,
486
+ typical_min_tokens,
487
+ # topk,
488
+ checkpoint_key
489
+ }
490
+
491
+ # connect widgets
492
+ vamp_button.click(
493
+ fn=vamp,
494
+ inputs=_inputs,
495
+ outputs=[output_audio, audio_mask],
496
+ )
497
+
498
+ api_vamp_button = gr.Button("api vamp")
499
+ api_vamp_button.click(
500
+ fn=api_vamp,
501
+ inputs=_inputs,
502
+ outputs=[output_audio],
503
+ api_name="vamp"
504
+ )
505
+
506
+ use_as_input_button.click(
507
+ fn=lambda x: x,
508
+ inputs=[output_audio],
509
+ outputs=[input_audio]
510
+ )
511
+
512
+ save_button.click(
513
+ fn=save_vamp,
514
+ inputs=_inputs | {notes_text, output_audio},
515
+ outputs=[thank_you, download_file]
516
+ )
517
+
518
+ demo.launch(share=True, enable_queue=False, debug=True, server_name="0.0.0.0")
demo.py CHANGED
@@ -32,6 +32,47 @@ dataset = at.data.datasets.AudioDataset(
32
  )
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  OUT_DIR = Path("gradio-outputs")
36
  OUT_DIR.mkdir(exist_ok=True, parents=True)
37
 
@@ -63,6 +104,19 @@ def load_random_audio():
63
 
64
 
65
  def _vamp(data, return_mask=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  out_dir = OUT_DIR / str(uuid.uuid4())
67
  out_dir.mkdir()
68
  sig = at.AudioSignal(data[input_audio])
@@ -229,8 +283,8 @@ with gr.Blocks() as demo:
229
 
230
  input_pitch_shift = gr.Slider(
231
  label="input pitch shift (semitones)",
232
- minimum=-12,
233
- maximum=12,
234
  step=1,
235
  value=0,
236
  )
@@ -247,7 +301,7 @@ with gr.Blocks() as demo:
247
  minimum=0,
248
  maximum=128,
249
  step=1,
250
- value=9,
251
  )
252
  periodic_w = gr.Slider(
253
  label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
@@ -262,7 +316,7 @@ with gr.Blocks() as demo:
262
  minimum=0,
263
  maximum=20,
264
  step=1,
265
- value=0,
266
  )
267
 
268
  with gr.Accordion("extras ", open=False):
@@ -361,6 +415,11 @@ with gr.Blocks() as demo:
361
 
362
  # mask settings
363
  with gr.Column():
 
 
 
 
 
364
  vamp_button = gr.Button("vamp!!!")
365
  output_audio = gr.Audio(
366
  label="output audio",
@@ -423,6 +482,7 @@ with gr.Blocks() as demo:
423
  typical_mass,
424
  typical_min_tokens,
425
  topk,
 
426
  }
427
 
428
  # connect widgets
@@ -452,4 +512,4 @@ with gr.Blocks() as demo:
452
  outputs=[thank_you, download_file]
453
  )
454
 
455
- demo.launch(share=True, enable_queue=False, debug=True)
 
32
  )
33
 
34
 
35
+ checkpoints = {
36
+ "spotdl": {
37
+ "coarse": "./models/spotdl/coarse.pth",
38
+ "c2f": "./models/spotdl/c2f.pth",
39
+ "codec": "./models/spotdl/codec.pth",
40
+ "full_ckpt": True
41
+ },
42
+ "berta": {
43
+ "coarse": "./models/finetuned/berta-goldman-speech/coarse.pth",
44
+ "c2f": "./models/finetuned/berta-goldman-speech/c2f.pth",
45
+ "codec": "./model/spotdl/codec.pth",
46
+ "full_ckpt": True
47
+ },
48
+ "xeno-canto-2": {
49
+ "coarse": "./models/finetuned/xeno-canto-2/coarse.pth",
50
+ "c2f": "./models/finetuned/xeno-canto-2/c2f.pth",
51
+ "codec": "./models/spotdl/codec.pth",
52
+ "full_ckpt": True
53
+ },
54
+ "panchos": {
55
+ "coarse": "./models/finetuned/panchos/coarse.pth",
56
+ "c2f": "./models/finetuned/panchos/c2f.pth",
57
+ "codec": "./models/spotdl/codec.pth",
58
+ "full_ckpt": False
59
+ },
60
+ "tv-choir": {
61
+ "coarse": "./models/finetuned/tv-choir/coarse.pth",
62
+ "c2f": "./models/finetuned/tv-choir/c2f.pth",
63
+ "codec": "./models/spotdl/codec.pth",
64
+ "full_ckpt": False
65
+ },
66
+ "titi": {
67
+ "coarse": "./models/finetuned/titi/coarse.pth",
68
+ "c2f": "./models/finetuned/titi/c2f.pth",
69
+ "codec": "./models/spotdl/codec.pth",
70
+ "full_ckpt": False
71
+ }
72
+ }
73
+ interface.checkpoint_key = "spotdl"
74
+
75
+
76
  OUT_DIR = Path("gradio-outputs")
77
  OUT_DIR.mkdir(exist_ok=True, parents=True)
78
 
 
104
 
105
 
106
  def _vamp(data, return_mask=False):
107
+
108
+ # if our checkpoint key is different, we need to load a new checkpoint
109
+ if data[checkpoint_key] != interface.checkpoint_key:
110
+ print(f"loading checkpoint {data[checkpoint_key]}")
111
+ interface.lora_load(
112
+ checkpoints[data[checkpoint_key]]["coarse"],
113
+ checkpoints[data[checkpoint_key]]["c2f"],
114
+ checkpoints[data[checkpoint_key]]["full_ckpt"],
115
+ reset=(data[checkpoint_key] == "spotdl")
116
+ )
117
+ interface.checkpoint_key = data[checkpoint_key]
118
+
119
+
120
  out_dir = OUT_DIR / str(uuid.uuid4())
121
  out_dir.mkdir()
122
  sig = at.AudioSignal(data[input_audio])
 
283
 
284
  input_pitch_shift = gr.Slider(
285
  label="input pitch shift (semitones)",
286
+ minimum=-36,
287
+ maximum=36,
288
  step=1,
289
  value=0,
290
  )
 
301
  minimum=0,
302
  maximum=128,
303
  step=1,
304
+ value=3,
305
  )
306
  periodic_w = gr.Slider(
307
  label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
 
316
  minimum=0,
317
  maximum=20,
318
  step=1,
319
+ value=5,
320
  )
321
 
322
  with gr.Accordion("extras ", open=False):
 
415
 
416
  # mask settings
417
  with gr.Column():
418
+ checkpoint_key = gr.Radio(
419
+ label="checkpoint",
420
+ choices=list(checkpoints.keys()),
421
+ value="spotdl"
422
+ )
423
  vamp_button = gr.Button("vamp!!!")
424
  output_audio = gr.Audio(
425
  label="output audio",
 
482
  typical_mass,
483
  typical_min_tokens,
484
  topk,
485
+ checkpoint_key
486
  }
487
 
488
  # connect widgets
 
512
  outputs=[thank_you, download_file]
513
  )
514
 
515
+ demo.launch(share=True, enable_queue=False, debug=True, server_name="0.0.0.0")
scripts/exp/train.py CHANGED
@@ -353,12 +353,9 @@ def train(
353
  mask[:, vn.n_conditioning_codebooks :, :],
354
  )
355
 
356
- if vn.noise_mode == "mask":
357
- # replace target with ignore index for masked tokens
358
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
359
- output["loss"] = criterion(z_hat, t_masked)
360
- else:
361
- output["loss"] = criterion(z_hat, target)
362
 
363
  self._metrics(
364
  vn=vn,
@@ -429,12 +426,9 @@ def train(
429
  )
430
 
431
  output = {}
432
- if vn.noise_mode == "mask":
433
- # replace target with ignore index for masked tokens
434
- t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
435
- output["loss"] = criterion(z_hat, t_masked)
436
- else:
437
- output["loss"] = criterion(z_hat, target)
438
 
439
  self._metrics(
440
  vn=vn,
 
353
  mask[:, vn.n_conditioning_codebooks :, :],
354
  )
355
 
356
+ # replace target with ignore index for masked tokens
357
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
358
+ output["loss"] = criterion(z_hat, t_masked)
 
 
 
359
 
360
  self._metrics(
361
  vn=vn,
 
426
  )
427
 
428
  output = {}
429
+ # replace target with ignore index for masked tokens
430
+ t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
431
+ output["loss"] = criterion(z_hat, t_masked)
 
 
 
432
 
433
  self._metrics(
434
  vn=vn,
scripts/utils/augment.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import audiotools as at
4
+ from audiotools import AudioSignal
5
+
6
+ import argbind
7
+ import tqdm
8
+
9
+
10
+ from pedalboard import (
11
+ Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
12
+ )
13
+ from pedalboard.io import AudioFile
14
+
15
+ # Read in a whole file, resampling to our desired sample rate:
16
+ samplerate = 44100.0
17
+ with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
18
+ audio = f.read(f.frames)
19
+
20
+ # Make a pretty interesting sounding guitar pedalboard:
21
+ board = Pedalboard([
22
+ Compressor(threshold_db=-50, ratio=25),
23
+ Gain(gain_db=30),
24
+ Chorus(),
25
+ LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
26
+ Phaser(),
27
+ Convolution("./guitar_amp.wav", 1.0),
28
+ Reverb(room_size=0.25),
29
+ ])
30
+
31
+
32
+ @argbind.bind(without_prefix=True)
33
+ def augment(
34
+ audio_folder: Path,
35
+ dest_folder: Path,
36
+ n_augmentations: int = 10,
37
+ ):
38
+ """
39
+ Augment a folder of audio files by applying audiotools and pedalboard transforms.
40
+
41
+ The dest foler will contain a folder for each of the clean dataset's files.
42
+ Under each of these folders, there will be a clean file and many augmented files.
43
+ """
44
+
45
+ audio_files = at.util.find_audio(audio_folder)
46
+
47
+ for audio_file in tqdm.tqdm(audio_files):
48
+ subtree = dest_folder / audio_file.relative_to(audio_folder).parent
49
+ subdir = subtree / audio_file.stem
50
+ subdir.mkdir(parents=True, exist_ok=True)
51
+
52
+ # apply pedalboard transforms
53
+ for i in range(n_augmentations):
vampnet/interface.py CHANGED
@@ -97,17 +97,36 @@ class Interface(torch.nn.Module):
97
 
98
  def lora_load(
99
  self,
100
- coarse_lora_ckpt: str = None,
101
- coarse2fine_lora_ckpt: str = None,
 
102
  ):
103
- if coarse_lora_ckpt is not None:
104
- self.coarse.to("cpu")
105
- self.coarse.load_state_dict(torch.load(coarse_lora_ckpt, map_location="cpu"))
106
- self.coarse.to(self.device)
107
- if coarse2fine_lora_ckpt is not None:
108
- self.c2f.to("cpu")
109
- self.c2f.load_state_dict(torch.load(coarse2fine_lora_ckpt, map_location="cpu"))
110
- self.c2f.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
  def s2t(self, seconds: float):
@@ -290,6 +309,7 @@ class Interface(torch.nn.Module):
290
  z,
291
  mask,
292
  return_mask=False,
 
293
  **kwargs
294
  ):
295
  # coarse z
@@ -301,7 +321,8 @@ class Interface(torch.nn.Module):
301
  cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
302
  cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
303
 
304
- c_vamp = self.coarse.sample(
 
305
  codec=self.codec,
306
  time_steps=cz.shape[-1],
307
  start_tokens=cz,
@@ -310,8 +331,6 @@ class Interface(torch.nn.Module):
310
  **kwargs
311
  )
312
 
313
- # replace the mask token in cz_masked with random tokens
314
- # so that we can decode it
315
  if return_mask:
316
  return c_vamp, cz_masked
317
 
@@ -320,53 +339,48 @@ class Interface(torch.nn.Module):
320
 
321
  if __name__ == "__main__":
322
  import audiotools as at
 
 
 
 
323
 
324
  interface = Interface(
325
  coarse_ckpt="./models/spotdl/coarse.pth",
326
  coarse2fine_ckpt="./models/spotdl/c2f.pth",
327
  codec_ckpt="./models/spotdl/codec.pth",
328
- device="cpu"
329
  )
330
 
331
- sig = at.AudioSignal('cali.mp3', duration=10)
332
 
333
  z = interface.encode(sig)
334
 
335
- mask = linear_random(z, 0.8)
336
- print(mask)
337
- mask = mask_and(
338
- mask, inpaint(
339
- z,
340
- interface.s2t(3),
341
- interface.s2t(3)
342
- )
343
- )
344
- print(mask)
345
  mask = mask_and(
346
  mask, periodic_mask(
347
  z,
348
- 7,
349
  1,
350
  random_roll=True
351
  )
352
  )
353
- mask = dropout(mask, 0.0)
354
- mask = codebook_unmask(mask, 0)
355
 
356
 
357
  zv, mask_z = interface.coarse_vamp(
358
  z,
359
  mask=mask,
360
- sampling_steps=1,
361
- temperature=(0.8,1),
362
- return_mask=True
 
363
  )
364
 
365
  use_coarse2fine = False
366
  if use_coarse2fine:
367
  zv = interface.coarse_to_fine(zv)
368
 
369
- print(mask_z)
370
  mask = interface.to_signal(mask_z).cpu()
371
 
372
  sig = interface.to_signal(zv).cpu()
 
97
 
98
  def lora_load(
99
  self,
100
+ coarse_ckpt: str = None,
101
+ c2f_ckpt: str = None,
102
+ full_ckpts: bool = False,
103
  ):
104
+ if full_ckpts:
105
+ if coarse_ckpt is not None:
106
+ self.coarse = _load_model(
107
+ ckpt=coarse_ckpt,
108
+ device=self.device,
109
+ chunk_size_s=self.coarse.chunk_size_s,
110
+ )
111
+ if c2f_ckpt is not None:
112
+ self.c2f = _load_model(
113
+ ckpt=c2f_ckpt,
114
+ device=self.device,
115
+ chunk_size_s=self.c2f.chunk_size_s,
116
+ )
117
+ else:
118
+ if coarse_ckpt is not None:
119
+ self.coarse.to("cpu")
120
+ state_dict = torch.load(coarse_ckpt, map_location="cpu")
121
+
122
+ self.coarse.load_state_dict(state_dict, strict=False)
123
+ self.coarse.to(self.device)
124
+ if c2f_ckpt is not None:
125
+ self.c2f.to("cpu")
126
+ state_dict = torch.load(c2f_ckpt, map_location="cpu")
127
+
128
+ self.c2f.load_state_dict(state_dict, strict=False)
129
+ self.c2f.to(self.device)
130
 
131
 
132
  def s2t(self, seconds: float):
 
309
  z,
310
  mask,
311
  return_mask=False,
312
+ gen_fn=None,
313
  **kwargs
314
  ):
315
  # coarse z
 
321
  cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
322
  cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
323
 
324
+ gen_fn = gen_fn or self.coarse.sample
325
+ c_vamp = gen_fn(
326
  codec=self.codec,
327
  time_steps=cz.shape[-1],
328
  start_tokens=cz,
 
331
  **kwargs
332
  )
333
 
 
 
334
  if return_mask:
335
  return c_vamp, cz_masked
336
 
 
339
 
340
  if __name__ == "__main__":
341
  import audiotools as at
342
+ import logging
343
+ logger = logging.getLogger()
344
+ logger.setLevel(logging.INFO)
345
+ torch.set_printoptions(threshold=10000)
346
 
347
  interface = Interface(
348
  coarse_ckpt="./models/spotdl/coarse.pth",
349
  coarse2fine_ckpt="./models/spotdl/c2f.pth",
350
  codec_ckpt="./models/spotdl/codec.pth",
351
+ device="cuda"
352
  )
353
 
354
+ sig = at.AudioSignal('introspection ii-1.mp3', duration=10)
355
 
356
  z = interface.encode(sig)
357
 
358
+ mask = linear_random(z, 1.0)
 
 
 
 
 
 
 
 
 
359
  mask = mask_and(
360
  mask, periodic_mask(
361
  z,
362
+ 32,
363
  1,
364
  random_roll=True
365
  )
366
  )
367
+ # mask = dropout(mask, 0.0)
368
+ # mask = codebook_unmask(mask, 0)
369
 
370
 
371
  zv, mask_z = interface.coarse_vamp(
372
  z,
373
  mask=mask,
374
+ sampling_steps=36,
375
+ temperature=6.0,
376
+ return_mask=True,
377
+ # gen_fn=interface.coarse.generate
378
  )
379
 
380
  use_coarse2fine = False
381
  if use_coarse2fine:
382
  zv = interface.coarse_to_fine(zv)
383
 
 
384
  mask = interface.to_signal(mask_z).cpu()
385
 
386
  sig = interface.to_signal(zv).cpu()
vampnet/mask.py CHANGED
@@ -6,7 +6,7 @@ from audiotools import AudioSignal
6
  from .util import scalar_to_batch_tensor
7
 
8
  def _gamma(r):
9
- return (r * torch.pi / 2).cos()
10
 
11
  def _invgamma(y):
12
  if not torch.is_tensor(y):
 
6
  from .util import scalar_to_batch_tensor
7
 
8
  def _gamma(r):
9
+ return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
10
 
11
  def _invgamma(y):
12
  if not torch.is_tensor(y):
vampnet/modules/transformer.py CHANGED
@@ -1,4 +1,5 @@
1
  import math
 
2
  from typing import Optional, Tuple, Union
3
 
4
  import numpy as np
@@ -19,17 +20,17 @@ from ..mask import _gamma
19
 
20
  LORA_R = 8
21
 
22
- def log(t, eps=1e-20):
23
- return torch.log(t + eps)
24
 
25
 
26
- def gumbel_noise(t):
27
- noise = torch.zeros_like(t).uniform_(0, 1)
28
- return -log(-log(noise))
29
 
30
 
31
  def gumbel_sample(t, temperature=1.0, dim=-1):
32
- return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
33
 
34
 
35
  class RMSNorm(nn.Module):
@@ -477,23 +478,16 @@ class VampNet(at.ml.BaseModel):
477
  self.flash_attn = flash_attn
478
  self.noise_mode = noise_mode
479
 
480
- if noise_mode == "mask":
481
- special_tokens = ["MASK"]
482
- elif noise_mode == "random":
483
- special_tokens = None
484
- else:
485
- raise ValueError(f"Unknown noise mode: {noise_mode}")
486
 
487
  self.embedding = CodebookEmbedding(
488
  latent_dim=latent_dim,
489
  n_codebooks=n_codebooks,
490
  vocab_size=vocab_size,
491
  emb_dim=embedding_dim,
492
- special_tokens=special_tokens,
493
  )
494
-
495
- if noise_mode == "mask":
496
- self.mask_token = self.embedding.special_idxs["MASK"]
497
 
498
  self.transformer = TransformerStack(
499
  d_model=embedding_dim,
@@ -584,23 +578,20 @@ class VampNet(at.ml.BaseModel):
584
  z_hat,
585
  mask,
586
  ):
587
- if self.noise_mode == "mask":
588
- z_true = z_true[:, self.n_conditioning_codebooks :, :]
589
- mask = mask[:, self.n_conditioning_codebooks :, :]
590
-
591
- truth = F.one_hot(z_true, self.vocab_size)
592
- mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
593
- z_hat = rearrange(
594
- z_hat,
595
- "b p (t c) -> b c t p",
596
- c=self.n_codebooks - self.n_conditioning_codebooks,
597
- )
598
 
599
- z_hat = z_hat * mask + truth * (1 - mask)
600
 
601
- z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
602
- else:
603
- raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
604
 
605
  return z_hat
606
 
@@ -742,6 +733,272 @@ class VampNet(at.ml.BaseModel):
742
  else:
743
  return z
744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745
  def sample_from_logits(
746
  logits,
747
  top_k: int = None,
@@ -798,7 +1055,6 @@ def sample_from_logits(
798
  return inferred
799
 
800
 
801
-
802
  if __name__ == "__main__":
803
  # import argbind
804
  from .layers import num_params
 
1
  import math
2
+ import logging
3
  from typing import Optional, Tuple, Union
4
 
5
  import numpy as np
 
20
 
21
  LORA_R = 8
22
 
23
+ # def log(t, eps=1e-20):
24
+ # return torch.log(t + eps)
25
 
26
 
27
+ def gumbel_noise_like(t):
28
+ noise = torch.zeros_like(t).uniform_(1e-20, 1)
29
+ return -torch.log(-torch.log(noise))
30
 
31
 
32
  def gumbel_sample(t, temperature=1.0, dim=-1):
33
+ return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
34
 
35
 
36
  class RMSNorm(nn.Module):
 
478
  self.flash_attn = flash_attn
479
  self.noise_mode = noise_mode
480
 
481
+ assert self.noise_mode == "mask", "deprecated"
 
 
 
 
 
482
 
483
  self.embedding = CodebookEmbedding(
484
  latent_dim=latent_dim,
485
  n_codebooks=n_codebooks,
486
  vocab_size=vocab_size,
487
  emb_dim=embedding_dim,
488
+ special_tokens=["MASK"],
489
  )
490
+ self.mask_token = self.embedding.special_idxs["MASK"]
 
 
491
 
492
  self.transformer = TransformerStack(
493
  d_model=embedding_dim,
 
578
  z_hat,
579
  mask,
580
  ):
581
+ z_true = z_true[:, self.n_conditioning_codebooks :, :]
582
+ mask = mask[:, self.n_conditioning_codebooks :, :]
583
+
584
+ truth = F.one_hot(z_true, self.vocab_size)
585
+ mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
586
+ z_hat = rearrange(
587
+ z_hat,
588
+ "b p (t c) -> b c t p",
589
+ c=self.n_codebooks - self.n_conditioning_codebooks,
590
+ )
 
591
 
592
+ z_hat = z_hat * mask + truth * (1 - mask)
593
 
594
+ z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
 
 
595
 
596
  return z_hat
597
 
 
733
  else:
734
  return z
735
 
736
+ @torch.no_grad()
737
+ def generate(
738
+ self,
739
+ codec,
740
+ time_steps: int = 300,
741
+ sampling_steps: int = 36,
742
+ start_tokens: Optional[torch.Tensor] = None,
743
+ mask: Optional[torch.Tensor] = None,
744
+ temperature: Union[float, Tuple[float, float]] = 0.8,
745
+ typical_filtering=False,
746
+ typical_mass=0.2,
747
+ typical_min_tokens=1,
748
+ return_signal=True,
749
+ ):
750
+ logging.info(f"beginning generation with {sampling_steps} steps")
751
+
752
+ #####################
753
+ # resolve temperature #
754
+ #####################
755
+ if isinstance(temperature, float):
756
+ temperature = torch.tensor(temperature).repeat(sampling_steps)
757
+ elif isinstance(temperature, tuple):
758
+ assert len(temperature) == 2
759
+ l, h = temperature
760
+ temperature = torch.linspace(l, h, sampling_steps)
761
+ else:
762
+ raise TypeError(f"invalid type for temperature")
763
+
764
+ logging.info(f"temperature: {temperature}")
765
+
766
+
767
+ #####################
768
+ # resolve initial z #
769
+ #####################
770
+ z = start_tokens
771
+
772
+ if z is None:
773
+ z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
774
+ self.device
775
+ )
776
+
777
+ logging.info(f"created z with shape {z.shape}")
778
+
779
+
780
+ #################
781
+ # resolve mask #
782
+ #################
783
+
784
+ if mask is None:
785
+ mask = torch.ones_like(z).to(self.device).int()
786
+ mask[:, : self.n_conditioning_codebooks, :] = 0.0
787
+ if mask.ndim == 2:
788
+ mask = mask[:, None, :].repeat(1, z.shape[1], 1)
789
+ # init_mask = mask.clone()
790
+
791
+ logging.info(f"created mask with shape {mask.shape}")
792
+
793
+
794
+ ###########
795
+ # set up #
796
+ ##########
797
+ # apply the mask to z
798
+ z_masked = z.masked_fill(mask.bool(), self.mask_token)
799
+ # logging.info(f"z_masked: {z_masked}")
800
+
801
+ # how many mask tokens to begin with?
802
+ num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
803
+ logging.info(f"num mask tokens at start: {num_mask_tokens_at_start}")
804
+
805
+ # our r steps
806
+ r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
807
+ logging.info(f"r steps: {r_steps}")
808
+
809
+ # how many codebooks are we inferring vs conditioning on?
810
+ n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
811
+ logging.info(f"n infer codebooks: {n_infer_codebooks}")
812
+
813
+ #################
814
+ # begin sampling #
815
+ #################
816
+
817
+ for i in range(sampling_steps):
818
+ logging.info(f"step {i} of {sampling_steps}")
819
+
820
+ # our current temperature
821
+ tmpt = temperature[i]
822
+ logging.info(f"temperature: {tmpt}")
823
+
824
+ # our current schedule step
825
+ r = r_steps[i : i + 1]
826
+ logging.info(f"r: {r}")
827
+
828
+ # get latents
829
+ latents = self.embedding.from_codes(z_masked, codec)
830
+ logging.info(f"computed latents with shape: {latents.shape}")
831
+
832
+
833
+ # infer from latents
834
+ # NOTE: this collapses the codebook dimension into the sequence dimension
835
+ logits = self.forward(latents, r) # b, prob, seq
836
+ logits = logits.permute(0, 2, 1) # b, seq, prob
837
+ if typical_filtering:
838
+ typical_filter(logits,
839
+ typical_mass=typical_mass,
840
+ typical_min_tokens=typical_min_tokens
841
+ )
842
+
843
+
844
+ logging.info(f"permuted logits with shape: {logits.shape}")
845
+
846
+
847
+ # logits2probs
848
+ probs = torch.softmax(logits, dim=-1)
849
+ logging.info(f"computed probs with shape: {probs.shape}")
850
+
851
+ # flatten z_masked and mask, so we can deal with the sampling logic
852
+ # we'll unflatten them at the end of the loop for the next forward pass
853
+ z_masked = codebook_flatten(z_masked)
854
+
855
+ # sample from logits with multinomial sampling
856
+ b = probs.shape[0]
857
+ probs = rearrange(probs, "b seq prob -> (b seq) prob")
858
+
859
+
860
+
861
+ sampled_z = torch.multinomial(probs, 1).squeeze(-1)
862
+
863
+ sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
864
+ probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
865
+ logging.info(f"sampled z with shape: {sampled_z.shape}")
866
+
867
+ # update the mask
868
+ mask = (z_masked == self.mask_token).int()
869
+ logging.info(f"updated mask with shape: {mask.shape}")
870
+
871
+ # add z back into sampled z where the mask was false
872
+ sampled_z = torch.where(
873
+ mask.bool(), sampled_z, z_masked
874
+ )
875
+ logging.info(f"added z back into sampled z with shape: {sampled_z.shape}")
876
+
877
+
878
+ # get the confidences: which tokens did we sample?
879
+ selected_probs = (
880
+ torch.take_along_dim(
881
+ probs, sampled_z.long().unsqueeze(-1),
882
+ dim=-1
883
+ ).squeeze(-1)
884
+ )
885
+
886
+ # ignore any tokens that weren't masked
887
+ selected_probs = torch.where(
888
+ mask.bool(), selected_probs, torch.inf
889
+ )
890
+
891
+ # get the num tokens to mask, according to the schedule
892
+ num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
893
+ logging.info(f"num to mask: {num_to_mask}")
894
+
895
+ num_to_mask = torch.maximum(
896
+ torch.tensor(1),
897
+ torch.minimum(
898
+ mask.sum(dim=-1, keepdim=True) - 1,
899
+ num_to_mask
900
+ )
901
+ )
902
+
903
+
904
+ # get our new mask
905
+ # print(tmpt * (1-_gamma(r)))
906
+ mask = mask_by_random_topk(
907
+ num_to_mask, selected_probs, tmpt * (1-r)
908
+ )
909
+
910
+ # print(f"most confident tokens: ")
911
+ # print(torch.take_along_dim(
912
+ # sampled_z, selected_probs.argsort(descending=False), dim=-1)
913
+ # )
914
+ # print(sampled_z[~mask.bool()])
915
+
916
+
917
+ # update the mask
918
+ z_masked = torch.where(
919
+ mask.bool(), self.mask_token, sampled_z
920
+ )
921
+ logging.info(f"updated z_masked with shape: {z_masked.shape}")
922
+
923
+
924
+ z_masked = codebook_unflatten(z_masked, self.n_codebooks)
925
+ mask = codebook_unflatten(mask, self.n_codebooks)
926
+ logging.info(f"unflattened z_masked with shape: {z_masked.shape}")
927
+
928
+
929
+ logging.info(f"updated z_masked with shape: {z_masked.shape}")
930
+
931
+
932
+ logging.info(f"finished sampling")
933
+ z = codebook_unflatten(sampled_z, self.n_codebooks)
934
+
935
+ if return_signal:
936
+ return self.to_signal(z, codec)
937
+ else:
938
+ return z
939
+
940
+
941
+ def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
942
+ """
943
+ Args:
944
+ num_to_mask (int): number of tokens to mask
945
+ probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
946
+ temperature (float, optional): temperature. Defaults to 1.0.
947
+ """
948
+ logging.info(f"masking by random topk")
949
+ logging.info(f"num to mask: {num_to_mask}")
950
+ logging.info(f"probs shape: {probs.shape}")
951
+ logging.info(f"temperature: {temperature}")
952
+ logging.info("")
953
+
954
+ confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
955
+ logging.info(f"confidence shape: {confidence.shape}")
956
+
957
+ sorted_confidence, sorted_idx = confidence.sort(dim=-1)
958
+ logging.info(f"sorted confidence shape: {sorted_confidence.shape}")
959
+ logging.info(f"sorted idx shape: {sorted_idx.shape}")
960
+
961
+ # get the cut off threshold, given the mask length
962
+ cut_off = torch.take_along_dim(
963
+ sorted_confidence, num_to_mask, axis=-1
964
+ )
965
+ logging.info(f"cut off shape: {cut_off.shape}")
966
+
967
+ # mask out the tokens
968
+ mask = confidence < cut_off
969
+ logging.info(f"mask shape: {mask.shape}")
970
+
971
+ return mask
972
+
973
+ def typical_filter(
974
+ logits,
975
+ typical_mass: float = 0.95,
976
+ typical_min_tokens: int = 1,):
977
+ nb, nt, _ = logits.shape
978
+ x_flat = rearrange(logits, "b t l -> (b t ) l")
979
+ x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
980
+ x_flat_norm_p = torch.exp(x_flat_norm)
981
+ entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
982
+
983
+ c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
984
+ c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
985
+ x_flat_cumsum = (
986
+ x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
987
+ )
988
+
989
+ last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
990
+ sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
991
+ 1, last_ind.view(-1, 1)
992
+ )
993
+ if typical_min_tokens > 1:
994
+ sorted_indices_to_remove[..., :typical_min_tokens] = 0
995
+ indices_to_remove = sorted_indices_to_remove.scatter(
996
+ 1, x_flat_indices, sorted_indices_to_remove
997
+ )
998
+ x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
999
+ logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
1000
+ return logits
1001
+
1002
  def sample_from_logits(
1003
  logits,
1004
  top_k: int = None,
 
1055
  return inferred
1056
 
1057
 
 
1058
  if __name__ == "__main__":
1059
  # import argbind
1060
  from .layers import num_params