Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
•
c6f0e5a
1
Parent(s):
a022742
new stuff
Browse files- app.py +136 -32
- conf/generated/church-bells/c2f.yml +15 -0
- conf/generated/church-bells/coarse.yml +8 -0
- conf/generated/church-bells/interface.yml +6 -0
- conf/generated/copepod/c2f.yml +15 -0
- conf/generated/copepod/coarse.yml +8 -0
- conf/generated/copepod/interface.yml +6 -0
- conf/generated/growl/c2f.yml +16 -0
- conf/generated/growl/coarse.yml +9 -0
- conf/generated/growl/interface.yml +7 -0
- conf/generated/sample-instrument/c2f.yml +15 -0
- conf/generated/sample-instrument/coarse.yml +8 -0
- conf/generated/sample-instrument/interface.yml +6 -0
- models/models/vampnet/c2f.pth +3 -0
- models/models/vampnet/coarse.pth +3 -0
- models/models/vampnet/codec.pth +3 -0
- models/models/wavebeat.pth +3 -0
- scripts/utils/split.py +5 -3
- vampnet/interface.py +101 -12
- vampnet/mask.py +10 -3
- vampnet/modules/transformer.py +211 -9
app.py
CHANGED
@@ -7,13 +7,15 @@ import audiotools as at
|
|
7 |
import argbind
|
8 |
import shutil
|
9 |
import torch
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
-
from vampnet.interface import Interface
|
13 |
from vampnet import mask as pmask
|
14 |
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
|
|
17 |
interface = Interface(
|
18 |
device=device,
|
19 |
coarse_ckpt="models/vampnet/coarse.pth",
|
@@ -78,11 +80,20 @@ def shift_pitch(signal, interval: int):
|
|
78 |
)
|
79 |
return signal
|
80 |
|
81 |
-
def _vamp(data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
|
|
84 |
out_dir.mkdir(parents=True)
|
85 |
sig = at.AudioSignal(data[input_audio])
|
|
|
86 |
|
87 |
# reload the model if necessary
|
88 |
interface.reload(
|
@@ -96,18 +107,21 @@ def _vamp(data, return_mask=False):
|
|
96 |
if data[pitch_shift_amt] != 0:
|
97 |
sig = shift_pitch(sig, data[pitch_shift_amt])
|
98 |
|
|
|
|
|
99 |
build_mask_kwargs = dict(
|
100 |
rand_mask_intensity=data[rand_mask_intensity],
|
101 |
prefix_s=data[prefix_s],
|
102 |
suffix_s=data[suffix_s],
|
103 |
periodic_prompt=data[periodic_p],
|
|
|
104 |
periodic_prompt_width=data[periodic_w],
|
105 |
onset_mask_width=data[onset_mask_width],
|
106 |
_dropout=data[dropout],
|
107 |
-
upper_codebook_mask=int(data[n_mask_codebooks])
|
|
|
108 |
)
|
109 |
|
110 |
-
_seed = data[seed] if data[seed] > 0 else None
|
111 |
vamp_kwargs = dict(
|
112 |
# _sampling_steps=[data[num_steps], 8, 8, 4, 4, 2, 2, 1, 1],
|
113 |
mask_temperature=data[masktemp]*10,
|
@@ -121,31 +135,81 @@ def _vamp(data, return_mask=False):
|
|
121 |
)
|
122 |
|
123 |
# save the mask as a txt file
|
124 |
-
|
|
|
125 |
sig,
|
126 |
-
batch_size=1,
|
127 |
feedback_steps=data[num_feedback_steps],
|
128 |
time_stretch_factor=data[stretch_factor],
|
129 |
build_mask_kwargs=build_mask_kwargs,
|
130 |
vamp_kwargs=vamp_kwargs,
|
131 |
-
return_mask=
|
132 |
)
|
133 |
|
134 |
-
|
|
|
135 |
|
136 |
-
if return_mask:
|
137 |
-
mask = interface.to_signal(mask.cuda()).cpu()
|
138 |
-
mask.write(out_dir / "mask.wav")
|
139 |
-
return sig.path_to_file, mask.path_to_file
|
140 |
-
else:
|
141 |
return sig.path_to_file
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
def vamp(data):
|
145 |
-
return _vamp(data,
|
146 |
|
147 |
def api_vamp(data):
|
148 |
-
return _vamp(data,
|
149 |
|
150 |
with gr.Blocks() as demo:
|
151 |
with gr.Row():
|
@@ -193,7 +257,13 @@ with gr.Blocks() as demo:
|
|
193 |
step=1,
|
194 |
value=3,
|
195 |
)
|
196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
|
198 |
onset_mask_width = gr.Slider(
|
199 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
@@ -204,9 +274,13 @@ with gr.Blocks() as demo:
|
|
204 |
)
|
205 |
|
206 |
n_mask_codebooks = gr.Number(
|
207 |
-
label="
|
208 |
value=3,
|
209 |
)
|
|
|
|
|
|
|
|
|
210 |
|
211 |
with gr.Accordion("extras ", open=False):
|
212 |
pitch_shift_amt = gr.Slider(
|
@@ -337,17 +411,45 @@ with gr.Blocks() as demo:
|
|
337 |
value=1
|
338 |
)
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
vamp_button = gr.Button("generate (vamp)!!!")
|
341 |
-
|
342 |
-
label="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
interactive=False,
|
344 |
type="filepath"
|
345 |
)
|
346 |
-
|
347 |
-
use_as_input_button = gr.Button("use output as input")
|
348 |
|
349 |
thank_you = gr.Markdown("")
|
350 |
|
|
|
|
|
|
|
351 |
|
352 |
_inputs = {
|
353 |
input_audio,
|
@@ -368,29 +470,31 @@ with gr.Blocks() as demo:
|
|
368 |
n_mask_codebooks,
|
369 |
pitch_shift_amt,
|
370 |
sample_cutoff,
|
371 |
-
num_feedback_steps
|
|
|
|
|
|
|
372 |
}
|
373 |
|
374 |
# connect widgets
|
375 |
vamp_button.click(
|
376 |
fn=vamp,
|
377 |
inputs=_inputs,
|
378 |
-
outputs=[
|
379 |
)
|
380 |
|
381 |
api_vamp_button = gr.Button("api vamp", visible=False)
|
382 |
api_vamp_button.click(
|
383 |
fn=api_vamp,
|
384 |
inputs=_inputs,
|
385 |
-
outputs=[
|
386 |
api_name="vamp"
|
387 |
)
|
388 |
|
389 |
-
use_as_input_button.click(
|
390 |
-
fn=lambda x: x,
|
391 |
-
inputs=[output_audio],
|
392 |
-
outputs=[input_audio]
|
393 |
-
)
|
394 |
|
395 |
-
|
396 |
-
demo.
|
|
|
|
|
|
|
|
|
|
7 |
import argbind
|
8 |
import shutil
|
9 |
import torch
|
10 |
+
from datetime import datetime
|
11 |
|
12 |
import gradio as gr
|
13 |
+
from vampnet.interface import Interface, signal_concat
|
14 |
from vampnet import mask as pmask
|
15 |
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
|
18 |
+
|
19 |
interface = Interface(
|
20 |
device=device,
|
21 |
coarse_ckpt="models/vampnet/coarse.pth",
|
|
|
80 |
)
|
81 |
return signal
|
82 |
|
83 |
+
def _vamp(data, api: bool=False,# bypasses mask audio generation and other things to be faster
|
84 |
+
):
|
85 |
+
|
86 |
+
_seed = data[seed] if data[seed] > 0 else None
|
87 |
+
if _seed is None:
|
88 |
+
# create a random seed
|
89 |
+
_seed = int(torch.randint(0, 2**32, (1,)).item())
|
90 |
+
at.util.seed(_seed)
|
91 |
|
92 |
+
datentime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
|
93 |
+
out_dir = OUT_DIR / f"{Path(data[input_audio]).stem}-{datentime}-seed-{_seed}-model-{data[model_choice]}"
|
94 |
out_dir.mkdir(parents=True)
|
95 |
sig = at.AudioSignal(data[input_audio])
|
96 |
+
sig.write(out_dir / "input.wav")
|
97 |
|
98 |
# reload the model if necessary
|
99 |
interface.reload(
|
|
|
107 |
if data[pitch_shift_amt] != 0:
|
108 |
sig = shift_pitch(sig, data[pitch_shift_amt])
|
109 |
|
110 |
+
_p2 = data[periodic_p] if data[p2] == 0 else data[p2]
|
111 |
+
_n_codebooks_2 = data[n_mask_codebooks] if data[n_mask_codebooks_2] == 0 else data[n_mask_codebooks_2]
|
112 |
build_mask_kwargs = dict(
|
113 |
rand_mask_intensity=data[rand_mask_intensity],
|
114 |
prefix_s=data[prefix_s],
|
115 |
suffix_s=data[suffix_s],
|
116 |
periodic_prompt=data[periodic_p],
|
117 |
+
periodic_prompt2=_p2,
|
118 |
periodic_prompt_width=data[periodic_w],
|
119 |
onset_mask_width=data[onset_mask_width],
|
120 |
_dropout=data[dropout],
|
121 |
+
upper_codebook_mask=int(data[n_mask_codebooks]),
|
122 |
+
upper_codebook_mask_2=int(_n_codebooks_2),
|
123 |
)
|
124 |
|
|
|
125 |
vamp_kwargs = dict(
|
126 |
# _sampling_steps=[data[num_steps], 8, 8, 4, 4, 2, 2, 1, 1],
|
127 |
mask_temperature=data[masktemp]*10,
|
|
|
135 |
)
|
136 |
|
137 |
# save the mask as a txt file
|
138 |
+
interface.set_chunk_size(data[win_dur])
|
139 |
+
sig, mask, codes = interface.ez_vamp(
|
140 |
sig,
|
141 |
+
batch_size=4 if not api else 1,
|
142 |
feedback_steps=data[num_feedback_steps],
|
143 |
time_stretch_factor=data[stretch_factor],
|
144 |
build_mask_kwargs=build_mask_kwargs,
|
145 |
vamp_kwargs=vamp_kwargs,
|
146 |
+
return_mask=True,
|
147 |
)
|
148 |
|
149 |
+
if api:
|
150 |
+
sig.write(out_dir / "out.wav")
|
151 |
|
|
|
|
|
|
|
|
|
|
|
152 |
return sig.path_to_file
|
153 |
|
154 |
+
if not api:
|
155 |
+
# write codes to numpy file
|
156 |
+
np.save(out_dir / "codes.npy", codes.cpu().numpy())
|
157 |
+
metadata = {}
|
158 |
+
metadata["seed"] = _seed
|
159 |
+
metadata["model_choice"] = data[model_choice]
|
160 |
+
metadata["mask_kwargs"] = build_mask_kwargs
|
161 |
+
metadata["vamp_kwargs"] = vamp_kwargs
|
162 |
+
metadata["loudness"] = loudness
|
163 |
+
# save the metadata
|
164 |
+
with open(out_dir / "metadata.yml", "w") as f:
|
165 |
+
yaml.dump(metadata, f)
|
166 |
+
|
167 |
+
sig0 = sig[0].write(out_dir / "out1.wav")
|
168 |
+
sig1 = sig[1].write(out_dir / "out2.wav")
|
169 |
+
sig2 = sig[2].write(out_dir / "out3.wav")
|
170 |
+
sig3 = sig[3].write(out_dir / "out4.wav")
|
171 |
+
|
172 |
+
# write the mask to txt
|
173 |
+
with open(out_dir / "mask.txt", "w") as f:
|
174 |
+
m = mask[0].cpu().numpy()
|
175 |
+
# write to txt, each time step on a new line
|
176 |
+
for i in range(m.shape[-1]):
|
177 |
+
f.write(f"{m[:, i]}\n")
|
178 |
+
|
179 |
+
|
180 |
+
import matplotlib.pyplot as plt
|
181 |
+
plt.clf()
|
182 |
+
interface.visualize_codes(mask)
|
183 |
+
plt.savefig(out_dir / "mask.png")
|
184 |
+
plt.clf()
|
185 |
+
interface.visualize_codes(codes)
|
186 |
+
plt.savefig(out_dir / "codes.png")
|
187 |
+
plt.close()
|
188 |
+
|
189 |
+
# zip out dir, and return the path to the zip
|
190 |
+
shutil.make_archive(out_dir, 'zip', out_dir)
|
191 |
+
|
192 |
+
# chunk in groups of 1024 timesteps
|
193 |
+
_mask_sigs = []
|
194 |
+
for i in range(0, mask.shape[-1], 1024):
|
195 |
+
_mask_sigs.append(interface.to_signal(mask[:, :, i:i+1024].to(interface.device)).cpu())
|
196 |
+
mask = signal_concat(_mask_sigs)
|
197 |
+
mask.write(out_dir / "mask.wav")
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
return (
|
203 |
+
sig0.path_to_file, sig1.path_to_file,
|
204 |
+
sig2.path_to_file, sig3.path_to_file,
|
205 |
+
mask.path_to_file, str(out_dir.with_suffix(".zip")), out_dir / "mask.png"
|
206 |
+
)
|
207 |
|
208 |
def vamp(data):
|
209 |
+
return _vamp(data, api=False)
|
210 |
|
211 |
def api_vamp(data):
|
212 |
+
return _vamp(data, api=True)
|
213 |
|
214 |
with gr.Blocks() as demo:
|
215 |
with gr.Row():
|
|
|
257 |
step=1,
|
258 |
value=3,
|
259 |
)
|
260 |
+
p2 = gr.Slider(
|
261 |
+
label="periodic prompt 2 (0 - same as p1, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
262 |
+
minimum=0,
|
263 |
+
maximum=128,
|
264 |
+
step=1,
|
265 |
+
value=0,
|
266 |
+
)
|
267 |
|
268 |
onset_mask_width = gr.Slider(
|
269 |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
|
|
|
274 |
)
|
275 |
|
276 |
n_mask_codebooks = gr.Number(
|
277 |
+
label="compression prompt (masks entire upper codebook levels above the specified level)",
|
278 |
value=3,
|
279 |
)
|
280 |
+
n_mask_codebooks_2 = gr.Number(
|
281 |
+
label="compression prompt 2 via linear interpolation (0 == constant)",
|
282 |
+
value=0,
|
283 |
+
)
|
284 |
|
285 |
with gr.Accordion("extras ", open=False):
|
286 |
pitch_shift_amt = gr.Slider(
|
|
|
411 |
value=1
|
412 |
)
|
413 |
|
414 |
+
win_dur= gr.Slider(
|
415 |
+
label="window duration (seconds)",
|
416 |
+
minimum=2,
|
417 |
+
maximum=10,
|
418 |
+
value=6)
|
419 |
+
|
420 |
+
|
421 |
vamp_button = gr.Button("generate (vamp)!!!")
|
422 |
+
maskimg = gr.Image(
|
423 |
+
label="mask image",
|
424 |
+
interactive=False,
|
425 |
+
type="filepath"
|
426 |
+
)
|
427 |
+
out1 = gr.Audio(
|
428 |
+
label="output audio 1",
|
429 |
+
interactive=False,
|
430 |
+
type="filepath"
|
431 |
+
)
|
432 |
+
out2 = gr.Audio(
|
433 |
+
label="output audio 2",
|
434 |
+
interactive=False,
|
435 |
+
type="filepath"
|
436 |
+
)
|
437 |
+
out3 = gr.Audio(
|
438 |
+
label="output audio 3",
|
439 |
+
interactive=False,
|
440 |
+
type="filepath"
|
441 |
+
)
|
442 |
+
out4 = gr.Audio(
|
443 |
+
label="output audio 4",
|
444 |
interactive=False,
|
445 |
type="filepath"
|
446 |
)
|
|
|
|
|
447 |
|
448 |
thank_you = gr.Markdown("")
|
449 |
|
450 |
+
# download all the outputs
|
451 |
+
download = gr.File(type="filepath", label="download outputs")
|
452 |
+
|
453 |
|
454 |
_inputs = {
|
455 |
input_audio,
|
|
|
470 |
n_mask_codebooks,
|
471 |
pitch_shift_amt,
|
472 |
sample_cutoff,
|
473 |
+
num_feedback_steps,
|
474 |
+
p2,
|
475 |
+
n_mask_codebooks_2,
|
476 |
+
win_dur
|
477 |
}
|
478 |
|
479 |
# connect widgets
|
480 |
vamp_button.click(
|
481 |
fn=vamp,
|
482 |
inputs=_inputs,
|
483 |
+
outputs=[out1, out2, out3, out4, audio_mask, download, maskimg],
|
484 |
)
|
485 |
|
486 |
api_vamp_button = gr.Button("api vamp", visible=False)
|
487 |
api_vamp_button.click(
|
488 |
fn=api_vamp,
|
489 |
inputs=_inputs,
|
490 |
+
outputs=[out1],
|
491 |
api_name="vamp"
|
492 |
)
|
493 |
|
|
|
|
|
|
|
|
|
|
|
494 |
|
495 |
+
try:
|
496 |
+
demo.queue()
|
497 |
+
demo.launch(share=True, debug=True)
|
498 |
+
except KeyboardInterrupt:
|
499 |
+
shutil.rmtree("gradio-outputs", ignore_errors=True)
|
500 |
+
raise
|
conf/generated/church-bells/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/church-bells/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- data/church-bells
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/church-bells/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/church-bells/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- data/church-bells
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/church-bells/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - data/church-bells
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/church-bells/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/church-bells/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/copepod/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/copepod/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- data/copepod
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/copepod/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/copepod/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- data/copepod
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/copepod/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - data/copepod
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/copepod/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/copepod/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/growl/c2f.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/growl/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- data/growly
|
15 |
+
- animals/
|
16 |
+
val/AudioLoader.sources: *id001
|
conf/generated/growl/coarse.yml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/growl/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- data/growly
|
8 |
+
- animals/
|
9 |
+
val/AudioLoader.sources: *id001
|
conf/generated/growl/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - data/growly
|
3 |
+
- animals/
|
4 |
+
Interface.coarse2fine_ckpt: ./runs/growl/c2f/latest/vampnet/weights.pth
|
5 |
+
Interface.coarse_ckpt: ./runs/growl/coarse/latest/vampnet/weights.pth
|
6 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
7 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
conf/generated/sample-instrument/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/vampnet/c2f.pth
|
12 |
+
save_path: ./runs/sample-instrument/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- data/sample-instrument/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/sample-instrument/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/vampnet/coarse.pth
|
5 |
+
save_path: ./runs/sample-instrument/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- data/sample-instrument/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/sample-instrument/interface.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - data/sample-instrument/
|
3 |
+
Interface.coarse2fine_ckpt: ./runs/sample-instrument/c2f/latest/vampnet/weights.pth
|
4 |
+
Interface.coarse_ckpt: ./runs/sample-instrument/coarse/latest/vampnet/weights.pth
|
5 |
+
Interface.codec_ckpt: ./models/vampnet/codec.pth
|
6 |
+
Interface.wavebeat_ckpt: ./models/wavebeat.pth
|
models/models/vampnet/c2f.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b10ea2d45459d34edb773cbacd71f40f7baa1f4e75ac8bcd93b022ac69f8fa63
|
3 |
+
size 1101898865
|
models/models/vampnet/coarse.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78e4ad4f8398e8ec3651bc5e5c6ea2995e1080b6226be186723ccf4320c9756c
|
3 |
+
size 1332182321
|
models/models/vampnet/codec.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3db3fa43ab5d160439ddb81fc540b5573ad5ae962230de3fc5b47d218845b855
|
3 |
+
size 600996465
|
models/models/wavebeat.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ff1066a4470cb98b20edf1e489f6995b19e0435b9cfd5a190bf90a954d0cadb
|
3 |
+
size 33248861
|
scripts/utils/split.py
CHANGED
@@ -16,12 +16,11 @@ def train_test_split(
|
|
16 |
audio_folder: str = ".",
|
17 |
test_size: float = 0.2,
|
18 |
seed: int = 42,
|
19 |
-
pattern: str = "**/*.mp3",
|
20 |
):
|
21 |
print(f"finding audio")
|
22 |
|
23 |
audio_folder = Path(audio_folder)
|
24 |
-
audio_files =
|
25 |
print(f"found {len(audio_files)} audio files")
|
26 |
|
27 |
# split according to test_size
|
@@ -49,7 +48,10 @@ def train_test_split(
|
|
49 |
for file in tqdm(files):
|
50 |
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
51 |
out_file.parent.mkdir(exist_ok=True, parents=True)
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
# save split as json
|
55 |
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
|
|
16 |
audio_folder: str = ".",
|
17 |
test_size: float = 0.2,
|
18 |
seed: int = 42,
|
|
|
19 |
):
|
20 |
print(f"finding audio")
|
21 |
|
22 |
audio_folder = Path(audio_folder)
|
23 |
+
audio_files = util.find_audio(audio_folder)
|
24 |
print(f"found {len(audio_files)} audio files")
|
25 |
|
26 |
# split according to test_size
|
|
|
48 |
for file in tqdm(files):
|
49 |
out_file = audio_folder.parent / f"{audio_folder.name}-{split}" / Path(file).name
|
50 |
out_file.parent.mkdir(exist_ok=True, parents=True)
|
51 |
+
try:
|
52 |
+
os.symlink(file, out_file)
|
53 |
+
except FileExistsError:
|
54 |
+
print(f"File {out_file} already exists, skipping")
|
55 |
|
56 |
# save split as json
|
57 |
with open(Path(audio_folder) / f"{split}.json", "w") as f:
|
vampnet/interface.py
CHANGED
@@ -273,11 +273,15 @@ class Interface(torch.nn.Module):
|
|
273 |
else:
|
274 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
275 |
return mask
|
|
|
|
|
|
|
276 |
|
277 |
def coarse_to_fine(
|
278 |
self,
|
279 |
z: torch.Tensor,
|
280 |
mask: torch.Tensor = None,
|
|
|
281 |
**kwargs
|
282 |
):
|
283 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
@@ -320,6 +324,9 @@ class Interface(torch.nn.Module):
|
|
320 |
fine_z.append(chunk)
|
321 |
|
322 |
fine_z = torch.cat(fine_z, dim=-1)
|
|
|
|
|
|
|
323 |
return fine_z[:, :, :length].clone()
|
324 |
|
325 |
def coarse_vamp(
|
@@ -397,10 +404,12 @@ class Interface(torch.nn.Module):
|
|
397 |
prefix_s: float = 0.0,
|
398 |
suffix_s: float = 0.0,
|
399 |
periodic_prompt: int = 7,
|
|
|
400 |
periodic_prompt_width: int = 1,
|
401 |
onset_mask_width: int = 0,
|
402 |
_dropout: float = 0.0,
|
403 |
upper_codebook_mask: int = 3,
|
|
|
404 |
ncc: int = 0,
|
405 |
):
|
406 |
|
@@ -409,10 +418,17 @@ class Interface(torch.nn.Module):
|
|
409 |
mask,
|
410 |
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
|
411 |
)
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
416 |
if onset_mask_width > 0:
|
417 |
assert sig is not None, f"must provide a signal to use onset mask"
|
418 |
mask = mask_or(
|
@@ -424,7 +440,8 @@ class Interface(torch.nn.Module):
|
|
424 |
|
425 |
mask = dropout(mask, _dropout)
|
426 |
mask = codebook_unmask(mask, ncc)
|
427 |
-
|
|
|
428 |
return mask
|
429 |
|
430 |
def ez_vamp(
|
@@ -440,8 +457,8 @@ class Interface(torch.nn.Module):
|
|
440 |
build_mask_kwargs = build_mask_kwargs or {}
|
441 |
vamp_kwargs = vamp_kwargs or {}
|
442 |
|
443 |
-
sig = self.preprocess(sig)
|
444 |
loudness = sig.loudness()
|
|
|
445 |
|
446 |
z = self.encode(sig)
|
447 |
|
@@ -488,26 +505,60 @@ class Interface(torch.nn.Module):
|
|
488 |
|
489 |
# now, coarse2fine
|
490 |
print(f"coarse2fine!")
|
491 |
-
zv = self.coarse_to_fine(
|
492 |
zv,
|
493 |
mask=mask,
|
494 |
**vamp_kwargs,
|
495 |
-
_sampling_steps=[
|
|
|
|
|
|
|
|
|
|
|
496 |
)
|
497 |
|
498 |
prev_zvs.append(zv)
|
499 |
z = zv
|
500 |
|
501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
502 |
print("done")
|
503 |
|
|
|
|
|
|
|
|
|
|
|
504 |
sig = sig.normalize(loudness)
|
505 |
|
506 |
if return_mask:
|
507 |
-
return sig,
|
508 |
else:
|
509 |
return sig
|
510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
if __name__ == "__main__":
|
512 |
import audiotools as at
|
513 |
import logging
|
@@ -528,8 +579,6 @@ if __name__ == "__main__":
|
|
528 |
sig = at.AudioSignal('assets/example.wav')
|
529 |
|
530 |
z = interface.encode(sig)
|
531 |
-
breakpoint()
|
532 |
-
|
533 |
# mask = linear_random(z, 1.0)
|
534 |
# mask = mask_and(
|
535 |
# mask, periodic_mask(
|
@@ -569,3 +618,43 @@ if __name__ == "__main__":
|
|
569 |
print("done")
|
570 |
|
571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
else:
|
274 |
mask = mask.repeat(1, self.coarse.n_codebooks, 1)
|
275 |
return mask
|
276 |
+
|
277 |
+
def set_chunk_size(self, chunk_size_s: float):
|
278 |
+
self.coarse.chunk_size_s = chunk_size_s
|
279 |
|
280 |
def coarse_to_fine(
|
281 |
self,
|
282 |
z: torch.Tensor,
|
283 |
mask: torch.Tensor = None,
|
284 |
+
return_mask: bool = False,
|
285 |
**kwargs
|
286 |
):
|
287 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
|
|
324 |
fine_z.append(chunk)
|
325 |
|
326 |
fine_z = torch.cat(fine_z, dim=-1)
|
327 |
+
if return_mask:
|
328 |
+
return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone()
|
329 |
+
|
330 |
return fine_z[:, :, :length].clone()
|
331 |
|
332 |
def coarse_vamp(
|
|
|
404 |
prefix_s: float = 0.0,
|
405 |
suffix_s: float = 0.0,
|
406 |
periodic_prompt: int = 7,
|
407 |
+
periodic_prompt2: int = 7,
|
408 |
periodic_prompt_width: int = 1,
|
409 |
onset_mask_width: int = 0,
|
410 |
_dropout: float = 0.0,
|
411 |
upper_codebook_mask: int = 3,
|
412 |
+
upper_codebook_mask_2: int = None,
|
413 |
ncc: int = 0,
|
414 |
):
|
415 |
|
|
|
418 |
mask,
|
419 |
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)),
|
420 |
)
|
421 |
+
|
422 |
+
pmask1 = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True)
|
423 |
+
pmask2 = periodic_mask(z, periodic_prompt2, periodic_prompt_width, random_roll=True)
|
424 |
+
# interpolate the two masks
|
425 |
+
pmask = torch.round(
|
426 |
+
pmask1 * torch.linspace(1, 0, pmask1.shape[-1], device=pmask1.device) +
|
427 |
+
pmask2 * torch.linspace(0, 1, pmask2.shape[-1], device=pmask2.device)
|
428 |
+
).long()
|
429 |
+
|
430 |
+
mask = mask_and(mask, pmask)
|
431 |
+
|
432 |
if onset_mask_width > 0:
|
433 |
assert sig is not None, f"must provide a signal to use onset mask"
|
434 |
mask = mask_or(
|
|
|
440 |
|
441 |
mask = dropout(mask, _dropout)
|
442 |
mask = codebook_unmask(mask, ncc)
|
443 |
+
|
444 |
+
mask = codebook_mask(mask, int(upper_codebook_mask), upper_codebook_mask_2)
|
445 |
return mask
|
446 |
|
447 |
def ez_vamp(
|
|
|
457 |
build_mask_kwargs = build_mask_kwargs or {}
|
458 |
vamp_kwargs = vamp_kwargs or {}
|
459 |
|
|
|
460 |
loudness = sig.loudness()
|
461 |
+
sig = self.preprocess(sig)
|
462 |
|
463 |
z = self.encode(sig)
|
464 |
|
|
|
505 |
|
506 |
# now, coarse2fine
|
507 |
print(f"coarse2fine!")
|
508 |
+
zv, fine_zv_mask = self.coarse_to_fine(
|
509 |
zv,
|
510 |
mask=mask,
|
511 |
**vamp_kwargs,
|
512 |
+
_sampling_steps=[2, 2, 1, 1],
|
513 |
+
return_mask=True
|
514 |
+
)
|
515 |
+
mask_z = torch.cat(
|
516 |
+
[mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]],
|
517 |
+
dim=1
|
518 |
)
|
519 |
|
520 |
prev_zvs.append(zv)
|
521 |
z = zv
|
522 |
|
523 |
+
# perform to_signal batch item by batch
|
524 |
+
sigs = []
|
525 |
+
for zv in prev_zvs:
|
526 |
+
# do it in timestep chunks of 1024
|
527 |
+
_sigs = []
|
528 |
+
for i in range(0, zv.shape[-1], 1024):
|
529 |
+
_sigs.append(self.to_signal(zv[:, :, i:i+1024]).cpu())
|
530 |
+
sigs.append(signal_concat(_sigs))
|
531 |
print("done")
|
532 |
|
533 |
+
sig = AudioSignal.batch(sigs)
|
534 |
+
|
535 |
+
# sig = self.to_signal(zv).cpu()
|
536 |
+
# print("done")
|
537 |
+
|
538 |
sig = sig.normalize(loudness)
|
539 |
|
540 |
if return_mask:
|
541 |
+
return sig, mask_z.cpu(), zv.cpu()
|
542 |
else:
|
543 |
return sig
|
544 |
|
545 |
+
def visualize_codes(self, z: torch.Tensor):
|
546 |
+
import matplotlib.pyplot as plt
|
547 |
+
# make sure the figsize is square when imshow is called
|
548 |
+
fig = plt.figure(figsize=(10, 7))
|
549 |
+
# in subplots, plot z[0] and the mask
|
550 |
+
# set title to "codes" and "mask"
|
551 |
+
fig.add_subplot(2, 1, 1)
|
552 |
+
plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
|
553 |
+
plt.title("codes")
|
554 |
+
plt.ylabel("codebook index")
|
555 |
+
# set the xticks to seconds
|
556 |
+
plt.xticks(
|
557 |
+
np.arange(0, z.shape[-1], self.s2t(1)),
|
558 |
+
np.arange(0, self.t2s(z.shape[-1]), 1)
|
559 |
+
)
|
560 |
+
plt.xlabel("time (s)")
|
561 |
+
|
562 |
if __name__ == "__main__":
|
563 |
import audiotools as at
|
564 |
import logging
|
|
|
579 |
sig = at.AudioSignal('assets/example.wav')
|
580 |
|
581 |
z = interface.encode(sig)
|
|
|
|
|
582 |
# mask = linear_random(z, 1.0)
|
583 |
# mask = mask_and(
|
584 |
# mask, periodic_mask(
|
|
|
618 |
print("done")
|
619 |
|
620 |
|
621 |
+
|
622 |
+
|
623 |
+
# example plotting code
|
624 |
+
# import matplotlib.pyplot as plt
|
625 |
+
# from pathlib import Path
|
626 |
+
# Path(".vampnet").mkdir(exist_ok=True)
|
627 |
+
# plt.clf()
|
628 |
+
# # close all figs
|
629 |
+
# plt.close('all')
|
630 |
+
# # set the fig size
|
631 |
+
# plt.subplot(4, 1, 1)
|
632 |
+
# # sig = self.to_signal(sampled_z, codec)
|
633 |
+
# # sig.cpu().specshow()
|
634 |
+
|
635 |
+
# plt.subplot(4, 1, 2)
|
636 |
+
# # since z_masked is a codebook, we want to plot the colormap
|
637 |
+
# # with distinct colors for each codebook index
|
638 |
+
# # plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20")
|
639 |
+
# # make it so that anywhere where the mask is 1, we make that pixel black
|
640 |
+
# plt.imshow(_debug_z_masked_before_forward[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r',)
|
641 |
+
|
642 |
+
|
643 |
+
# plt.subplot(4, 1, 3)
|
644 |
+
# # plot the mask (which is a matrix)
|
645 |
+
# plt.imshow(mask[0].cpu().numpy(), aspect='auto', origin='lower', cmap='gray_r')
|
646 |
+
# plt.subplot(4, 1, 4)
|
647 |
+
# # replace any inf or -inf with 0
|
648 |
+
# _selected_probs = torch.where(
|
649 |
+
# selected_probs == torch.inf, torch.zeros_like(selected_probs), selected_probs
|
650 |
+
# )
|
651 |
+
# _selected_probs = torch.where(
|
652 |
+
# selected_probs == -torch.inf, torch.zeros_like(selected_probs), selected_probs
|
653 |
+
# )
|
654 |
+
# # fig = plt.gcf()
|
655 |
+
# # fig.set_figheight(15)
|
656 |
+
# # fig.set_figwidth(15)
|
657 |
+
# plt.imshow(codebook_unflatten(_selected_probs, n_infer_codebooks)[0].cpu().numpy(), aspect='auto', origin='lower', cmap="viridis" )
|
658 |
+
# # plt.show()
|
659 |
+
# plt.savefig(f".vampnet/c={codebook_level}_{i}.png")
|
660 |
+
# plt.close('all')
|
vampnet/mask.py
CHANGED
@@ -60,6 +60,7 @@ def linear_random(
|
|
60 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
if not isinstance(r, torch.Tensor):
|
62 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
|
|
63 |
|
64 |
probs = torch.ones_like(x).to(x.device).float()
|
65 |
# expand to batch and codebook dims
|
@@ -98,7 +99,7 @@ def inpaint(x: torch.Tensor,
|
|
98 |
return mask
|
99 |
|
100 |
def periodic_mask(x: torch.Tensor,
|
101 |
-
period: int,
|
102 |
random_roll=False,
|
103 |
):
|
104 |
mask = full_mask(x)
|
@@ -140,9 +141,15 @@ def codebook_unmask(
|
|
140 |
mask[:, :n_conditioning_codebooks, :] = 0
|
141 |
return mask
|
142 |
|
143 |
-
def codebook_mask(mask: torch.Tensor,
|
144 |
mask = mask.clone()
|
145 |
-
mask[:,
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
return mask
|
147 |
|
148 |
def mask_and(
|
|
|
60 |
assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
|
61 |
if not isinstance(r, torch.Tensor):
|
62 |
r = scalar_to_batch_tensor(r, x.shape[0]).to(x.device).float()
|
63 |
+
r = r[:, None, None]
|
64 |
|
65 |
probs = torch.ones_like(x).to(x.device).float()
|
66 |
# expand to batch and codebook dims
|
|
|
99 |
return mask
|
100 |
|
101 |
def periodic_mask(x: torch.Tensor,
|
102 |
+
period: int,width: int = 1,
|
103 |
random_roll=False,
|
104 |
):
|
105 |
mask = full_mask(x)
|
|
|
141 |
mask[:, :n_conditioning_codebooks, :] = 0
|
142 |
return mask
|
143 |
|
144 |
+
def codebook_mask(mask: torch.Tensor, val1: int, val2: int = None):
|
145 |
mask = mask.clone()
|
146 |
+
mask[:, val1:, :] = 1
|
147 |
+
# val2 = val2 or val1
|
148 |
+
# vs = torch.linspace(val1, val2, mask.shape[1])
|
149 |
+
# for t, v in enumerate(vs):
|
150 |
+
# v = int(v)
|
151 |
+
# mask[:, v:, t] = 1
|
152 |
+
|
153 |
return mask
|
154 |
|
155 |
def mask_and(
|
vampnet/modules/transformer.py
CHANGED
@@ -572,6 +572,8 @@ class VampNet(at.ml.BaseModel):
|
|
572 |
"""
|
573 |
assert z.ndim == 3
|
574 |
|
|
|
|
|
575 |
signal = at.AudioSignal(
|
576 |
codec.decode(
|
577 |
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
@@ -581,15 +583,13 @@ class VampNet(at.ml.BaseModel):
|
|
581 |
|
582 |
# find where the mask token is and replace it with silence in the audio
|
583 |
for tstep in range(z.shape[-1]):
|
584 |
-
if torch.
|
585 |
sample_idx_0 = tstep * codec.hop_length
|
586 |
sample_idx_1 = sample_idx_0 + codec.hop_length
|
587 |
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
588 |
|
589 |
return signal
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
@torch.no_grad()
|
594 |
def generate(
|
595 |
self,
|
@@ -600,17 +600,37 @@ class VampNet(at.ml.BaseModel):
|
|
600 |
sampling_temperature: float = 1.0,
|
601 |
mask: Optional[torch.Tensor] = None,
|
602 |
mask_temperature: float = 10.5,
|
603 |
-
typical_filtering=
|
604 |
typical_mass=0.2,
|
605 |
typical_min_tokens=1,
|
606 |
-
top_p=
|
607 |
seed: int = None,
|
608 |
-
sample_cutoff: float =
|
609 |
return_signal=True,
|
610 |
debug=False,
|
611 |
causal_weight: float = 0.0,
|
|
|
612 |
):
|
613 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
if seed is not None:
|
615 |
at.util.seed(seed)
|
616 |
|
@@ -749,7 +769,7 @@ class VampNet(at.ml.BaseModel):
|
|
749 |
num_to_mask
|
750 |
)
|
751 |
)
|
752 |
-
|
753 |
mask = codebook_flatten(mask)
|
754 |
|
755 |
# ignore any tokens that weren't masked
|
@@ -812,6 +832,188 @@ class VampNet(at.ml.BaseModel):
|
|
812 |
else:
|
813 |
return sampled_z
|
814 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
815 |
def sample_from_logits(
|
816 |
logits,
|
817 |
sample: bool = True,
|
|
|
572 |
"""
|
573 |
assert z.ndim == 3
|
574 |
|
575 |
+
# remove mask token
|
576 |
+
z = z.masked_fill(z == self.mask_token, 0)
|
577 |
signal = at.AudioSignal(
|
578 |
codec.decode(
|
579 |
codec.quantizer.from_latents(self.embedding.from_codes(z, codec))[0]
|
|
|
583 |
|
584 |
# find where the mask token is and replace it with silence in the audio
|
585 |
for tstep in range(z.shape[-1]):
|
586 |
+
if torch.all(z[:, :, tstep] == self.mask_token):
|
587 |
sample_idx_0 = tstep * codec.hop_length
|
588 |
sample_idx_1 = sample_idx_0 + codec.hop_length
|
589 |
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
590 |
|
591 |
return signal
|
592 |
+
|
|
|
|
|
593 |
@torch.no_grad()
|
594 |
def generate(
|
595 |
self,
|
|
|
600 |
sampling_temperature: float = 1.0,
|
601 |
mask: Optional[torch.Tensor] = None,
|
602 |
mask_temperature: float = 10.5,
|
603 |
+
typical_filtering=True,
|
604 |
typical_mass=0.2,
|
605 |
typical_min_tokens=1,
|
606 |
+
top_p=0.9,
|
607 |
seed: int = None,
|
608 |
+
sample_cutoff: float = 0.9,
|
609 |
return_signal=True,
|
610 |
debug=False,
|
611 |
causal_weight: float = 0.0,
|
612 |
+
use_og_method: bool = False,
|
613 |
):
|
614 |
+
if use_og_method:
|
615 |
+
return self.og_method(
|
616 |
+
codec,
|
617 |
+
time_steps,
|
618 |
+
_sampling_steps,
|
619 |
+
start_tokens,
|
620 |
+
sampling_temperature,
|
621 |
+
mask,
|
622 |
+
mask_temperature,
|
623 |
+
typical_filtering,
|
624 |
+
typical_mass,
|
625 |
+
typical_min_tokens,
|
626 |
+
top_p,
|
627 |
+
seed,
|
628 |
+
sample_cutoff,
|
629 |
+
return_signal,
|
630 |
+
debug,
|
631 |
+
causal_weight,
|
632 |
+
)
|
633 |
+
|
634 |
if seed is not None:
|
635 |
at.util.seed(seed)
|
636 |
|
|
|
769 |
num_to_mask
|
770 |
)
|
771 |
)
|
772 |
+
logging.debug(f"will mask {num_to_mask.sum()} tokens")
|
773 |
mask = codebook_flatten(mask)
|
774 |
|
775 |
# ignore any tokens that weren't masked
|
|
|
832 |
else:
|
833 |
return sampled_z
|
834 |
|
835 |
+
|
836 |
+
|
837 |
+
def og_method(
|
838 |
+
self,
|
839 |
+
codec,
|
840 |
+
time_steps: int = 300,
|
841 |
+
_sampling_steps: List[int] = [16, 8, 8, 2, 2, 2, 2, 1, 1],
|
842 |
+
start_tokens: Optional[torch.Tensor] = None,
|
843 |
+
sampling_temperature: float = 1.0,
|
844 |
+
mask: Optional[torch.Tensor] = None,
|
845 |
+
mask_temperature: float = 10.5,
|
846 |
+
typical_filtering=True,
|
847 |
+
typical_mass=0.2,
|
848 |
+
typical_min_tokens=1,
|
849 |
+
top_p=0.9,
|
850 |
+
seed: int = None,
|
851 |
+
sample_cutoff: float = 0.75,
|
852 |
+
return_signal=True,
|
853 |
+
debug=False,
|
854 |
+
causal_weight: float = 0.0,
|
855 |
+
):
|
856 |
+
if seed is not None:
|
857 |
+
at.util.seed(seed)
|
858 |
+
sampling_steps = sum(_sampling_steps)
|
859 |
+
logging.debug(f"beginning generation with {sampling_steps} steps")
|
860 |
+
|
861 |
+
|
862 |
+
|
863 |
+
#####################
|
864 |
+
# resolve initial z #
|
865 |
+
#####################
|
866 |
+
z = start_tokens
|
867 |
+
|
868 |
+
if z is None:
|
869 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
870 |
+
self.device
|
871 |
+
)
|
872 |
+
|
873 |
+
logging.debug(f"created z with shape {z.shape}")
|
874 |
+
|
875 |
+
|
876 |
+
#################
|
877 |
+
# resolve mask #
|
878 |
+
#################
|
879 |
+
|
880 |
+
if mask is None:
|
881 |
+
mask = torch.ones_like(z).to(self.device).int()
|
882 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
883 |
+
if mask.ndim == 2:
|
884 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
885 |
+
# init_mask = mask.clone()
|
886 |
+
|
887 |
+
logging.debug(f"created mask with shape {mask.shape}")
|
888 |
+
|
889 |
+
|
890 |
+
###########
|
891 |
+
# set up #
|
892 |
+
##########
|
893 |
+
# apply the mask to z
|
894 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
895 |
+
# logging.debug(f"z_masked: {z_masked}")
|
896 |
+
|
897 |
+
# how many mask tokens to begin with?
|
898 |
+
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
899 |
+
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
900 |
+
|
901 |
+
# how many codebooks are we inferring vs conditioning on?
|
902 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
903 |
+
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
904 |
+
|
905 |
+
#################
|
906 |
+
# begin sampling #
|
907 |
+
#################
|
908 |
+
|
909 |
+
for i in range(sampling_steps):
|
910 |
+
logging.debug(f"step {i} of {sampling_steps}")
|
911 |
+
|
912 |
+
# our current schedule step
|
913 |
+
r = scalar_to_batch_tensor(
|
914 |
+
(i + 1) / sampling_steps,
|
915 |
+
z.shape[0]
|
916 |
+
).to(z.device)
|
917 |
+
logging.debug(f"r: {r}")
|
918 |
+
|
919 |
+
# get latents
|
920 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
921 |
+
logging.debug(f"computed latents with shape: {latents.shape}")
|
922 |
+
|
923 |
+
|
924 |
+
# infer from latents
|
925 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
926 |
+
logits = self.forward(latents) # b, prob, seq
|
927 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
928 |
+
b = logits.shape[0]
|
929 |
+
|
930 |
+
logging.debug(f"permuted logits with shape: {logits.shape}")
|
931 |
+
|
932 |
+
sampled_z, selected_probs = sample_from_logits(
|
933 |
+
logits, sample=(
|
934 |
+
(i / sampling_steps) <= sample_cutoff
|
935 |
+
),
|
936 |
+
temperature=sampling_temperature,
|
937 |
+
typical_filtering=typical_filtering, typical_mass=typical_mass,
|
938 |
+
typical_min_tokens=typical_min_tokens,
|
939 |
+
top_k=None, top_p=top_p, return_probs=True,
|
940 |
+
)
|
941 |
+
|
942 |
+
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
943 |
+
|
944 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
945 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
946 |
+
# remove conditioning codebooks, we'll add them back at the end
|
947 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
948 |
+
|
949 |
+
mask = (z_masked == self.mask_token).int()
|
950 |
+
|
951 |
+
# update the mask, remove conditioning codebooks from the mask
|
952 |
+
logging.debug(f"updated mask with shape: {mask.shape}")
|
953 |
+
# add z back into sampled z where the mask was false
|
954 |
+
sampled_z = torch.where(
|
955 |
+
mask.bool(), sampled_z, z_masked
|
956 |
+
)
|
957 |
+
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
958 |
+
|
959 |
+
# ignore any tokens that weren't masked
|
960 |
+
selected_probs = torch.where(
|
961 |
+
mask.bool(), selected_probs, torch.inf
|
962 |
+
)
|
963 |
+
|
964 |
+
# get the num tokens to mask, according to the schedule
|
965 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
966 |
+
logging.debug(f"num to mask: {num_to_mask}")
|
967 |
+
|
968 |
+
if i != (sampling_steps - 1):
|
969 |
+
num_to_mask = torch.maximum(
|
970 |
+
torch.tensor(1),
|
971 |
+
torch.minimum(
|
972 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
973 |
+
num_to_mask
|
974 |
+
)
|
975 |
+
)
|
976 |
+
|
977 |
+
|
978 |
+
# get our new mask
|
979 |
+
mask = mask_by_random_topk(
|
980 |
+
num_to_mask, selected_probs, mask_temperature * (1-r)
|
981 |
+
)
|
982 |
+
|
983 |
+
# update the mask
|
984 |
+
z_masked = torch.where(
|
985 |
+
mask.bool(), self.mask_token, sampled_z
|
986 |
+
)
|
987 |
+
logging.debug(f"updated z_masked with shape: {z_masked.shape}")
|
988 |
+
|
989 |
+
z_masked = codebook_unflatten(z_masked, n_infer_codebooks)
|
990 |
+
mask = codebook_unflatten(mask, n_infer_codebooks)
|
991 |
+
logging.debug(f"unflattened z_masked with shape: {z_masked.shape}")
|
992 |
+
|
993 |
+
# add conditioning codebooks back to z_masked
|
994 |
+
z_masked = torch.cat(
|
995 |
+
(z[:, :self.n_conditioning_codebooks, :], z_masked), dim=1
|
996 |
+
)
|
997 |
+
logging.debug(f"added conditioning codebooks back to z_masked with shape: {z_masked.shape}")
|
998 |
+
|
999 |
+
|
1000 |
+
# add conditioning codebooks back to sampled_z
|
1001 |
+
sampled_z = codebook_unflatten(sampled_z, n_infer_codebooks)
|
1002 |
+
sampled_z = torch.cat(
|
1003 |
+
(z[:, :self.n_conditioning_codebooks, :], sampled_z), dim=1
|
1004 |
+
)
|
1005 |
+
|
1006 |
+
logging.debug(f"finished sampling")
|
1007 |
+
|
1008 |
+
if return_signal:
|
1009 |
+
return self.to_signal(sampled_z, codec)
|
1010 |
+
else:
|
1011 |
+
return sampled_z
|
1012 |
+
|
1013 |
+
|
1014 |
+
|
1015 |
+
|
1016 |
+
|
1017 |
def sample_from_logits(
|
1018 |
logits,
|
1019 |
sample: bool = True,
|