Hugo Flores Garcia commited on
Commit
09b9691
1 Parent(s): 3f6f517

settling down on the new sampling routine

Browse files
conf/generated/breaks-steps/c2f.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ AudioDataset.duration: 3.0
4
+ AudioDataset.loudness_cutoff: -40.0
5
+ VampNet.embedding_dim: 1280
6
+ VampNet.n_codebooks: 14
7
+ VampNet.n_conditioning_codebooks: 4
8
+ VampNet.n_heads: 20
9
+ VampNet.n_layers: 16
10
+ fine_tune: true
11
+ fine_tune_checkpoint: ./models/spotdl/c2f.pth
12
+ save_path: ./runs/breaks-steps/c2f
13
+ train/AudioLoader.sources: &id001
14
+ - /media/CHONK/hugo/breaks-steps
15
+ val/AudioLoader.sources: *id001
conf/generated/breaks-steps/coarse.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ $include:
2
+ - conf/lora/lora.yml
3
+ fine_tune: true
4
+ fine_tune_checkpoint: ./models/spotdl/coarse.pth
5
+ save_path: ./runs/breaks-steps/coarse
6
+ train/AudioLoader.sources: &id001
7
+ - /media/CHONK/hugo/breaks-steps
8
+ val/AudioLoader.sources: *id001
conf/generated/breaks-steps/interface.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ AudioLoader.sources:
2
+ - - /media/CHONK/hugo/breaks-steps
3
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
4
+ Interface.coarse2fine_lora_ckpt: ./runs/breaks-steps/c2f/latest/lora.pth
5
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
6
+ Interface.coarse_lora_ckpt: ./runs/breaks-steps/coarse/latest/lora.pth
7
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
demo-new.py DELETED
@@ -1,518 +0,0 @@
1
- from pathlib import Path
2
- from typing import Tuple
3
- import yaml
4
- import tempfile
5
- import uuid
6
- from dataclasses import dataclass, asdict
7
-
8
- import numpy as np
9
- import audiotools as at
10
- import argbind
11
-
12
- import gradio as gr
13
- from vampnet.interface import Interface
14
- from vampnet import mask as pmask
15
-
16
- import logging
17
- logger = logging.getLogger()
18
- logger.setLevel(logging.CRITICAL)
19
-
20
- Interface = argbind.bind(Interface)
21
- AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
22
-
23
- conf = argbind.parse_args()
24
-
25
- with argbind.scope(conf):
26
- interface = Interface()
27
- loader = AudioLoader()
28
- print(f"interface device is {interface.device}")
29
-
30
- dataset = at.data.datasets.AudioDataset(
31
- loader,
32
- sample_rate=interface.codec.sample_rate,
33
- duration=interface.coarse.chunk_size_s,
34
- n_examples=5000,
35
- without_replacement=True,
36
- )
37
-
38
-
39
- checkpoints = {
40
- "spotdl": {
41
- "coarse": "./models/spotdl/coarse.pth",
42
- "c2f": "./models/spotdl/c2f.pth",
43
- "codec": "./models/spotdl/codec.pth",
44
- "full_ckpt": True
45
- },
46
- "berta": {
47
- "coarse": "./models/finetuned/berta-goldman-speech/coarse.pth",
48
- "c2f": "./models/finetuned/berta-goldman-speech/c2f.pth",
49
- "codec": "./model/spotdl/codec.pth",
50
- "full_ckpt": True
51
- },
52
- "xeno-canto-2": {
53
- "coarse": "./models/finetuned/xeno-canto-2/coarse.pth",
54
- "c2f": "./models/finetuned/xeno-canto-2/c2f.pth",
55
- "codec": "./models/spotdl/codec.pth",
56
- "full_ckpt": True
57
- },
58
- "panchos": {
59
- "coarse": "./models/finetuned/panchos/coarse.pth",
60
- "c2f": "./models/finetuned/panchos/c2f.pth",
61
- "codec": "./models/spotdl/codec.pth",
62
- "full_ckpt": False
63
- },
64
- "tv-choir": {
65
- "coarse": "./models/finetuned/tv-choir/coarse.pth",
66
- "c2f": "./models/finetuned/tv-choir/c2f.pth",
67
- "codec": "./models/spotdl/codec.pth",
68
- "full_ckpt": False
69
- },
70
- "titi": {
71
- "coarse": "./models/finetuned/titi/coarse.pth",
72
- "c2f": "./models/finetuned/titi/c2f.pth",
73
- "codec": "./models/spotdl/codec.pth",
74
- "full_ckpt": False
75
- },
76
- "titi-clean": {
77
- "coarse": "./models/finetuned/titi-clean/coarse.pth",
78
- "c2f": "./models/finetuned/titi-clean/c2f.pth",
79
- "codec": "./models/spotdl/codec.pth",
80
- "full_ckpt": False
81
- }
82
- }
83
- interface.checkpoint_key = "spotdl"
84
-
85
-
86
- OUT_DIR = Path("gradio-outputs")
87
- OUT_DIR.mkdir(exist_ok=True, parents=True)
88
-
89
-
90
- def load_audio(file):
91
- print(file)
92
- filepath = file.name
93
- sig = at.AudioSignal.salient_excerpt(
94
- filepath,
95
- duration=interface.coarse.chunk_size_s
96
- )
97
- sig = interface.preprocess(sig)
98
-
99
- out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
100
- out_dir.mkdir(parents=True, exist_ok=True)
101
- sig.write(out_dir / "input.wav")
102
- return sig.path_to_file
103
-
104
-
105
- def load_random_audio():
106
- index = np.random.randint(0, len(dataset))
107
- sig = dataset[index]["signal"]
108
- sig = interface.preprocess(sig)
109
-
110
- out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
111
- out_dir.mkdir(parents=True, exist_ok=True)
112
- sig.write(out_dir / "input.wav")
113
- return sig.path_to_file
114
-
115
-
116
- def _vamp(data, return_mask=False):
117
-
118
- # if our checkpoint key is different, we need to load a new checkpoint
119
- if data[checkpoint_key] != interface.checkpoint_key:
120
- print(f"loading checkpoint {data[checkpoint_key]}")
121
- interface.lora_load(
122
- checkpoints[data[checkpoint_key]]["coarse"],
123
- checkpoints[data[checkpoint_key]]["c2f"],
124
- checkpoints[data[checkpoint_key]]["full_ckpt"],
125
- )
126
- interface.checkpoint_key = data[checkpoint_key]
127
-
128
- out_dir = OUT_DIR / str(uuid.uuid4())
129
- out_dir.mkdir()
130
- sig = at.AudioSignal(data[input_audio])
131
- #pitch shift input
132
- sig = sig.shift_pitch(data[input_pitch_shift])
133
-
134
- # TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
135
-
136
- z = interface.encode(sig)
137
-
138
- ncc = data[n_conditioning_codebooks]
139
-
140
- # build the mask
141
- mask = pmask.linear_random(z, data[rand_mask_intensity])
142
- mask = pmask.mask_and(
143
- mask, pmask.inpaint(
144
- z,
145
- interface.s2t(data[prefix_s]),
146
- interface.s2t(data[suffix_s])
147
- )
148
- )
149
- mask = pmask.mask_and(
150
- mask, pmask.periodic_mask(
151
- z,
152
- data[periodic_p],
153
- data[periodic_w],
154
- random_roll=True
155
- )
156
- )
157
- if data[onset_mask_width] > 0:
158
- mask = pmask.mask_or(
159
- mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
160
- )
161
- # these should be the last two mask ops
162
- mask = pmask.dropout(mask, data[dropout])
163
- mask = pmask.codebook_unmask(mask, ncc)
164
-
165
- print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[init_temp]}, final temp {data[final_temp]}, use coarse2fine {data[use_coarse2fine]}")
166
- # save the mask as a txt file
167
- np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
168
-
169
- # if data[topk] is not None:
170
- # top_k = data[topk] if data[topk] > 0 else None
171
- # else:
172
- # top_k = None
173
-
174
- zv, mask_z = interface.coarse_vamp(
175
- z,
176
- mask=mask,
177
- sampling_steps=data[num_steps],
178
- temperature=(data[init_temp]*10, data[final_temp]*10),
179
- return_mask=True,
180
- # sample=data[sampling_strategy],
181
- typical_filtering=data[typical_filtering],
182
- typical_mass=data[typical_mass],
183
- typical_min_tokens=data[typical_min_tokens],
184
- # top_k=top_k,
185
- gen_fn=interface.coarse.generate,
186
- )
187
-
188
- if use_coarse2fine:
189
- zv = interface.coarse_to_fine(zv)
190
-
191
- sig = interface.to_signal(zv).cpu()
192
- print("done")
193
-
194
-
195
-
196
- sig.write(out_dir / "output.wav")
197
-
198
- if return_mask:
199
- mask = interface.to_signal(mask_z).cpu()
200
- mask.write(out_dir / "mask.wav")
201
- return sig.path_to_file, mask.path_to_file
202
- else:
203
- return sig.path_to_file
204
-
205
- def vamp(data):
206
- return _vamp(data, return_mask=True)
207
-
208
- def api_vamp(data):
209
- return _vamp(data, return_mask=False)
210
-
211
- def save_vamp(data):
212
- out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
213
- out_dir.mkdir(parents=True, exist_ok=True)
214
-
215
- sig_in = at.AudioSignal(data[input_audio])
216
- sig_out = at.AudioSignal(data[output_audio])
217
-
218
- sig_in.write(out_dir / "input.wav")
219
- sig_out.write(out_dir / "output.wav")
220
-
221
- _data = {
222
- "init_temp": data[init_temp],
223
- "final_temp": data[final_temp],
224
- "prefix_s": data[prefix_s],
225
- "suffix_s": data[suffix_s],
226
- "rand_mask_intensity": data[rand_mask_intensity],
227
- "num_steps": data[num_steps],
228
- "notes": data[notes_text],
229
- "periodic_period": data[periodic_p],
230
- "periodic_width": data[periodic_w],
231
- "n_conditioning_codebooks": data[n_conditioning_codebooks],
232
- "use_coarse2fine": data[use_coarse2fine],
233
- "stretch_factor": data[stretch_factor],
234
- }
235
-
236
- # save with yaml
237
- with open(out_dir / "data.yaml", "w") as f:
238
- yaml.dump(_data, f)
239
-
240
- import zipfile
241
- zip_path = out_dir.with_suffix(".zip")
242
- with zipfile.ZipFile(zip_path, "w") as zf:
243
- for file in out_dir.iterdir():
244
- zf.write(file, file.name)
245
-
246
- return f"saved! your save code is {out_dir.stem}", zip_path
247
-
248
-
249
-
250
- with gr.Blocks() as demo:
251
-
252
- with gr.Row():
253
- with gr.Column():
254
- use_coarse2fine = gr.Checkbox(
255
- label="use coarse2fine",
256
- value=True
257
- )
258
-
259
- manual_audio_upload = gr.File(
260
- label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
261
- file_types=["audio"]
262
- )
263
- load_random_audio_button = gr.Button("or load random audio")
264
-
265
- input_audio = gr.Audio(
266
- label="input audio",
267
- interactive=False,
268
- type="filepath",
269
- )
270
-
271
- audio_mask = gr.Audio(
272
- label="audio mask (listen to this to hear the mask hints)",
273
- interactive=False,
274
- type="filepath",
275
- )
276
-
277
- # connect widgets
278
- load_random_audio_button.click(
279
- fn=load_random_audio,
280
- inputs=[],
281
- outputs=[ input_audio]
282
- )
283
-
284
- manual_audio_upload.change(
285
- fn=load_audio,
286
- inputs=[manual_audio_upload],
287
- outputs=[ input_audio]
288
- )
289
-
290
- # mask settings
291
- with gr.Column():
292
-
293
- input_pitch_shift = gr.Slider(
294
- label="input pitch shift (semitones)",
295
- minimum=-36,
296
- maximum=36,
297
- step=1,
298
- value=0,
299
- )
300
-
301
- rand_mask_intensity = gr.Slider(
302
- label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
303
- minimum=0.0,
304
- maximum=1.0,
305
- value=1.0
306
- )
307
-
308
- periodic_p = gr.Slider(
309
- label="periodic prompt (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
310
- minimum=0,
311
- maximum=128,
312
- step=1,
313
- value=3,
314
- )
315
- periodic_w = gr.Slider(
316
- label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
317
- minimum=1,
318
- maximum=20,
319
- step=1,
320
- value=1,
321
- )
322
-
323
- onset_mask_width = gr.Slider(
324
- label="onset mask width (steps, 1 step ~= 10milliseconds)",
325
- minimum=0,
326
- maximum=20,
327
- step=1,
328
- value=5,
329
- )
330
-
331
- with gr.Accordion("extras ", open=False):
332
- n_conditioning_codebooks = gr.Number(
333
- label="number of conditioning codebooks. probably 0",
334
- value=0,
335
- precision=0,
336
- )
337
-
338
- stretch_factor = gr.Slider(
339
- label="time stretch factor",
340
- minimum=0,
341
- maximum=64,
342
- step=1,
343
- value=1,
344
- )
345
-
346
-
347
- with gr.Accordion("prefix/suffix hints", open=False):
348
- prefix_s = gr.Slider(
349
- label="prefix hint length (seconds)",
350
- minimum=0.0,
351
- maximum=10.0,
352
- value=0.0
353
- )
354
- suffix_s = gr.Slider(
355
- label="suffix hint length (seconds)",
356
- minimum=0.0,
357
- maximum=10.0,
358
- value=0.0
359
- )
360
-
361
- with gr.Accordion("temperature settings", open=False):
362
- init_temp = gr.Slider(
363
- label="initial temperature (should probably stay between 0.6 and 1)",
364
- minimum=0.0,
365
- maximum=1.5,
366
- value=0.8
367
- )
368
- final_temp = gr.Slider(
369
- label="final temperature (should probably stay between 0.7 and 2)",
370
- minimum=0.0,
371
- maximum=2.0,
372
- value=0.8
373
- )
374
-
375
- with gr.Accordion("sampling settings", open=False):
376
- sampling_strategy = gr.Radio(
377
- label="sampling strategy",
378
- choices=["gumbel", "multinomial"],
379
- value="gumbel"
380
- )
381
- typical_filtering = gr.Checkbox(
382
- label="typical filtering (cannot be used with topk)",
383
- value=False
384
- )
385
- typical_mass = gr.Slider(
386
- label="typical mass (should probably stay between 0.1 and 0.5)",
387
- minimum=0.01,
388
- maximum=0.99,
389
- value=0.2
390
- )
391
- typical_min_tokens = gr.Slider(
392
- label="typical min tokens (should probably stay between 1 and 256)",
393
- minimum=1,
394
- maximum=256,
395
- step=1,
396
- value=1
397
- )
398
-
399
-
400
-
401
-
402
- num_steps = gr.Slider(
403
- label="number of steps (should normally be between 12 and 36)",
404
- minimum=1,
405
- maximum=128,
406
- step=1,
407
- value=36
408
- )
409
-
410
- dropout = gr.Slider(
411
- label="mask dropout",
412
- minimum=0.0,
413
- maximum=1.0,
414
- step=0.01,
415
- value=0.0
416
- )
417
-
418
-
419
- # mask settings
420
- with gr.Column():
421
- checkpoint_key = gr.Radio(
422
- label="checkpoint",
423
- choices=list(checkpoints.keys()),
424
- value="spotdl"
425
- )
426
- vamp_button = gr.Button("vamp!!!")
427
- output_audio = gr.Audio(
428
- label="output audio",
429
- interactive=False,
430
- type="filepath"
431
- )
432
-
433
-
434
-
435
- # with gr.Column():
436
- # with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
437
- # use_beats = gr.Checkbox(
438
- # label="use beat hints (helps the output stick to the beat structure of the input)",
439
- # value=False
440
- # )
441
-
442
- # snap_to_beats = gr.Checkbox(
443
- # label="trim to beat markers (uncheck if the output audio is too short.)",
444
- # value=True
445
- # )
446
-
447
- # beat_unmask_dur = gr.Slider(
448
- # label="duration",
449
- # minimum=0.0,
450
- # maximum=3.0,
451
- # value=0.07
452
- # )
453
-
454
-
455
- notes_text = gr.Textbox(
456
- label="type any notes about the generated audio here",
457
- value="",
458
- interactive=True
459
- )
460
- save_button = gr.Button("save vamp")
461
- download_file = gr.File(
462
- label="vamp to download will appear here",
463
- interactive=False
464
- )
465
- use_as_input_button = gr.Button("use output as input")
466
-
467
- thank_you = gr.Markdown("")
468
-
469
-
470
- _inputs = {
471
- input_audio,
472
- num_steps,
473
- init_temp, final_temp,
474
- prefix_s, suffix_s,
475
- rand_mask_intensity,
476
- periodic_p, periodic_w,
477
- n_conditioning_codebooks,
478
- dropout,
479
- use_coarse2fine,
480
- stretch_factor,
481
- onset_mask_width,
482
- input_pitch_shift,
483
- sampling_strategy,
484
- typical_filtering,
485
- typical_mass,
486
- typical_min_tokens,
487
- # topk,
488
- checkpoint_key
489
- }
490
-
491
- # connect widgets
492
- vamp_button.click(
493
- fn=vamp,
494
- inputs=_inputs,
495
- outputs=[output_audio, audio_mask],
496
- )
497
-
498
- api_vamp_button = gr.Button("api vamp")
499
- api_vamp_button.click(
500
- fn=api_vamp,
501
- inputs=_inputs,
502
- outputs=[output_audio],
503
- api_name="vamp"
504
- )
505
-
506
- use_as_input_button.click(
507
- fn=lambda x: x,
508
- inputs=[output_audio],
509
- outputs=[input_audio]
510
- )
511
-
512
- save_button.click(
513
- fn=save_vamp,
514
- inputs=_inputs | {notes_text, output_audio},
515
- outputs=[thank_you, download_file]
516
- )
517
-
518
- demo.launch(share=True, enable_queue=False, debug=True, server_name="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -68,7 +68,19 @@ checkpoints = {
68
  "c2f": "./models/finetuned/titi/c2f.pth",
69
  "codec": "./models/spotdl/codec.pth",
70
  "full_ckpt": False
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
72
  }
