Spaces:
Running
Running
from pathlib import Path | |
import yaml | |
import uuid | |
import numpy as np | |
import audiotools as at | |
import argbind | |
import shutil | |
import torch | |
from datetime import datetime | |
import gradio as gr | |
from vampnet.interface import Interface, signal_concat | |
from vampnet import mask as pmask | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
interface = Interface( | |
device=device, | |
coarse_ckpt="models/vampnet/coarse.pth", | |
coarse2fine_ckpt="models/vampnet/c2f.pth", | |
codec_ckpt="models/vampnet/codec.pth", | |
) | |
# populate the model choices with any interface.yml files in the generated confs | |
MODEL_CHOICES = { | |
"default": { | |
"Interface.coarse_ckpt": str(interface.coarse_path), | |
"Interface.coarse2fine_ckpt": str(interface.c2f_path), | |
"Interface.codec_ckpt": str(interface.codec_path), | |
} | |
} | |
generated_confs = Path("conf/generated") | |
for conf_file in generated_confs.glob("*/interface.yml"): | |
with open(conf_file) as f: | |
_conf = yaml.safe_load(f) | |
# check if the coarse, c2f, and codec ckpts exist | |
# otherwise, dont' add this model choice | |
if not ( | |
Path(_conf["Interface.coarse_ckpt"]).exists() and | |
Path(_conf["Interface.coarse2fine_ckpt"]).exists() and | |
Path(_conf["Interface.codec_ckpt"]).exists() | |
): | |
continue | |
MODEL_CHOICES[conf_file.parent.name] = _conf | |
OUT_DIR = Path("gradio-outputs") | |
OUT_DIR.mkdir(exist_ok=True, parents=True) | |
MAX_DURATION_S = 60 | |
def load_audio(file): | |
print(file) | |
filepath = file.name | |
sig = at.AudioSignal.salient_excerpt( | |
filepath, duration=MAX_DURATION_S | |
) | |
# sig = interface.preprocess(sig) | |
sig = at.AudioSignal(filepath) | |
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4()) | |
out_dir.mkdir(parents=True, exist_ok=True) | |
sig.write(out_dir / "input.wav") | |
return sig.path_to_file | |
def load_example_audio(): | |
return "./assets/example.wav" | |
from torch_pitch_shift import pitch_shift, get_fast_shifts | |
def shift_pitch(signal, interval: int): | |
signal.samples = pitch_shift( | |
signal.samples, | |
shift=interval, | |
sample_rate=signal.sample_rate | |
) | |
return signal | |
def _vamp(seed, input_audio, model_choice, pitch_shift_amt, periodic_p, p2, n_mask_codebooks, n_mask_codebooks_2, rand_mask_intensity, prefix_s, suffix_s, periodic_w, onset_mask_width, dropout, masktemp, sampletemp, typical_filtering, typical_mass, typical_min_tokens, top_p, sample_cutoff, win_dur, num_feedback_steps, stretch_factor, api=False): | |
_seed = seed if seed > 0 else None | |
if _seed is None: | |
_seed = int(torch.randint(0, 2**32, (1,)).item()) | |
at.util.seed(_seed) | |
datentime = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | |
out_dir = OUT_DIR / f"{Path(input_audio).stem}-{datentime}-seed-{_seed}-model-{model_choice}" | |
out_dir.mkdir(parents=True) | |
sig = at.AudioSignal(input_audio) | |
sig.write(out_dir / "input.wav") | |
# reload the model if necessary | |
interface.reload( | |
coarse_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse_ckpt"], | |
c2f_ckpt=MODEL_CHOICES[model_choice]["Interface.coarse2fine_ckpt"], | |
) | |
loudness = sig.loudness() | |
print(f"input loudness is {loudness}") | |
if pitch_shift_amt != 0: | |
sig = shift_pitch(sig, pitch_shift_amt) | |
_p2 = periodic_p if p2 == 0 else p2 | |
_n_codebooks_2 = n_mask_codebooks if n_mask_codebooks_2 == 0 else n_mask_codebooks_2 | |
build_mask_kwargs = dict( | |
rand_mask_intensity=rand_mask_intensity, | |
prefix_s=prefix_s, | |
suffix_s=suffix_s, | |
periodic_prompt=int(periodic_p), | |
periodic_prompt2=int(_p2), | |
periodic_prompt_width=periodic_w, | |
onset_mask_width=onset_mask_width, | |
_dropout=dropout, | |
upper_codebook_mask=int(n_mask_codebooks), | |
upper_codebook_mask_2=int(_n_codebooks_2), | |
) | |
vamp_kwargs = dict( | |
mask_temperature=masktemp*10, | |
sampling_temperature=sampletemp, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_p=top_p if top_p > 0 else None, | |
seed=_seed, | |
sample_cutoff=sample_cutoff, | |
) | |
# save the mask as a txt file | |
interface.set_chunk_size(win_dur) | |
sig, mask, codes = interface.ez_vamp( | |
sig, | |
batch_size=4 if not api else 1, | |
feedback_steps=num_feedback_steps, | |
time_stretch_factor=stretch_factor, | |
build_mask_kwargs=build_mask_kwargs, | |
vamp_kwargs=vamp_kwargs, | |
return_mask=True, | |
) | |
if api: | |
sig.write(out_dir / "out.wav") | |
return sig.path_to_file | |
if not api: | |
# write codes to numpy file | |
np.save(out_dir / "codes.npy", codes.cpu().numpy()) | |
metadata = {} | |
metadata["seed"] = _seed | |
metadata["model_choice"] = model_choice | |
metadata["mask_kwargs"] = build_mask_kwargs | |
metadata["vamp_kwargs"] = vamp_kwargs | |
metadata["loudness"] = loudness | |
# save the metadata | |
with open(out_dir / "metadata.yml", "w") as f: | |
yaml.dump(metadata, f) | |
sig0 = sig[0].write(out_dir / "out1.wav") | |
sig1 = sig[1].write(out_dir / "out2.wav") | |
sig2 = sig[2].write(out_dir / "out3.wav") | |
sig3 = sig[3].write(out_dir / "out4.wav") | |
# write the mask to txt | |
with open(out_dir / "mask.txt", "w") as f: | |
m = mask[0].cpu().numpy() | |
# write to txt, each time step on a new line | |
for i in range(m.shape[-1]): | |
f.write(f"{m[:, i]}\n") | |
import matplotlib.pyplot as plt | |
plt.clf() | |
interface.visualize_codes(mask) | |
plt.savefig(out_dir / "mask.png") | |
plt.clf() | |
interface.visualize_codes(codes) | |
plt.savefig(out_dir / "codes.png") | |
plt.close() | |
# zip out dir, and return the path to the zip | |
shutil.make_archive(out_dir, 'zip', out_dir) | |
# chunk in groups of 1024 timesteps | |
_mask_sigs = [] | |
for i in range(0, mask.shape[-1], 1024): | |
_mask_sigs.append(interface.to_signal(mask[:, :, i:i+1024].to(interface.device)).cpu()) | |
mask = signal_concat(_mask_sigs) | |
mask.write(out_dir / "mask.wav") | |
return ( | |
sig0.path_to_file, sig1.path_to_file, | |
sig2.path_to_file, sig3.path_to_file, | |
mask.path_to_file, str(out_dir.with_suffix(".zip")), out_dir / "mask.png" | |
) | |
def vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
p2=data[p2], | |
n_mask_codebooks=data[n_mask_codebooks], | |
n_mask_codebooks_2=data[n_mask_codebooks_2], | |
rand_mask_intensity=data[rand_mask_intensity], | |
prefix_s=data[prefix_s], | |
suffix_s=data[suffix_s], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
masktemp=data[masktemp], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
win_dur=data[win_dur], | |
num_feedback_steps=data[num_feedback_steps], | |
stretch_factor=data[stretch_factor], | |
api=False, | |
) | |
def api_vamp(data): | |
return _vamp( | |
seed=data[seed], | |
input_audio=data[input_audio], | |
model_choice=data[model_choice], | |
pitch_shift_amt=data[pitch_shift_amt], | |
periodic_p=data[periodic_p], | |
p2=data[p2], | |
n_mask_codebooks=data[n_mask_codebooks], | |
n_mask_codebooks_2=data[n_mask_codebooks_2], | |
rand_mask_intensity=data[rand_mask_intensity], | |
prefix_s=data[prefix_s], | |
suffix_s=data[suffix_s], | |
periodic_w=data[periodic_w], | |
onset_mask_width=data[onset_mask_width], | |
dropout=data[dropout], | |
masktemp=data[masktemp], | |
sampletemp=data[sampletemp], | |
typical_filtering=data[typical_filtering], | |
typical_mass=data[typical_mass], | |
typical_min_tokens=data[typical_min_tokens], | |
top_p=data[top_p], | |
sample_cutoff=data[sample_cutoff], | |
win_dur=data[win_dur], | |
num_feedback_steps=data[num_feedback_steps], | |
stretch_factor=data[stretch_factor], | |
api=True, | |
) | |
def harp_vamp(input_audio, | |
periodic_p, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
win_dur): | |
return _vamp( | |
seed=0, | |
input_audio=input_audio, | |
model_choice="default", | |
pitch_shift_amt=pitch_shift_amt, | |
periodic_p=periodic_p, | |
p2=0, | |
n_mask_codebooks=n_mask_codebooks, | |
n_mask_codebooks_2=0, | |
rand_mask_intensity=1.0, | |
prefix_s=0.0, | |
suffix_s=0.0, | |
periodic_w=1, | |
onset_mask_width=0, | |
dropout=0.0, | |
masktemp=1.5, | |
sampletemp=1.0, | |
typical_filtering=True, | |
typical_mass=0.15, | |
typical_min_tokens=64, | |
top_p=0.9, | |
sample_cutoff=1.0, | |
win_dur=win_dur, | |
num_feedback_steps=1, | |
stretch_factor=1.0, | |
api=True, | |
) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
manual_audio_upload = gr.File( | |
label=f"upload some audio (will be randomly trimmed to max of 100s)", | |
file_types=["audio"] | |
) | |
load_example_audio_button = gr.Button("or load example audio") | |
input_audio = gr.Audio( | |
label="input audio", | |
interactive=False, | |
type="filepath", | |
) | |
audio_mask = gr.Audio( | |
label="audio mask (listen to this to hear the mask hints)", | |
interactive=False, | |
type="filepath", | |
) | |
# connect widgets | |
load_example_audio_button.click( | |
fn=load_example_audio, | |
inputs=[], | |
outputs=[ input_audio] | |
) | |
manual_audio_upload.change( | |
fn=load_audio, | |
inputs=[manual_audio_upload], | |
outputs=[ input_audio] | |
) | |
# mask settings | |
with gr.Column(): | |
with gr.Accordion("manual controls", open=True): | |
periodic_p = gr.Slider( | |
label="periodic prompt", | |
minimum=0, | |
maximum=128, | |
step=1, | |
value=3, | |
) | |
p2 = gr.Slider( | |
label="periodic prompt 2 (0 - same as p1, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)", | |
minimum=0, | |
maximum=128, | |
step=1, | |
value=0, | |
) | |
onset_mask_width = gr.Slider( | |
label="onset mask width (multiplies with the periodic mask, 1 step ~= 10milliseconds) ", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=0, | |
) | |
n_mask_codebooks = gr.Slider( | |
label="compression prompt ", | |
value=3, | |
minimum=0, | |
maximum=14, | |
step=1, | |
) | |
n_mask_codebooks_2 = gr.Number( | |
label="compression prompt 2 via linear interpolation (0 == constant)", | |
value=0, | |
) | |
with gr.Accordion("extras ", open=False): | |
pitch_shift_amt = gr.Slider( | |
label="pitch shift amount (semitones)", | |
minimum=-12, | |
maximum=12, | |
step=1, | |
value=0, | |
) | |
stretch_factor = gr.Slider( | |
label="time stretch factor", | |
minimum=0, | |
maximum=64, | |
step=1, | |
value=1, | |
) | |
rand_mask_intensity = gr.Slider( | |
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)", | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0 | |
) | |
periodic_w = gr.Slider( | |
label="periodic prompt width (steps, 1 step ~= 10milliseconds)", | |
minimum=1, | |
maximum=20, | |
step=1, | |
value=1, | |
) | |
with gr.Accordion("prefix/suffix prompts", open=True): | |
prefix_s = gr.Slider( | |
label="prefix hint length (seconds)", | |
minimum=0.0, | |
maximum=10.0, | |
value=0.0 | |
) | |
suffix_s = gr.Slider( | |
label="suffix hint length (seconds)", | |
minimum=0.0, | |
maximum=10.0, | |
value=0.0 | |
) | |
masktemp = gr.Slider( | |
label="mask temperature", | |
minimum=0.0, | |
maximum=100.0, | |
value=1.5 | |
) | |
sampletemp = gr.Slider( | |
label="sample temperature", | |
minimum=0.1, | |
maximum=10.0, | |
value=1.0, | |
step=0.001 | |
) | |
with gr.Accordion("sampling settings", open=False): | |
top_p = gr.Slider( | |
label="top p (0.0 = off)", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.9 | |
) | |
typical_filtering = gr.Checkbox( | |
label="typical filtering ", | |
value=True | |
) | |
typical_mass = gr.Slider( | |
label="typical mass (should probably stay between 0.1 and 0.5)", | |
minimum=0.01, | |
maximum=0.99, | |
value=0.15 | |
) | |
typical_min_tokens = gr.Slider( | |
label="typical min tokens (should probably stay between 1 and 256)", | |
minimum=1, | |
maximum=256, | |
step=1, | |
value=64 | |
) | |
sample_cutoff = gr.Slider( | |
label="sample cutoff", | |
minimum=0.0, | |
maximum=1.0, | |
value=1.0, | |
step=0.01 | |
) | |
dropout = gr.Slider( | |
label="mask dropout", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.0 | |
) | |
seed = gr.Number( | |
label="seed (0 for random)", | |
value=0, | |
precision=0, | |
) | |
# mask settings | |
with gr.Column(): | |
model_choice = gr.Dropdown( | |
label="model choice", | |
choices=list(MODEL_CHOICES.keys()), | |
value="default", | |
visible=True | |
) | |
num_feedback_steps = gr.Slider( | |
label="number of feedback steps (each one takes a while)", | |
minimum=1, | |
maximum=16, | |
step=1, | |
value=1 | |
) | |
win_dur= gr.Slider( | |
label="window duration (seconds)", | |
minimum=2, | |
maximum=10, | |
value=6) | |
vamp_button = gr.Button("generate (vamp)!!!") | |
maskimg = gr.Image( | |
label="mask image", | |
interactive=False, | |
type="filepath" | |
) | |
out1 = gr.Audio( | |
label="output audio 1", | |
interactive=False, | |
type="filepath" | |
) | |
out2 = gr.Audio( | |
label="output audio 2", | |
interactive=False, | |
type="filepath" | |
) | |
out3 = gr.Audio( | |
label="output audio 3", | |
interactive=False, | |
type="filepath" | |
) | |
out4 = gr.Audio( | |
label="output audio 4", | |
interactive=False, | |
type="filepath" | |
) | |
thank_you = gr.Markdown("") | |
# download all the outputs | |
download = gr.File(type="file", label="download outputs") | |
_inputs = { | |
input_audio, | |
masktemp, | |
sampletemp, | |
top_p, | |
prefix_s, suffix_s, | |
rand_mask_intensity, | |
periodic_p, periodic_w, | |
dropout, | |
stretch_factor, | |
onset_mask_width, | |
typical_filtering, | |
typical_mass, | |
typical_min_tokens, | |
seed, | |
model_choice, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
sample_cutoff, | |
num_feedback_steps, | |
p2, | |
n_mask_codebooks_2, | |
win_dur | |
} | |
# connect widgets | |
vamp_button.click( | |
fn=vamp, | |
inputs=_inputs, | |
outputs=[out1, out2, out3, out4, audio_mask, download, maskimg], | |
) | |
api_vamp_button = gr.Button("api vamp", visible=False) | |
api_vamp_button.click( | |
fn=api_vamp, | |
inputs=_inputs, | |
outputs=[out1], | |
api_name="vamp" | |
) | |
from pyharp import ModelCard, build_endpoint | |
model_card = ModelCard( | |
name="salad bowl", | |
description="sounds", | |
author="hugo flores garcía", | |
tags=["generative","sound"], | |
) | |
build_endpoint( | |
inputs=[ | |
input_audio, | |
periodic_p, | |
n_mask_codebooks, | |
pitch_shift_amt, | |
win_dur, | |
], | |
output=out1, | |
process_fn=harp_vamp, | |
card=model_card | |
) | |
try: | |
demo.queue() | |
demo.launch(share=True) | |
except KeyboardInterrupt: | |
shutil.rmtree("gradio-outputs", ignore_errors=True) | |
raise |