nnstomps / app.py
intrect's picture
Upload app.py with huggingface_hub
27ba96f verified
import os
from pathlib import Path
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
# ---------------------------------------------------------------------------
# Model definition (must match training code)
# ---------------------------------------------------------------------------
class NNStompGRU(nn.Module):
def __init__(self, cond_dim: int, hidden_size: int = 40):
super().__init__()
self.cond_dim = cond_dim
self.hidden_size = hidden_size
self.gru = nn.GRU(
input_size=1 + cond_dim,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.dense = nn.Linear(hidden_size, 1)
self.tanh = nn.Tanh()
def forward(self, x, cond, hidden=None):
batch, seq_len, _ = x.shape
cond_expanded = cond.unsqueeze(1).expand(-1, seq_len, -1)
inp = torch.cat([x, cond_expanded], dim=-1)
h, hidden_out = self.gru(inp, hidden)
out = self.tanh(self.dense(h))
return out, hidden_out
# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
MODELS = {
"Blackstar (Drive A/B)": {
"repo_file": "blackstar/best_model.pt",
"cond_dim": 2,
"controls": {
"Drive A": {"idx": 0, "min": 0, "max": 100, "default": 50},
"Drive B": {"idx": 1, "min": 0, "max": 100, "default": 0},
},
},
}
MODEL_REPO = "intrect/nnstomps-models"
_model_cache: dict[str, NNStompGRU] = {}
def load_model(name: str) -> NNStompGRU | None:
if name in _model_cache:
return _model_cache[name]
cfg = MODELS.get(name)
if cfg is None:
return None
local_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=cfg["repo_file"],
token=os.environ.get("HF_TOKEN"),
)
ckpt = torch.load(local_path, map_location="cpu", weights_only=True)
model = NNStompGRU(ckpt["config"]["cond_dim"], ckpt["config"]["hidden_size"])
model.load_state_dict(ckpt["model_state"])
model.eval()
_model_cache[name] = model
return model
# ---------------------------------------------------------------------------
# Audio processing
# ---------------------------------------------------------------------------
def process_audio(
audio_input,
model_name: str,
param1: float,
param2: float,
mix: float,
input_gain_db: float,
):
if audio_input is None:
return None
sr, data = audio_input
# float32
if data.dtype == np.int16:
data = data.astype(np.float32) / 32768.0
elif data.dtype == np.int32:
data = data.astype(np.float32) / 2147483648.0
elif data.dtype != np.float32:
data = data.astype(np.float32)
# stereo -> mono
if data.ndim == 2:
mono = data.mean(axis=1) if data.shape[1] <= 2 else data.mean(axis=0)
else:
mono = data
# input gain
gain = 10 ** (input_gain_db / 20.0)
mono = mono * gain
model = load_model(model_name)
if model is None:
return (sr, mono)
cfg = MODELS[model_name]
controls = cfg["controls"]
# build condition vector
cond = [0.0] * cfg["cond_dim"]
ctrl_list = list(controls.values())
if len(ctrl_list) >= 1:
c = ctrl_list[0]
cond[c["idx"]] = (param1 - c["min"]) / (c["max"] - c["min"])
if len(ctrl_list) >= 2:
c = ctrl_list[1]
cond[c["idx"]] = (param2 - c["min"]) / (c["max"] - c["min"])
# GRU inference (chunked)
chunk_size = 8192
output = np.zeros_like(mono)
hidden = None
with torch.no_grad():
cond_t = torch.tensor([cond], dtype=torch.float32)
for start in range(0, len(mono), chunk_size):
end = min(start + chunk_size, len(mono))
chunk = mono[start:end]
x = torch.from_numpy(chunk).unsqueeze(0).unsqueeze(-1)
pred, hidden = model(x, cond_t, hidden)
output[start:end] = pred[0, :, 0].numpy()
# dry/wet mix
wet = mono * (1 - mix) + output * mix
peak = np.max(np.abs(wet))
if peak > 0.99:
wet = wet * (0.99 / peak)
return (sr, wet.astype(np.float32))
def update_controls(model_name: str):
cfg = MODELS.get(model_name, {})
controls = cfg.get("controls", {})
ctrl_list = list(controls.items())
if len(ctrl_list) >= 1:
name1, c1 = ctrl_list[0]
p1_update = gr.update(
label=name1, minimum=c1["min"], maximum=c1["max"],
value=c1["default"], visible=True,
)
else:
p1_update = gr.update(visible=False)
if len(ctrl_list) >= 2:
name2, c2 = ctrl_list[1]
p2_update = gr.update(
label=name2, minimum=c2["min"], maximum=c2["max"],
value=c2["default"], visible=True,
)
else:
p2_update = gr.update(visible=False, value=0)
return p1_update, p2_update
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(
title="NNStomps — Neural Drive",
theme=gr.themes.Soft(primary_hue="orange"),
) as demo:
gr.Markdown(
"# NNStomps — Neural Drive\n"
"GRU neural network based saturation/distortion. "
"Upload audio and tweak the drive to hear the neural model in action."
)
with gr.Row():
with gr.Column(scale=1):
model_sel = gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
)
param1 = gr.Slider(
minimum=0, maximum=100, value=50, step=1, label="Drive A",
)
param2 = gr.Slider(
minimum=0, maximum=100, value=0, step=1, label="Drive B",
)
input_gain = gr.Slider(
minimum=-12, maximum=12, value=0, step=0.5,
label="Input Gain (dB)",
)
mix_slider = gr.Slider(
minimum=0, maximum=1.0, value=1.0, step=0.05,
label="Dry/Wet Mix",
)
process_btn = gr.Button("Process", variant="primary", size="lg")
with gr.Column(scale=2):
audio_in = gr.Audio(label="Input Audio", type="numpy")
audio_out = gr.Audio(label="Output Audio", type="numpy")
model_sel.change(
fn=update_controls, inputs=[model_sel], outputs=[param1, param2],
)
process_btn.click(
fn=process_audio,
inputs=[audio_in, model_sel, param1, param2, mix_slider, input_gain],
outputs=[audio_out],
)
demo.launch()