73
  interface.checkpoint_key = "spotdl"
74
 
@@ -112,10 +124,8 @@ def _vamp(data, return_mask=False):
112
  checkpoints[data[checkpoint_key]]["coarse"],
113
  checkpoints[data[checkpoint_key]]["c2f"],
114
  checkpoints[data[checkpoint_key]]["full_ckpt"],
115
- reset=(data[checkpoint_key] == "spotdl")
116
  )
117
  interface.checkpoint_key = data[checkpoint_key]
118
-
119
 
120
  out_dir = OUT_DIR / str(uuid.uuid4())
121
  out_dir.mkdir()
@@ -154,30 +164,24 @@ def _vamp(data, return_mask=False):
154
  mask = pmask.dropout(mask, data[dropout])
155
  mask = pmask.codebook_unmask(mask, ncc)
156
 
157
- print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[init_temp]}, final temp {data[final_temp]}, use coarse2fine {data[use_coarse2fine]}")
158
  # save the mask as a txt file
159
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
160
 
161
- if data[topk] is not None:
162
- top_k = data[topk] if data[topk] > 0 else None
163
- else:
164
- top_k = None
165
-
166
  zv, mask_z = interface.coarse_vamp(
167
  z,
168
  mask=mask,
169
  sampling_steps=data[num_steps],
170
- temperature=(data[init_temp], data[final_temp]),
171
  return_mask=True,
172
- sample=data[sampling_strategy],
173
  typical_filtering=data[typical_filtering],
174
  typical_mass=data[typical_mass],
175
  typical_min_tokens=data[typical_min_tokens],
176
- top_k=top_k,
177
  )
