|
|
|
import os |
|
import torch |
|
import numpy as np |
|
import soundfile as sf |
|
import gradio as gr |
|
from model import UFormer, UFormerConfig |
|
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
CHECKPOINT_DIR = "checkpoints" |
|
config = UFormerConfig() |
|
_model_cache = {} |
|
|
|
VALID_CKPTS = [ |
|
"acoustic_guitar","bass","electric_guitar","guitars","keyboards", |
|
"orchestra","rhythm_section","synth","vocals" |
|
] |
|
|
|
def _get_model(ckpt_name: str): |
|
if ckpt_name not in VALID_CKPTS: |
|
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}") |
|
if ckpt_name in _model_cache: |
|
return _model_cache[ckpt_name] |
|
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth") |
|
model = UFormer(config).to(DEVICE).eval() |
|
state = torch.load(ckpt_path, map_location=DEVICE) |
|
model.load_state_dict(state) |
|
_model_cache[ckpt_name] = model |
|
return model |
|
|
|
|
|
|
|
|
|
def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5): |
|
C, T = x.shape |
|
chunk, hop = int(sr*chunk_s), int(sr*hop_s) |
|
pad = (-(T - chunk) % hop) if T > chunk else 0 |
|
x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect") |
|
win = np.hanning(chunk)[None, :] |
|
out = np.zeros_like(x_pad) |
|
norm = np.zeros((1, x_pad.shape[1])) |
|
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop |
|
print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...") |
|
|
|
for i in range(n_chunks): |
|
s = i * hop |
|
seg = x_pad[:, s:s+chunk].astype(np.float32) |
|
with torch.no_grad(): |
|
y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy() |
|
out[:, s:s+chunk] += y * win |
|
norm[:, s:s+chunk] += win |
|
|
|
eps = 1e-8 |
|
return (out / (norm + eps))[:, :T] |
|
|
|
|
|
|
|
|
|
def restore_fn(audio_path, checkpoint): |
|
audio, sr = sf.read(audio_path) |
|
if audio.ndim == 1: |
|
audio = np.stack([audio, audio], axis=1) |
|
x = audio.T |
|
|
|
model = _get_model(checkpoint) |
|
if x.shape[1] <= sr * 5: |
|
seg = x.astype(np.float32)[None] |
|
with torch.no_grad(): |
|
y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy() |
|
else: |
|
y = _overlap_add(model, x, sr) |
|
|
|
tmp = "restored.wav" |
|
sf.write(tmp, y.T, sr, format="WAV") |
|
return tmp |
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=restore_fn, |
|
inputs=[ |
|
gr.Audio(sources="upload", type="filepath", label="Your Input"), |
|
gr.Dropdown(VALID_CKPTS, label="Checkpoint") |
|
], |
|
outputs=gr.Audio(type="filepath", label="Restored Output"), |
|
title="π΅ Music Source Restoration", |
|
description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...", |
|
allow_flagging="never" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
else: |
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|