vampnet / demo.py
Hugo Flores Garcia
more demo ctrls
b61e699
raw
history blame
13.8 kB
from pathlib import Path
from typing import Tuple
import yaml
import tempfile
import uuid
from dataclasses import dataclass, asdict
import numpy as np
import audiotools as at
import argbind
import gradio as gr
from vampnet.interface import Interface
from vampnet import mask as pmask
Interface = argbind.bind(Interface)
AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
conf = argbind.parse_args()
with argbind.scope(conf):
interface = Interface()
loader = AudioLoader()
print(f"interface device is {interface.device}")
dataset = at.data.datasets.AudioDataset(
loader,
sample_rate=interface.codec.sample_rate,
duration=interface.coarse.chunk_size_s,
n_examples=5000,
without_replacement=True,
)
OUT_DIR = Path("gradio-outputs")
OUT_DIR.mkdir(exist_ok=True, parents=True)
def load_audio(file):
print(file)
filepath = file.name
sig = at.AudioSignal.salient_excerpt(
filepath,
duration=interface.coarse.chunk_size_s
)
sig = interface.preprocess(sig)
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
out_dir.mkdir(parents=True, exist_ok=True)
sig.write(out_dir / "input.wav")
return sig.path_to_file
def load_random_audio():
index = np.random.randint(0, len(dataset))
sig = dataset[index]["signal"]
sig = interface.preprocess(sig)
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
out_dir.mkdir(parents=True, exist_ok=True)
sig.write(out_dir / "input.wav")
return sig.path_to_file
def _vamp(data, return_mask=False):
out_dir = OUT_DIR / str(uuid.uuid4())
out_dir.mkdir()
sig = at.AudioSignal(data[input_audio])
#pitch shift input
sig = sig.shift_pitch(data[input_pitch_shift])
# TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
z = interface.encode(sig)
ncc = data[n_conditioning_codebooks]
# build the mask
mask = pmask.linear_random(z, data[rand_mask_intensity])
mask = pmask.mask_and(
mask, pmask.inpaint(
z,
interface.s2t(data[prefix_s]),
interface.s2t(data[suffix_s])
)
)
mask = pmask.mask_and(
mask, pmask.periodic_mask(
z,
data[periodic_p],
data[periodic_w],
random_roll=True
)
)
if data[onset_mask_width] > 0:
mask = pmask.mask_or(
mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
)
# these should be the last two mask ops
mask = pmask.dropout(mask, data[dropout])
mask = pmask.codebook_unmask(mask, ncc)
print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[init_temp]}, final temp {data[final_temp]}, use coarse2fine {data[use_coarse2fine]}")
# save the mask as a txt file
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
top_k = data[topk] if data[topk] > 0 else None
zv, mask_z = interface.coarse_vamp(
z,
mask=mask,
sampling_steps=data[num_steps],
temperature=(data[init_temp], data[final_temp]),
return_mask=True,
sample=data[sampling_strategy],
typical_filtering=data[typical_filtering],
typical_mass=data[typical_mass],
typical_min_tokens=data[typical_min_tokens],
top_k=top_k,
)
if use_coarse2fine:
zv = interface.coarse_to_fine(zv)
sig = interface.to_signal(zv).cpu()
print("done")
sig.write(out_dir / "output.wav")
if return_mask:
mask = interface.to_signal(mask_z).cpu()
mask.write(out_dir / "mask.wav")
return sig.path_to_file, mask.path_to_file
else:
return sig.path_to_file
def vamp(data):
return _vamp(data, return_mask=True)
def api_vamp(data):
return _vamp(data, return_mask=False)
def save_vamp(data):
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
out_dir.mkdir(parents=True, exist_ok=True)
sig_in = at.AudioSignal(data[input_audio])
sig_out = at.AudioSignal(data[output_audio])
sig_in.write(out_dir / "input.wav")
sig_out.write(out_dir / "output.wav")
_data = {
"init_temp": data[init_temp],
"final_temp": data[final_temp],
"prefix_s": data[prefix_s],
"suffix_s": data[suffix_s],
"rand_mask_intensity": data[rand_mask_intensity],
"num_steps": data[num_steps],
"notes": data[notes_text],
"periodic_period": data[periodic_p],
"periodic_width": data[periodic_w],
"n_conditioning_codebooks": data[n_conditioning_codebooks],
"use_coarse2fine": data[use_coarse2fine],
"stretch_factor": data[stretch_factor],
}
# save with yaml
with open(out_dir / "data.yaml", "w") as f:
yaml.dump(_data, f)
import zipfile
zip_path = out_dir.with_suffix(".zip")
with zipfile.ZipFile(zip_path, "w") as zf:
for file in out_dir.iterdir():
zf.write(file, file.name)
return f"saved! your save code is {out_dir.stem}", zip_path
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
use_coarse2fine = gr.Checkbox(
label="use coarse2fine",
value=True
)
manual_audio_upload = gr.File(
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
file_types=["audio"]
)
load_random_audio_button = gr.Button("or load random audio")
input_audio = gr.Audio(
label="input audio",
interactive=False,
type="filepath",
)
audio_mask = gr.Audio(
label="audio mask (listen to this to hear the mask hints)",
interactive=False,
type="filepath",
)
# connect widgets
load_random_audio_button.click(
fn=load_random_audio,
inputs=[],
outputs=[ input_audio]
)
manual_audio_upload.change(
fn=load_audio,
inputs=[manual_audio_upload],
outputs=[ input_audio]
)
# mask settings
with gr.Column():
input_pitch_shift = gr.Slider(
label="input pitch shift (semitones)",
minimum=-12,
maximum=12,
step=1,
value=0,
)
rand_mask_intensity = gr.Slider(
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
minimum=0.0,
maximum=1.0,
value=1.0
)
periodic_p = gr.Slider(
label="periodic prompt (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
minimum=0,
maximum=128,
step=1,
value=9,
)
periodic_w = gr.Slider(
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
minimum=1,
maximum=20,
step=1,
value=1,
)
onset_mask_width = gr.Slider(
label="onset mask width (steps, 1 step ~= 10milliseconds)",
minimum=0,
maximum=20,
step=1,
value=0,
)
with gr.Accordion("extras ", open=False):
n_conditioning_codebooks = gr.Number(
label="number of conditioning codebooks. probably 0",
value=0,
precision=0,
)
stretch_factor = gr.Slider(
label="time stretch factor",
minimum=0,
maximum=64,
step=1,
value=1,
)
with gr.Accordion("prefix/suffix hints", open=False):
prefix_s = gr.Slider(
label="prefix hint length (seconds)",
minimum=0.0,
maximum=10.0,
value=0.0
)
suffix_s = gr.Slider(
label="suffix hint length (seconds)",
minimum=0.0,
maximum=10.0,
value=0.0
)
with gr.Accordion("temperature settings", open=False):
init_temp = gr.Slider(
label="initial temperature (should probably stay between 0.6 and 1)",
minimum=0.0,
maximum=1.5,
value=0.8
)
final_temp = gr.Slider(
label="final temperature (should probably stay between 0.7 and 2)",
minimum=0.0,
maximum=2.0,
value=1.0
)
with gr.Accordion("sampling settings", open=False):
sampling_strategy = gr.Radio(
label="sampling strategy",
choices=["gumbel", "multinomial"],
value="gumbel"
)
typical_filtering = gr.Checkbox(
label="typical filtering (cannot be used with topk)",
value=True
)
typical_mass = gr.Slider(
label="typical mass (should probably stay between 0.1 and 0.5)",
minimum=0.01,
maximum=0.99,
value=0.2
)
typical_min_tokens = gr.Slider(
label="typical min tokens (should probably stay between 1 and 256)",
minimum=1,
maximum=256,
step=1,
value=1
)
topk = gr.Slider(
label="topk (cannot be used with typical filtering). 0 = None",
minimum=0,
maximum=256,
step=1,
value=0
)
num_steps = gr.Slider(
label="number of steps (should normally be between 12 and 36)",
minimum=1,
maximum=128,
step=1,
value=36
)
dropout = gr.Slider(
label="mask dropout",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0
)
vamp_button = gr.Button("vamp!!!")
# mask settings
with gr.Column():
output_audio = gr.Audio(
label="output audio",
interactive=False,
type="filepath"
)
use_as_input_button = gr.Button("use as input")
# with gr.Column():
# with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
# use_beats = gr.Checkbox(
# label="use beat hints (helps the output stick to the beat structure of the input)",
# value=False
# )
# snap_to_beats = gr.Checkbox(
# label="trim to beat markers (uncheck if the output audio is too short.)",
# value=True
# )
# beat_unmask_dur = gr.Slider(
# label="duration",
# minimum=0.0,
# maximum=3.0,
# value=0.07
# )
notes_text = gr.Textbox(
label="type any notes about the generated audio here",
value="",
interactive=True
)
save_button = gr.Button("save vamp")
download_file = gr.File(
label="vamp to download will appear here",
interactive=False
)
thank_you = gr.Markdown("")
_inputs = {
input_audio,
num_steps,
init_temp, final_temp,
prefix_s, suffix_s,
rand_mask_intensity,
periodic_p, periodic_w,
n_conditioning_codebooks,
dropout,
use_coarse2fine,
stretch_factor,
onset_mask_width,
input_pitch_shift,
sampling_strategy,
typical_filtering,
typical_mass,
typical_min_tokens,
topk,
}
# connect widgets
vamp_button.click(
fn=vamp,
inputs=_inputs,
outputs=[output_audio, audio_mask],
)
api_vamp_button = gr.Button("api vamp")
api_vamp_button.click(
fn=api_vamp,
inputs=_inputs,
outputs=[output_audio],
api_name="vamp"
)
use_as_input_button.click(
fn=lambda x: x,
inputs=[output_audio],
outputs=[input_audio]
)
save_button.click(
fn=save_vamp,
inputs=_inputs | {notes_text, output_audio},
outputs=[thank_you, download_file]
)
demo.launch(share=True, enable_queue=False, debug=True)