178
 
179
  if use_coarse2fine:
180
- zv = interface.coarse_to_fine(zv)
181
 
182
  sig = interface.to_signal(zv).cpu()
183
  print("done")
@@ -210,8 +214,7 @@ def save_vamp(data):
210
  sig_out.write(out_dir / "output.wav")
211
 
212
  _data = {
213
- "init_temp": data[init_temp],
214
- "final_temp": data[final_temp],
215
  "prefix_s": data[prefix_s],
216
  "suffix_s": data[suffix_s],
217
  "rand_mask_intensity": data[rand_mask_intensity],
@@ -349,53 +352,32 @@ with gr.Blocks() as demo:
349
  value=0.0
350
  )
351
 
352
- with gr.Accordion("temperature settings", open=False):
353
- init_temp = gr.Slider(
354
- label="initial temperature (should probably stay between 0.6 and 1)",
355
- minimum=0.0,
356
- maximum=1.5,
357
- value=0.8
358
- )
359
- final_temp = gr.Slider(
360
- label="final temperature (should probably stay between 0.7 and 2)",
361
- minimum=0.0,
362
- maximum=2.0,
363
- value=1.0
364
- )
365
 
366
  with gr.Accordion("sampling settings", open=False):
367
- sampling_strategy = gr.Radio(
368
- label="sampling strategy",
369
- choices=["gumbel", "multinomial"],
370
- value="gumbel"
371
- )
372
  typical_filtering = gr.Checkbox(
373
- label="typical filtering (cannot be used with topk)",
374
  value=True
375
  )
