salad_bowl / app.py
Hugo Flores Garcia
fix periodic prompt
5f3fd3c
from pathlib import Path
import yaml
import uuid
import numpy as np
import audiotools as at
import argbind
import shutil
import torch
from datetime import datetime
import gradio as gr
from vampnet.interface import Interface, signal_concat
from vampnet import mask as pmask
device = "cuda" if torch.cuda.is_available() else "cpu"
interface = Interface(
device=device,
coarse_ckpt="models/vampnet/coarse.pth",
coarse2fine_ckpt="models/vampnet/c2f.pth",
codec_ckpt="models/vampnet/codec.pth",
)
# populate the model choices with any interface.yml files in the generated confs
MODEL_CHOICES = {
"default": {
"Interface.coarse_ckpt": str(interface.coarse_path),
"Interface.coarse2fine_ckpt": str(interface.c2f_path),
"Interface.codec_ckpt": str(interface.codec_path),
}
}
generated_confs = Path("conf/generated")
for conf_file in generated_confs.glob("*/interface.yml"):
with open(conf_file) as f:
_conf = yaml.safe_load(f)
# check if the coarse, c2f, and codec ckpts exist
# otherwise, dont' add this model choice
if not (
Path(_conf["Interface.coarse_ckpt"]).exists() and
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and
Path(_conf["Interface.codec_ckpt"]).exists()
):
continue
MODEL_CHOICES[conf_file.parent.name] = _conf
OUT_DIR = Path("gradio-outputs")
OUT_DIR.mkdir(exist_ok=True, parents=True)
MAX_DURATION_S = 60
def load_audio(file):
print(file)
filepath = file.name
sig = at.AudioSignal.salient_excerpt(
filepath, duration=MAX_DURATION_S
)
# sig = interface.preprocess(sig)
sig = at.AudioSignal(filepath)
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
out_dir.mkdir(parents=True, exist_ok=True)
sig.write(out_dir / "input.wav")
return sig.path_to_file
def load_example_audio():
return "./assets/example.wav"
from torch_pitch_shift import pitch_shift, get_fast_shifts
def shift_pitch(signal, interval: int):
signal.samples = pitch_shift(
signal.samples,
shift=interval,
sample_rate=signal.sample_rate
)
return signal
def _vamp(seed, input_audio, model_choice, pitch_shift_amt, periodic_p, p2, n_mask_codebooks, n_mask_codebooks_2, rand_mask_intensity, prefix_s, suffix_s, periodic_w, onset_mask_width, dropout, masktemp, sampletemp, typical_filtering, typical_mass, typical_min_tokens, top_p, sample_cutoff, win_dur, num_feedback_steps, stretch_factor, api=False):
_seed = seed if seed > 0 else None
if _seed is None:
_seed = int(torch.randint(0, 2**32, (1,)).item())
at.util.seed(_seed)
datentime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
out_dir = OUT_DIR / f"{Path(input_audio).stem}-{datentime}-seed-{_seed}-model-{model_choice}"
out_dir.mkdir(parents=True)
sig = at.AudioSignal(input_audio)
sig.write(out_dir / "input.wav")
# reload the model if necessary
interface.reload(
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"],
c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"],
)
loudness = sig.loudness()
print(f"input loudness is {loudness}")
if pitch_shift_amt != 0:
sig = shift_pitch(sig, pitch_shift_amt)
_p2 = periodic_p if p2 == 0 else p2
_n_codebooks_2 = n_mask_codebooks if n_mask_codebooks_2 == 0 else n_mask_codebooks_2
build_mask_kwargs = dict(
rand_mask_intensity=rand_mask_intensity,
prefix_s=prefix_s,
suffix_s=suffix_s,
periodic_prompt=int(periodic_p),
periodic_prompt2=int(_p2),
periodic_prompt_width=periodic_w,
onset_mask_width=onset_mask_width,
_dropout=dropout,
upper_codebook_mask=int(n_mask_codebooks),
upper_codebook_mask_2=int(_n_codebooks_2),
)
vamp_kwargs = dict(
mask_temperature=masktemp*10,
sampling_temperature=sampletemp,
typical_filtering=typical_filtering,
typical_mass=typical_mass,
typical_min_tokens=typical_min_tokens,
top_p=top_p if top_p > 0 else None,
seed=_seed,
sample_cutoff=sample_cutoff,
)
# save the mask as a txt file
interface.set_chunk_size(win_dur)
sig, mask, codes = interface.ez_vamp(
sig,
batch_size=4 if not api else 1,
feedback_steps=num_feedback_steps,
time_stretch_factor=stretch_factor,
build_mask_kwargs=build_mask_kwargs,
vamp_kwargs=vamp_kwargs,
return_mask=True,
)
if api:
sig.write(out_dir / "out.wav")
return sig.path_to_file
if not api:
# write codes to numpy file
np.save(out_dir / "codes.npy", codes.cpu().numpy())
metadata = {}
metadata["seed"] = _seed
metadata["model_choice"] = model_choice
metadata["mask_kwargs"] = build_mask_kwargs
metadata["vamp_kwargs"] = vamp_kwargs
metadata["loudness"] = loudness
# save the metadata
with open(out_dir / "metadata.yml", "w") as f:
yaml.dump(metadata, f)
sig0 = sig[0].write(out_dir / "out1.wav")
sig1 = sig[1].write(out_dir / "out2.wav")
sig2 = sig[2].write(out_dir / "out3.wav")
sig3 = sig[3].write(out_dir / "out4.wav")
# write the mask to txt
with open(out_dir / "mask.txt", "w") as f:
m = mask[0].cpu().numpy()
# write to txt, each time step on a new line
for i in range(m.shape[-1]):
f.write(f"{m[:, i]}\n")
import matplotlib.pyplot as plt
plt.clf()
interface.visualize_codes(mask)
plt.savefig(out_dir / "mask.png")
plt.clf()
interface.visualize_codes(codes)
plt.savefig(out_dir / "codes.png")
plt.close()
# zip out dir, and return the path to the zip
shutil.make_archive(out_dir, 'zip', out_dir)
# chunk in groups of 1024 timesteps
_mask_sigs = []
for i in range(0, mask.shape[-1], 1024):
_mask_sigs.append(interface.to_signal(mask[:, :, i:i+1024].to(interface.device)).cpu())
mask = signal_concat(_mask_sigs)
mask.write(out_dir / "mask.wav")
return (
sig0.path_to_file, sig1.path_to_file,
sig2.path_to_file, sig3.path_to_file,
mask.path_to_file, str(out_dir.with_suffix(".zip")), out_dir / "mask.png"
)
def vamp(data):
return _vamp(
seed=data[seed],
input_audio=data[input_audio],
model_choice=data[model_choice],
pitch_shift_amt=data[pitch_shift_amt],
periodic_p=data[periodic_p],
p2=data[p2],
n_mask_codebooks=data[n_mask_codebooks],
n_mask_codebooks_2=data[n_mask_codebooks_2],
rand_mask_intensity=data[rand_mask_intensity],
prefix_s=data[prefix_s],
suffix_s=data[suffix_s],
periodic_w=data[periodic_w],
onset_mask_width=data[onset_mask_width],
dropout=data[dropout],
masktemp=data[masktemp],
sampletemp=data[sampletemp],
typical_filtering=data[typical_filtering],
typical_mass=data[typical_mass],
typical_min_tokens=data[typical_min_tokens],
top_p=data[top_p],
sample_cutoff=data[sample_cutoff],
win_dur=data[win_dur],
num_feedback_steps=data[num_feedback_steps],
stretch_factor=data[stretch_factor],
api=False,
)
def api_vamp(data):
return _vamp(
seed=data[seed],
input_audio=data[input_audio],
model_choice=data[model_choice],
pitch_shift_amt=data[pitch_shift_amt],
periodic_p=data[periodic_p],
p2=data[p2],
n_mask_codebooks=data[n_mask_codebooks],
n_mask_codebooks_2=data[n_mask_codebooks_2],
rand_mask_intensity=data[rand_mask_intensity],
prefix_s=data[prefix_s],
suffix_s=data[suffix_s],
periodic_w=data[periodic_w],
onset_mask_width=data[onset_mask_width],
dropout=data[dropout],
masktemp=data[masktemp],
sampletemp=data[sampletemp],
typical_filtering=data[typical_filtering],
typical_mass=data[typical_mass],
typical_min_tokens=data[typical_min_tokens],
top_p=data[top_p],
sample_cutoff=data[sample_cutoff],
win_dur=data[win_dur],
num_feedback_steps=data[num_feedback_steps],
stretch_factor=data[stretch_factor],
api=True,
)
def harp_vamp(input_audio,
periodic_p,
n_mask_codebooks,
pitch_shift_amt,
win_dur):
return _vamp(
seed=0,
input_audio=input_audio,
model_choice="default",
pitch_shift_amt=pitch_shift_amt,
periodic_p=periodic_p,
p2=0,
n_mask_codebooks=n_mask_codebooks,
n_mask_codebooks_2=0,
rand_mask_intensity=1.0,
prefix_s=0.0,
suffix_s=0.0,
periodic_w=1,
onset_mask_width=0,
dropout=0.0,
masktemp=1.5,
sampletemp=1.0,
typical_filtering=True,
typical_mass=0.15,
typical_min_tokens=64,
top_p=0.9,
sample_cutoff=1.0,
win_dur=win_dur,
num_feedback_steps=1,
stretch_factor=1.0,
api=True,
)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
manual_audio_upload = gr.File(
label=f"upload some audio (will be randomly trimmed to max of 100s)",
file_types=["audio"]
)
load_example_audio_button = gr.Button("or load example audio")
input_audio = gr.Audio(
label="input audio",
interactive=False,
type="filepath",
)
audio_mask = gr.Audio(
label="audio mask (listen to this to hear the mask hints)",
interactive=False,
type="filepath",
)
# connect widgets
load_example_audio_button.click(
fn=load_example_audio,
inputs=[],
outputs=[ input_audio]
)
manual_audio_upload.change(
fn=load_audio,
inputs=[manual_audio_upload],
outputs=[ input_audio]
)
# mask settings
with gr.Column():
with gr.Accordion("manual controls", open=True):
periodic_p = gr.Slider(
label="periodic prompt",
minimum=0,
maximum=128,
step=1,
value=3,
)
p2 = gr.Slider(
label="periodic prompt 2 (0 - same as p1, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
minimum=0,
maximum=128,
step=1,
value=0,
)
onset_mask_width = gr.Slider(
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ",
minimum=0,
maximum=100,
step=1,
value=0,
)
n_mask_codebooks = gr.Slider(
label="compression prompt ",
value=3,
minimum=0,
maximum=14,
step=1,
)
n_mask_codebooks_2 = gr.Number(
label="compression prompt 2 via linear interpolation (0 == constant)",
value=0,
)
with gr.Accordion("extras ", open=False):
pitch_shift_amt = gr.Slider(
label="pitch shift amount (semitones)",
minimum=-12,
maximum=12,
step=1,
value=0,
)
stretch_factor = gr.Slider(
label="time stretch factor",
minimum=0,
maximum=64,
step=1,
value=1,
)
rand_mask_intensity = gr.Slider(
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
minimum=0.0,
maximum=1.0,
value=1.0
)
periodic_w = gr.Slider(
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
minimum=1,
maximum=20,
step=1,
value=1,
)
with gr.Accordion("prefix/suffix prompts", open=True):
prefix_s = gr.Slider(
label="prefix hint length (seconds)",
minimum=0.0,
maximum=10.0,
value=0.0
)
suffix_s = gr.Slider(
label="suffix hint length (seconds)",
minimum=0.0,
maximum=10.0,
value=0.0
)
masktemp = gr.Slider(
label="mask temperature",
minimum=0.0,
maximum=100.0,
value=1.5
)
sampletemp = gr.Slider(
label="sample temperature",
minimum=0.1,
maximum=10.0,
value=1.0,
step=0.001
)
with gr.Accordion("sampling settings", open=False):
top_p = gr.Slider(
label="top p (0.0 = off)",
minimum=0.0,
maximum=1.0,
value=0.9
)
typical_filtering = gr.Checkbox(
label="typical filtering ",
value=True
)
typical_mass = gr.Slider(
label="typical mass (should probably stay between 0.1 and 0.5)",
minimum=0.01,
maximum=0.99,
value=0.15
)
typical_min_tokens = gr.Slider(
label="typical min tokens (should probably stay between 1 and 256)",
minimum=1,
maximum=256,
step=1,
value=64
)
sample_cutoff = gr.Slider(
label="sample cutoff",
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.01
)
dropout = gr.Slider(
label="mask dropout",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.0
)
seed = gr.Number(
label="seed (0 for random)",
value=0,
precision=0,
)
# mask settings
with gr.Column():
model_choice = gr.Dropdown(
label="model choice",
choices=list(MODEL_CHOICES.keys()),
value="default",
visible=True
)
num_feedback_steps = gr.Slider(
label="number of feedback steps (each one takes a while)",
minimum=1,
maximum=16,
step=1,
value=1
)
win_dur= gr.Slider(
label="window duration (seconds)",
minimum=2,
maximum=10,
value=6)
vamp_button = gr.Button("generate (vamp)!!!")
maskimg = gr.Image(
label="mask image",
interactive=False,
type="filepath"
)
out1 = gr.Audio(
label="output audio 1",
interactive=False,
type="filepath"
)
out2 = gr.Audio(
label="output audio 2",
interactive=False,
type="filepath"
)
out3 = gr.Audio(
label="output audio 3",
interactive=False,
type="filepath"
)
out4 = gr.Audio(
label="output audio 4",
interactive=False,
type="filepath"
)
thank_you = gr.Markdown("")
# download all the outputs
download = gr.File(type="file", label="download outputs")
_inputs = {
input_audio,
masktemp,
sampletemp,
top_p,
prefix_s, suffix_s,
rand_mask_intensity,
periodic_p, periodic_w,
dropout,
stretch_factor,
onset_mask_width,
typical_filtering,
typical_mass,
typical_min_tokens,
seed,
model_choice,
n_mask_codebooks,
pitch_shift_amt,
sample_cutoff,
num_feedback_steps,
p2,
n_mask_codebooks_2,
win_dur
}
# connect widgets
vamp_button.click(
fn=vamp,
inputs=_inputs,
outputs=[out1, out2, out3, out4, audio_mask, download, maskimg],
)
api_vamp_button = gr.Button("api vamp", visible=False)
api_vamp_button.click(
fn=api_vamp,
inputs=_inputs,
outputs=[out1],
api_name="vamp"
)
from pyharp import ModelCard, build_endpoint
model_card = ModelCard(
name="salad bowl",
description="sounds",
author="hugo flores garcía",
tags=["generative","sound"],
)
build_endpoint(
inputs=[
input_audio,
periodic_p,
n_mask_codebooks,
pitch_shift_amt,
win_dur,
],
output=out1,
process_fn=harp_vamp,
card=model_card
)
try:
demo.queue()
demo.launch(share=True)
except KeyboardInterrupt:
shutil.rmtree("gradio-outputs", ignore_errors=True)
raise