|
from pathlib import Path |
|
import yaml |
|
import uuid |
|
|
|
import numpy as np |
|
import audiotools as at |
|
import argbind |
|
import shutil |
|
import torch |
|
from datetime import datetime |
|
|
|
import gradio as gr |
|
from vampnet.interface import Interface, signal_concat |
|
from vampnet import mask as pmask |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
interface = Interface( |
|
device=device, |
|
coarse_ckpt="models/nesquik/coarse.pth", |
|
coarse2fine_ckpt="models/nesquik/c2f.pth", |
|
codec_ckpt="models/nesquik/codec.pth", |
|
) |
|
|
|
|
|
MODEL_CHOICES = { |
|
"default": { |
|
"Interface.coarse_ckpt": str(interface.coarse_path), |
|
"Interface.coarse2fine_ckpt": str(interface.c2f_path), |
|
"Interface.codec_ckpt": str(interface.codec_path), |
|
} |
|
} |
|
generated_confs = Path("conf/generated") |
|
for conf_file in generated_confs.glob("*/interface.yml"): |
|
with open(conf_file) as f: |
|
_conf = yaml.safe_load(f) |
|
|
|
|
|
|
|
if not ( |
|
Path(_conf["Interface.coarse_ckpt"]).exists() and |
|
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and |
|
Path(_conf["Interface.codec_ckpt"]).exists() |
|
): |
|
continue |
|
|
|
MODEL_CHOICES[conf_file.parent.name] = _conf |
|
|
|
|
|
|
|
OUT_DIR = Path("gradio-outputs") |
|
OUT_DIR.mkdir(exist_ok=True, parents=True) |
|
|
|
MAX_DURATION_S = 60 |
|
def load_audio(file): |
|
print(file) |
|
filepath = file.name |
|
sig = at.AudioSignal.salient_excerpt( |
|
filepath, duration=MAX_DURATION_S |
|
) |
|
|
|
sig = at.AudioSignal(filepath) |
|
|
|
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4()) |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
sig.write(out_dir / "input.wav") |
|
return sig.path_to_file |
|
|
|
|
|
def load_example_audio(): |
|
return "./assets/example.wav" |
|
|
|
from torch_pitch_shift import pitch_shift, get_fast_shifts |
|
def shift_pitch(signal, interval: int): |
|
signal.samples = pitch_shift( |
|
signal.samples, |
|
shift=interval, |
|
sample_rate=signal.sample_rate |
|
) |
|
return signal |
|
|
|
def _vamp(seed, input_audio, model_choice, pitch_shift_amt, periodic_p, p2, n_mask_codebooks, n_mask_codebooks_2, rand_mask_intensity, prefix_s, suffix_s, periodic_w, onset_mask_width, dropout, masktemp, sampletemp, typical_filtering, typical_mass, typical_min_tokens, top_p, sample_cutoff, win_dur, num_feedback_steps, stretch_factor, api=False): |
|
_seed = seed if seed > 0 else None |
|
if _seed is None: |
|
_seed = int(torch.randint(0, 2**32, (1,)).item()) |
|
at.util.seed(_seed) |
|
|
|
datentime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') |
|
out_dir = OUT_DIR / f"{Path(input_audio).stem}-{datentime}-seed-{_seed}-model-{model_choice}" |
|
out_dir.mkdir(parents=True) |
|
sig = at.AudioSignal(input_audio) |
|
sig.write(out_dir / "input.wav") |
|
|
|
|
|
interface.reload( |
|
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"], |
|
c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"], |
|
) |
|
|
|
loudness = sig.loudness() |
|
print(f"input loudness is {loudness}") |
|
|
|
if pitch_shift_amt != 0: |
|
sig = shift_pitch(sig, pitch_shift_amt) |
|
|
|
_p2 = periodic_p if p2 == 0 else p2 |
|
_n_codebooks_2 = n_mask_codebooks if n_mask_codebooks_2 == 0 else n_mask_codebooks_2 |
|
|
|
build_mask_kwargs = dict( |
|
rand_mask_intensity=rand_mask_intensity, |
|
prefix_s=prefix_s, |
|
suffix_s=suffix_s, |
|
periodic_prompt=int(periodic_p), |
|
periodic_prompt2=int(_p2), |
|
periodic_prompt_width=periodic_w, |
|
onset_mask_width=onset_mask_width, |
|
_dropout=dropout, |
|
upper_codebook_mask=int(n_mask_codebooks), |
|
upper_codebook_mask_2=int(_n_codebooks_2), |
|
) |
|
|
|
vamp_kwargs = dict( |
|
mask_temperature=masktemp*10, |
|
sampling_temperature=sampletemp, |
|
typical_filtering=typical_filtering, |
|
typical_mass=typical_mass, |
|
typical_min_tokens=typical_min_tokens, |
|
top_p=top_p if top_p > 0 else None, |
|
seed=_seed, |
|
sample_cutoff=sample_cutoff, |
|
) |
|
|
|
|
|
interface.set_chunk_size(win_dur) |
|
sig, mask, codes = interface.ez_vamp( |
|
sig, |
|
batch_size=4 if not api else 1, |
|
feedback_steps=num_feedback_steps, |
|
time_stretch_factor=stretch_factor, |
|
build_mask_kwargs=build_mask_kwargs, |
|
vamp_kwargs=vamp_kwargs, |
|
return_mask=True, |
|
) |
|
|
|
if api: |
|
sig.write(out_dir / "out.wav") |
|
|
|
return sig.path_to_file |
|
|
|
if not api: |
|
|
|
np.save(out_dir / "codes.npy", codes.cpu().numpy()) |
|
metadata = {} |
|
metadata["seed"] = _seed |
|
metadata["model_choice"] = model_choice |
|
metadata["mask_kwargs"] = build_mask_kwargs |
|
metadata["vamp_kwargs"] = vamp_kwargs |
|
metadata["loudness"] = loudness |
|
|
|
with open(out_dir / "metadata.yml", "w") as f: |
|
yaml.dump(metadata, f) |
|
|
|
sig0 = sig[0].write(out_dir / "out1.wav") |
|
sig1 = sig[1].write(out_dir / "out2.wav") |
|
sig2 = sig[2].write(out_dir / "out3.wav") |
|
sig3 = sig[3].write(out_dir / "out4.wav") |
|
|
|
|
|
with open(out_dir / "mask.txt", "w") as f: |
|
m = mask[0].cpu().numpy() |
|
|
|
for i in range(m.shape[-1]): |
|
f.write(f"{m[:, i]}\n") |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
plt.clf() |
|
interface.visualize_codes(mask) |
|
plt.savefig(out_dir / "mask.png") |
|
plt.clf() |
|
interface.visualize_codes(codes) |
|
plt.savefig(out_dir / "codes.png") |
|
plt.close() |
|
|
|
|
|
shutil.make_archive(out_dir, 'zip', out_dir) |
|
|
|
|
|
_mask_sigs = [] |
|
for i in range(0, mask.shape[-1], 1024): |
|
_mask_sigs.append(interface.to_signal(mask[:, :, i:i+1024].to(interface.device)).cpu()) |
|
mask = signal_concat(_mask_sigs) |
|
mask.write(out_dir / "mask.wav") |
|
|
|
return ( |
|
sig0.path_to_file, sig1.path_to_file, |
|
sig2.path_to_file, sig3.path_to_file, |
|
mask.path_to_file, str(out_dir.with_suffix(".zip")), out_dir / "mask.png" |
|
) |
|
|
|
def vamp(data): |
|
return _vamp( |
|
seed=data[seed], |
|
input_audio=data[input_audio], |
|
model_choice=data[model_choice], |
|
pitch_shift_amt=data[pitch_shift_amt], |
|
periodic_p=data[periodic_p], |
|
p2=data[p2], |
|
n_mask_codebooks=data[n_mask_codebooks], |
|
n_mask_codebooks_2=data[n_mask_codebooks_2], |
|
rand_mask_intensity=data[rand_mask_intensity], |
|
prefix_s=data[prefix_s], |
|
suffix_s=data[suffix_s], |
|
periodic_w=data[periodic_w], |
|
onset_mask_width=data[onset_mask_width], |
|
dropout=data[dropout], |
|
masktemp=data[masktemp], |
|
sampletemp=data[sampletemp], |
|
typical_filtering=data[typical_filtering], |
|
typical_mass=data[typical_mass], |
|
typical_min_tokens=data[typical_min_tokens], |
|
top_p=data[top_p], |
|
sample_cutoff=data[sample_cutoff], |
|
win_dur=data[win_dur], |
|
num_feedback_steps=data[num_feedback_steps], |
|
stretch_factor=data[stretch_factor], |
|
api=False, |
|
) |
|
|
|
def api_vamp(data): |
|
return _vamp( |
|
seed=data[seed], |
|
input_audio=data[input_audio], |
|
model_choice=data[model_choice], |
|
pitch_shift_amt=data[pitch_shift_amt], |
|
periodic_p=data[periodic_p], |
|
p2=data[p2], |
|
n_mask_codebooks=data[n_mask_codebooks], |
|
n_mask_codebooks_2=data[n_mask_codebooks_2], |
|
rand_mask_intensity=data[rand_mask_intensity], |
|
prefix_s=data[prefix_s], |
|
suffix_s=data[suffix_s], |
|
periodic_w=data[periodic_w], |
|
onset_mask_width=data[onset_mask_width], |
|
dropout=data[dropout], |
|
masktemp=data[masktemp], |
|
sampletemp=data[sampletemp], |
|
typical_filtering=data[typical_filtering], |
|
typical_mass=data[typical_mass], |
|
typical_min_tokens=data[typical_min_tokens], |
|
top_p=data[top_p], |
|
sample_cutoff=data[sample_cutoff], |
|
win_dur=data[win_dur], |
|
num_feedback_steps=data[num_feedback_steps], |
|
stretch_factor=data[stretch_factor], |
|
api=True, |
|
) |
|
|
|
|
|
def harp_vamp(input_audio, |
|
periodic_p, |
|
n_mask_codebooks, |
|
pitch_shift_amt, |
|
win_dur, |
|
num_feedback_steps): |
|
return _vamp( |
|
seed=0, |
|
input_audio=input_audio, |
|
model_choice="default", |
|
pitch_shift_amt=pitch_shift_amt, |
|
periodic_p=periodic_p, |
|
p2=0, |
|
n_mask_codebooks=n_mask_codebooks, |
|
n_mask_codebooks_2=0, |
|
rand_mask_intensity=1.0, |
|
prefix_s=0.0, |
|
suffix_s=0.0, |
|
periodic_w=1, |
|
onset_mask_width=0, |
|
dropout=0.0, |
|
masktemp=1.5, |
|
sampletemp=1.0, |
|
typical_filtering=True, |
|
typical_mass=0.15, |
|
typical_min_tokens=64, |
|
top_p=0.9, |
|
sample_cutoff=1.0, |
|
win_dur=win_dur, |
|
num_feedback_steps=num_feedback_steps, |
|
stretch_factor=1.0, |
|
api=True, |
|
) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
manual_audio_upload = gr.File( |
|
label=f"upload some audio (will be randomly trimmed to max of 100s)", |
|
file_types=["audio"] |
|
) |
|
load_example_audio_button = gr.Button("or load example audio") |
|
|
|
input_audio = gr.Audio( |
|
label="input audio", |
|
interactive=False, |
|
type="filepath", |
|
) |
|
|
|
audio_mask = gr.Audio( |
|
label="audio mask (listen to this to hear the mask hints)", |
|
interactive=False, |
|
type="filepath", |
|
) |
|
|
|
|
|
load_example_audio_button.click( |
|
fn=load_example_audio, |
|
inputs=[], |
|
outputs=[ input_audio] |
|
) |
|
|
|
manual_audio_upload.change( |
|
fn=load_audio, |
|
inputs=[manual_audio_upload], |
|
outputs=[ input_audio] |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
with gr.Accordion("manual controls", open=True): |
|
periodic_p = gr.Slider( |
|
label="periodic prompt", |
|
minimum=0, |
|
maximum=128, |
|
step=1, |
|
value=3, |
|
) |
|
p2 = gr.Slider( |
|
label="periodic prompt 2 (0 - same as p1, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)", |
|
minimum=0, |
|
maximum=128, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
onset_mask_width = gr.Slider( |
|
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", |
|
minimum=0, |
|
maximum=100, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
n_mask_codebooks = gr.Slider( |
|
label="compression prompt ", |
|
value=3, |
|
minimum=0, |
|
maximum=14, |
|
step=1, |
|
) |
|
n_mask_codebooks_2 = gr.Number( |
|
label="compression prompt 2 via linear interpolation (0 == constant)", |
|
value=0, |
|
) |
|
|
|
with gr.Accordion("extras ", open=False): |
|
pitch_shift_amt = gr.Slider( |
|
label="pitch shift amount (semitones)", |
|
minimum=-12, |
|
maximum=12, |
|
step=1, |
|
value=0, |
|
) |
|
|
|
stretch_factor = gr.Slider( |
|
label="time stretch factor", |
|
minimum=0, |
|
maximum=64, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
rand_mask_intensity = gr.Slider( |
|
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1.0 |
|
) |
|
|
|
periodic_w = gr.Slider( |
|
label="periodic prompt width (steps, 1 step ~= 10milliseconds)", |
|
minimum=1, |
|
maximum=20, |
|
step=1, |
|
value=1, |
|
) |
|
|
|
with gr.Accordion("prefix/suffix prompts", open=True): |
|
prefix_s = gr.Slider( |
|
label="prefix hint length (seconds)", |
|
minimum=0.0, |
|
maximum=10.0, |
|
value=0.0 |
|
) |
|
suffix_s = gr.Slider( |
|
label="suffix hint length (seconds)", |
|
minimum=0.0, |
|
maximum=10.0, |
|
value=0.0 |
|
) |
|
|
|
masktemp = gr.Slider( |
|
label="mask temperature", |
|
minimum=0.0, |
|
maximum=100.0, |
|
value=1.5 |
|
) |
|
sampletemp = gr.Slider( |
|
label="sample temperature", |
|
minimum=0.1, |
|
maximum=10.0, |
|
value=1.0, |
|
step=0.001 |
|
) |
|
|
|
|
|
|
|
with gr.Accordion("sampling settings", open=False): |
|
top_p = gr.Slider( |
|
label="top p (0.0 = off)", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.9 |
|
) |
|
typical_filtering = gr.Checkbox( |
|
label="typical filtering ", |
|
value=True |
|
) |
|
typical_mass = gr.Slider( |
|
label="typical mass (should probably stay between 0.1 and 0.5)", |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.15 |
|
) |
|
typical_min_tokens = gr.Slider( |
|
label="typical min tokens (should probably stay between 1 and 256)", |
|
minimum=1, |
|
maximum=256, |
|
step=1, |
|
value=64 |
|
) |
|
sample_cutoff = gr.Slider( |
|
label="sample cutoff", |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1.0, |
|
step=0.01 |
|
) |
|
|
|
dropout = gr.Slider( |
|
label="mask dropout", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.01, |
|
value=0.0 |
|
) |
|
|
|
|
|
seed = gr.Number( |
|
label="seed (0 for random)", |
|
value=0, |
|
precision=0, |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Column(): |
|
|
|
model_choice = gr.Dropdown( |
|
label="model choice", |
|
choices=list(MODEL_CHOICES.keys()), |
|
value="default", |
|
visible=True |
|
) |
|
|
|
num_feedback_steps = gr.Slider( |
|
label="number of feedback steps (each one takes a while)", |
|
minimum=1, |
|
maximum=16, |
|
step=1, |
|
value=3 |
|
) |
|
|
|
win_dur= gr.Slider( |
|
label="window duration (seconds)", |
|
minimum=2, |
|
maximum=10, |
|
value=6) |
|
|
|
|
|
vamp_button = gr.Button("generate (vamp)!!!") |
|
maskimg = gr.Image( |
|
label="mask image", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
out1 = gr.Audio( |
|
label="output audio 1", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
out2 = gr.Audio( |
|
label="output audio 2", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
out3 = gr.Audio( |
|
label="output audio 3", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
out4 = gr.Audio( |
|
label="output audio 4", |
|
interactive=False, |
|
type="filepath" |
|
) |
|
|
|
thank_you = gr.Markdown("") |
|
|
|
|
|
download = gr.File(type="file", label="download outputs") |
|
|
|
|
|
_inputs = { |
|
input_audio, |
|
masktemp, |
|
sampletemp, |
|
top_p, |
|
prefix_s, suffix_s, |
|
rand_mask_intensity, |
|
periodic_p, periodic_w, |
|
dropout, |
|
stretch_factor, |
|
onset_mask_width, |
|
typical_filtering, |
|
typical_mass, |
|
typical_min_tokens, |
|
seed, |
|
model_choice, |
|
n_mask_codebooks, |
|
pitch_shift_amt, |
|
sample_cutoff, |
|
num_feedback_steps, |
|
p2, |
|
n_mask_codebooks_2, |
|
win_dur |
|
} |
|
|
|
|
|
vamp_button.click( |
|
fn=vamp, |
|
inputs=_inputs, |
|
outputs=[out1, out2, out3, out4, audio_mask, download, maskimg], |
|
) |
|
|
|
api_vamp_button = gr.Button("api vamp", visible=False) |
|
api_vamp_button.click( |
|
fn=api_vamp, |
|
inputs=_inputs, |
|
outputs=[out1], |
|
api_name="vamp" |
|
) |
|
|
|
from pyharp import ModelCard, build_endpoint |
|
|
|
model_card = ModelCard( |
|
name="nesquik", |
|
description="the ultimate 8-bit crusher", |
|
author="hugo flores garcía", |
|
tags=["generative","sound"], |
|
) |
|
|
|
build_endpoint( |
|
inputs=[ |
|
input_audio, |
|
periodic_p, |
|
n_mask_codebooks, |
|
pitch_shift_amt, |
|
win_dur, |
|
num_feedback_steps |
|
], |
|
output=out1, |
|
process_fn=harp_vamp, |
|
card=model_card |
|
) |
|
|
|
|
|
try: |
|
demo.queue() |
|
demo.launch(share=True) |
|
except KeyboardInterrupt: |
|
shutil.rmtree("gradio-outputs", ignore_errors=True) |
|
raise |