376
  typical_mass = gr.Slider(
377
  label="typical mass (should probably stay between 0.1 and 0.5)",
378
  minimum=0.01,
379
  maximum=0.99,
380
- value=0.2
381
  )
382
  typical_min_tokens = gr.Slider(
383
  label="typical min tokens (should probably stay between 1 and 256)",
384
  minimum=1,
385
  maximum=256,
386
  step=1,
387
- value=1
388
- )
389
- topk = gr.Slider(
390
- label="topk (cannot be used with typical filtering). 0 = None",
391
- minimum=0,
392
- maximum=256,
393
- step=1,
394
- value=0
395
  )
396
 
397
-
398
-
399
  num_steps = gr.Slider(
400
  label="number of steps (should normally be between 12 and 36)",
401
  minimum=1,
@@ -427,28 +409,6 @@ with gr.Blocks() as demo:
427
  type="filepath"
428
  )
429
 
430
-
431
-
432
- # with gr.Column():
433
- # with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
434
- # use_beats = gr.Checkbox(
435
- # label="use beat hints (helps the output stick to the beat structure of the input)",
436
- # value=False
437
- # )
438
-
439
- # snap_to_beats = gr.Checkbox(
440
- # label="trim to beat markers (uncheck if the output audio is too short.)",
441
- # value=True
442
- # )
443
-
444
- # beat_unmask_dur = gr.Slider(
445
- # label="duration",
446
- # minimum=0.0,
447
- # maximum=3.0,
448
- # value=0.07
449
- # )
450
-
451
-
452
  notes_text = gr.Textbox(
453
  label="type any notes about the generated audio here",
454
  value="",
@@ -467,7 +427,7 @@ with gr.Blocks() as demo:
467
  _inputs = {
468
  input_audio,
469
  num_steps,
470
- init_temp, final_temp,
471
  prefix_s, suffix_s,
472
  rand_mask_intensity,
473
  periodic_p, periodic_w,
@@ -477,11 +437,9 @@ with gr.Blocks() as demo:
477
  stretch_factor,
478
  onset_mask_width,
479
  input_pitch_shift,
480
- sampling_strategy,
481
  typical_filtering,
482
  typical_mass,
483
  typical_min_tokens,
484
- topk,
485
  checkpoint_key
486
  }
487
 
 
68
  "c2f": "./models/finetuned/titi/c2f.pth",
69
  "codec": "./models/spotdl/codec.pth",
70
  "full_ckpt": False
71
+ },
72
+ "titi-clean": {
73
+ "coarse": "./models/finetuned/titi-clean/coarse.pth",
74
+ "c2f": "./models/finetuned/titi-clean/c2f.pth",
75
+ "codec": "./models/spotdl/codec.pth",
76
+ "full_ckpt": False
77
+ },
78
+ "breaks-steps": {
79
+ "coarse": "./models/finetuned/breaks-steps/coarse.pth",
80
+ "c2f": None, #"./models/finetuned/breaks-steps/c2f.pth",
81
+ "codec": "./models/spotdl/codec.pth",
82
+ "full_ckpt": False
83
+ },
84
  }
85
  interface.checkpoint_key = "spotdl"
86
 
 
124
  checkpoints[data[checkpoint_key]]["coarse"],
125
  checkpoints[data[checkpoint_key]]["c2f"],
126
  checkpoints[data[checkpoint_key]]["full_ckpt"],
 
127
  )
128
  interface.checkpoint_key = data[checkpoint_key]
 
129
 
130
  out_dir = OUT_DIR / str(uuid.uuid4())
131
  out_dir.mkdir()
 
164
  mask = pmask.dropout(mask, data[dropout])
165
  mask = pmask.codebook_unmask(mask, ncc)
166
 
167
+ print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
168
  # save the mask as a txt file
169
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
170
 
 
 
 
 
 
171
  zv, mask_z = interface.coarse_vamp(
172
  z,
173
  mask=mask,
174
  sampling_steps=data[num_steps],
175
+ temperature=data[temp]*10,
176
  return_mask=True,
 
177
  typical_filtering=data[typical_filtering],
178
  typical_mass=data[typical_mass],
179
  typical_min_tokens=data[typical_min_tokens],
180
+ gen_fn=interface.coarse.generate,
181
  )
182
 
183
  if use_coarse2fine:
184
+ zv = interface.coarse_to_fine(zv, temperature=data[temp])
185
 
186
  sig = interface.to_signal(zv).cpu()
187
  print("done")
 
214
  sig_out.write(out_dir / "output.wav")
215
 
216
  _data = {
217
+ "temp": data[temp],
 
218
  "prefix_s": data[prefix_s],
219
  "suffix_s": data[suffix_s],
220
  "rand_mask_intensity": data[rand_mask_intensity],
 
352
  value=0.0
353
  )
354
 
355
+ temp = gr.Slider(
356
+ label="temperature",
357
+ minimum=0.0,
358
+ maximum=1.5,
359
+ value=0.8
360
+ )
 
 
 
 
 
 
 
361
 
362
  with gr.Accordion("sampling settings", open=False):
 
 
 
 
 
363
  typical_filtering = gr.Checkbox(
364
+ label="typical filtering ",
365
  value=True
366
  )
367
  typical_mass = gr.Slider(
368
  label="typical mass (should probably stay between 0.1 and 0.5)",
369
  minimum=0.01,
370
  maximum=0.99,
371
+ value=0.15
372
  )
373
  typical_min_tokens = gr.Slider(
374
  label="typical min tokens (should probably stay between 1 and 256)",
375
  minimum=1,
376
  maximum=256,
377
  step=1,
378
+ value=64
 
 
 
 
 
 
 
379
  )
380
 
 
 
381
  num_steps = gr.Slider(
382
  label="number of steps (should normally be between 12 and 36)",
383
  minimum=1,
 
409
  type="filepath"
410
  )
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  notes_text = gr.Textbox(
413
  label="type any notes about the generated audio here",
414
  value="",
 
427
  _inputs = {
428
  input_audio,
429
  num_steps,
430
+ temp,
431
  prefix_s, suffix_s,
432
  rand_mask_intensity,
433
  periodic_p, periodic_w,
 
437
  stretch_factor,
438
  onset_mask_width,
439
  input_pitch_shift,
 
440
  typical_filtering,
441
  typical_mass,
442
  typical_min_tokens,
 
443
  checkpoint_key
444
  }
445
 
scripts/exp/train.py CHANGED
@@ -491,7 +491,7 @@ def train(
491
  num_samples = z.shape[0]
492
 
493
  for i in range(num_samples):
494
- sampled = accel.unwrap(model).sample(
495
  codec=codec,
496
  time_steps=z.shape[-1],
497
  start_tokens=z[i : i + 1],
@@ -503,29 +503,6 @@ def train(
503
  plot_fn=None,
504
  )
505
 
506
- # sample in 1 step (only for coarse2fine models)
507
- if accel.unwrap(model).n_conditioning_codebooks > 0:
508
- sampled_argmax = accel.unwrap(model).sample(
509
- codec=codec,
510
- time_steps=z.shape[-1],
511
- start_tokens=z[i : i + 1],
512
- sample="argmax",
513
- sampling_steps=1,
514
- )
515
- sampled_argmax.cpu().write_audio_to_tb(
516
- f"sampled_1step-argmax/{i}",
517
- self.writer,
518
- step=self.state.epoch,
519
- plot_fn=None,
520
- )
521
- conditioning = z[i:i+1, : accel.unwrap(model).n_conditioning_codebooks, :]
522
- conditioning = accel.unwrap(model).to_signal(conditioning, codec)
523
- conditioning.cpu().write_audio_to_tb(
524
- f"conditioning/{i}",
525
- self.writer,
526
- step=self.state.epoch,
527
- plot_fn=None,
528
- )
529
 
530
  def save_imputation(self, z: torch.Tensor):
531
  n_prefix = int(z.shape[-1] * 0.25)
@@ -543,7 +520,7 @@ def train(
543
  imputed = []
544
  for i in range(len(z)):
545
  imputed.append(
546
- vn.sample(
547
  codec=codec,
548
  time_steps=z.shape[-1],
549
  start_tokens=z[i][None, ...],
 
491
  num_samples = z.shape[0]
492
 
493
  for i in range(num_samples):
494
+ sampled = accel.unwrap(model).generate(
495
  codec=codec,
496
  time_steps=z.shape[-1],
497
  start_tokens=z[i : i + 1],
 
503
  plot_fn=None,
504
  )
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
507
  def save_imputation(self, z: torch.Tensor):
508
  n_prefix = int(z.shape[-1] * 0.25)
 
520
  imputed = []
521
  for i in range(len(z)):
522
  imputed.append(
523
+ vn.generate(
524
  codec=codec,
525
  time_steps=z.shape[-1],
526
  start_tokens=z[i][None, ...],
vampnet/interface.py CHANGED
@@ -292,7 +292,7 @@ class Interface(torch.nn.Module):
292
  fine_z = []
293
  for i in range(n_chunks):
294
  chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len]
295
- chunk = self.c2f.sample(
296
  codec=self.codec,
297
  time_steps=chunk_len,
298
  start_tokens=chunk,
@@ -343,6 +343,7 @@ if __name__ == "__main__":
343
  logger = logging.getLogger()
344
  logger.setLevel(logging.INFO)
345
  torch.set_printoptions(threshold=10000)
 
346
 
347
  interface = Interface(
348
  coarse_ckpt="./models/spotdl/coarse.pth",
@@ -372,20 +373,20 @@ if __name__ == "__main__":
372
  z,
373
  mask=mask,
374
  sampling_steps=36,
375
- temperature=6.0,
376
  return_mask=True,
377
- # gen_fn=interface.coarse.generate
378
  )
379
 
380
- use_coarse2fine = False
381
  if use_coarse2fine:
382
- zv = interface.coarse_to_fine(zv)
383
 
384
  mask = interface.to_signal(mask_z).cpu()
385
 
386
  sig = interface.to_signal(zv).cpu()
387
  print("done")
388
 
389
- sig.write("output.wav")
390
  mask.write("mask.wav")
391
 
 
292
  fine_z = []
293
  for i in range(n_chunks):
294
  chunk = coarse_z[:, :, i * chunk_len : (i + 1) * chunk_len]
295
+ chunk = self.c2f.generate(
296
  codec=self.codec,
297
  time_steps=chunk_len,
298
  start_tokens=chunk,
 
343
  logger = logging.getLogger()
344
  logger.setLevel(logging.INFO)
345
  torch.set_printoptions(threshold=10000)
346
+ at.util.seed(42)
347
 
348
  interface = Interface(
349
  coarse_ckpt="./models/spotdl/coarse.pth",
 
373
  z,
374
  mask=mask,
375
  sampling_steps=36,
376
+ temperature=8.0,
377
  return_mask=True,
378
+ gen_fn=interface.coarse.generate
379
  )
380
 
381
+ use_coarse2fine = True
382
  if use_coarse2fine:
383
+ zv = interface.coarse_to_fine(zv, temperature=0.8)
384
 
385
  mask = interface.to_signal(mask_z).cpu()
386
 
387
  sig = interface.to_signal(zv).cpu()
388
  print("done")
389
 
390
+ sig.write("output3.wav")
391
  mask.write("mask.wav")
392
 
vampnet/modules/transformer.py CHANGED
@@ -741,7 +741,7 @@ class VampNet(at.ml.BaseModel):
741
  sampling_steps: int = 36,
742
  start_tokens: Optional[torch.Tensor] = None,
743
  mask: Optional[torch.Tensor] = None,
744
- temperature: Union[float, Tuple[float, float]] = 0.8,
745
  typical_filtering=False,
746
  typical_mass=0.2,
747
  typical_min_tokens=1,
@@ -848,26 +848,27 @@ class VampNet(at.ml.BaseModel):
848
  probs = torch.softmax(logits, dim=-1)
849
  logging.info(f"computed probs with shape: {probs.shape}")
850
 
851
- # flatten z_masked and mask, so we can deal with the sampling logic
852
- # we'll unflatten them at the end of the loop for the next forward pass
853
- z_masked = codebook_flatten(z_masked)
854
 
855
  # sample from logits with multinomial sampling
856
  b = probs.shape[0]
857
  probs = rearrange(probs, "b seq prob -> (b seq) prob")
858
 
859
-
860
-
861
  sampled_z = torch.multinomial(probs, 1).squeeze(-1)
862
 
863
  sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
864
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
865
  logging.info(f"sampled z with shape: {sampled_z.shape}")
866
 
867
- # update the mask
 
 
 
 
 
868
  mask = (z_masked == self.mask_token).int()
 
 
869
  logging.info(f"updated mask with shape: {mask.shape}")
870
-
871
  # add z back into sampled z where the mask was false
872
  sampled_z = torch.where(
873
  mask.bool(), sampled_z, z_masked
@@ -902,17 +903,9 @@ class VampNet(at.ml.BaseModel):
902
 
903
 
904
  # get our new mask
905
- # print(tmpt * (1-_gamma(r)))
906
  mask = mask_by_random_topk(
907
  num_to_mask, selected_probs, tmpt * (1-r)
908
- )
909
-
910
- # print(f"most confident tokens: ")
911
- # print(torch.take_along_dim(
912
- # sampled_z, selected_probs.argsort(descending=False), dim=-1)
913
- # )
914
- # print(sampled_z[~mask.bool()])
915
-
916
 
917
  # update the mask
918
  z_masked = torch.where(
@@ -920,22 +913,29 @@ class VampNet(at.ml.BaseModel):
920
  )
921
  logging.info(f"updated z_masked with shape: {z_masked.shape}")
922
 
923
-
924
- z_masked = codebook_unflatten(z_masked, self.n_codebooks)
925
- mask = codebook_unflatten(mask, self.n_codebooks)
926
  logging.info(f"unflattened z_masked with shape: {z_masked.shape}")
927
 
 
 
 
 
 
928
 
929
- logging.info(f"updated z_masked with shape: {z_masked.shape}")
930
 
 
 
 
 
 
931
 
932
  logging.info(f"finished sampling")
933
- z = codebook_unflatten(sampled_z, self.n_codebooks)
934
 
935
  if return_signal:
936
- return self.to_signal(z, codec)
937
  else:
938
- return z
939
 
940
 
941
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
 
741
  sampling_steps: int = 36,
742
  start_tokens: Optional[torch.Tensor] = None,
743
  mask: Optional[torch.Tensor] = None,
744
+ temperature: Union[float, Tuple[float, float]] = 8.0,
745
  typical_filtering=False,
746
  typical_mass=0.2,
747
  typical_min_tokens=1,
 
848
  probs = torch.softmax(logits, dim=-1)
849
  logging.info(f"computed probs with shape: {probs.shape}")
850
 
 
 
 
851
 
852
  # sample from logits with multinomial sampling
853
  b = probs.shape[0]
854
  probs = rearrange(probs, "b seq prob -> (b seq) prob")
855
 
 
 
856
  sampled_z = torch.multinomial(probs, 1).squeeze(-1)
857
 
858
  sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
859
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
860
  logging.info(f"sampled z with shape: {sampled_z.shape}")
861
 
862
+
863
+ # flatten z_masked and mask, so we can deal with the sampling logic
864
+ # we'll unflatten them at the end of the loop for the next forward pass
865
+ # remove conditioning codebooks, we'll add them back at the end
866
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
867
+
868
  mask = (z_masked == self.mask_token).int()
869
+
870
+ # update the mask, remove conditioning codebooks from the mask
871
  logging.info(f"updated mask with shape: {mask.shape}")
 
872
  # add z back into sampled z where the mask was false
873
  sampled_z = torch.where(
874
  mask.bool(), sampled_z, z_masked
 
903
 
904
 
905
  # get our new mask
 
906
  mask = mask_by_random_topk(
907
  num_to_mask, selected_probs, tmpt * (1-r)
908
+ )
 
 
 
 
 
 
 
909
 
910
  # update the mask
911
  z_masked = torch.where(
 
913
  )
914
  logging.info(f"updated z_masked with shape: {z_masked.shape}")
915
 
916
+ z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
917
+ mask = codebook_unflatten(mask, n_infer_codebooks)
 
918
  logging.info(f"unflattened z_masked with shape: {z_masked.shape}")
919
 
920
+ # add conditioning codebooks back to z_masked
921
+ z_masked = torch.cat(
922
+ (z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
923
+ )
924
+ logging.info(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
925
 
 
926
 
927
+ # add conditioning codebooks back to sampled_z
928
+ sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
929
+ sampled_z = torch.cat(
930
+ (z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
931
+ )
932
 
933
  logging.info(f"finished sampling")
 
934
 
935
  if return_signal:
936
+ return self.to_signal(sampled_z, codec)
937
  else:
938
+ return sampled_z
939
 
940
 
941
  def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):