LarsNet / app.py
Richard Zhu
Add LarsNet drum separator
94ce22b
import math
import tempfile
from pathlib import Path
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio as ta
import soundfile as sf
import gradio as gr
from tqdm import tqdm
from typing import Union, Tuple, Optional
from torch import Tensor
from pyharp import build_endpoint, ModelCard
# ─────────────────────────────────────────────
# UNet Utilities
# ─────────────────────────────────────────────
class UNetUtils:
def __init__(self, F=None, T=None, n_fft=4096, win_length=None,
hop_length=None, center=True, device='cpu'):
self.n_fft = n_fft
self.win_length = n_fft if win_length is None else win_length
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
self.hann_window = torch.hann_window(self.win_length, periodic=True).to(device)
self.center = center
self.device = device
self.F = F
self.T = T
def fold_unet_inputs(self, x):
time_dim = x.size(-1)
pad_len = math.ceil(time_dim / self.T) * self.T - time_dim
padded = F.pad(x, (0, pad_len))
if time_dim < self.T:
return padded
return torch.cat(torch.split(padded, self.T, dim=-1), dim=0)
def unfold_unet_outputs(self, x, input_size):
batch_size, n_frames = input_size[0], input_size[-1]
if x.size(0) == batch_size:
return x[..., :n_frames]
x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1)
return x[..., :n_frames]
def trim_freq_dim(self, x):
return x[..., :self.F, :]
def pad_freq_dim(self, x):
padding = (self.n_fft // 2 + 1) - x.size(-2)
return F.pad(x, (0, 0, 0, padding))
def pad_stft_input(self, x):
pad_len = (-(x.size(-1) - self.win_length) % self.hop_length) % self.win_length
return F.pad(x, (0, pad_len))
def _stft(self, x):
return torch.stft(input=x, n_fft=self.n_fft, window=self.hann_window,
win_length=self.win_length, hop_length=self.hop_length,
center=self.center, return_complex=True)
def _istft(self, x, trim_length=None):
return torch.istft(input=x, n_fft=self.n_fft, window=self.hann_window,
win_length=self.win_length, hop_length=self.hop_length,
center=self.center, length=trim_length)
def batch_stft(self, x, pad=True, return_complex=False):
x_shape = x.size()
x = x.reshape(-1, x_shape[-1])
if pad:
x = self.pad_stft_input(x)
S = self._stft(x)
S = S.reshape(x_shape[:-1] + S.shape[-2:])
if return_complex:
return S
return S.abs(), S.angle()
def batch_istft(self, magnitude, phase, trim_length=None):
S = torch.polar(magnitude, phase)
S_shape = S.size()
S = S.reshape(-1, S_shape[-2], S_shape[-1])
x = self._istft(S, trim_length)
return x.reshape(S_shape[:-2] + x.shape[-1:])
# ─────────────────────────────────────────────
# UNet Blocks
# ─────────────────────────────────────────────
class UNetEncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(5,5),
stride=(2,2), padding=(2,2), relu_slope=0.2):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding)
self.bn = nn.BatchNorm2d(out_channels)
self.activ = nn.LeakyReLU(relu_slope)
nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='leaky_relu', a=relu_slope)
nn.init.zeros_(self.conv.bias)
def forward(self, x):
c = self.conv(x)
return self.activ(self.bn(c)), c
class UNetDecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=(5,5),
stride=(2,2), padding=(2,2), output_padding=(1,1), dropout=0.0):
super().__init__()
self.conv_trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=output_padding)
self.bn = nn.BatchNorm2d(out_channels)
self.dropout = nn.Dropout(dropout)
self.activ = nn.ReLU()
def forward(self, x):
return self.dropout(self.bn(self.activ(self.conv_trans(x))))
# ─────────────────────────────────────────────
# UNet Models
# ─────────────────────────────────────────────
class UNet(nn.Module):
def __init__(self, input_size: Tuple[int, ...] = (2, 2048, 512),
power: float = 1.0, device: Optional[str] = None):
super().__init__()
self.input_size = input_size
audio_channels, f_size, t_size = input_size
self.utils = UNetUtils(F=f_size, T=t_size, device=device)
self.input_norm = nn.BatchNorm2d(f_size)
self.enc1 = UNetEncoderBlock(audio_channels, 16)
self.enc2 = UNetEncoderBlock(16, 32)
self.enc3 = UNetEncoderBlock(32, 64)
self.enc4 = UNetEncoderBlock(64, 128)
self.enc5 = UNetEncoderBlock(128, 256)
self.enc6 = UNetEncoderBlock(256, 512)
self.dec1 = UNetDecoderBlock(512, 256, dropout=0.5)
self.dec2 = UNetDecoderBlock(512, 128, dropout=0.5)
self.dec3 = UNetDecoderBlock(256, 64, dropout=0.5)
self.dec4 = UNetDecoderBlock(128, 32)
self.dec5 = UNetDecoderBlock(64, 16)
self.dec6 = UNetDecoderBlock(32, audio_channels)
self.mask_layer = nn.Sequential(
nn.Conv2d(audio_channels, audio_channels, kernel_size=(4,4), dilation=(2,2), padding=3),
nn.Sigmoid()
)
nn.init.kaiming_uniform_(self.mask_layer[0].weight)
nn.init.zeros_(self.mask_layer[0].bias)
if device is not None:
self.to(device)
def produce_mask(self, x: Tensor) -> Tensor:
x = self.input_norm(x.transpose(1, 2)).transpose(1, 2)
d, c1 = self.enc1(x)
d, c2 = self.enc2(d)
d, c3 = self.enc3(d)
d, c4 = self.enc4(d)
d, c5 = self.enc5(d)
_, c6 = self.enc6(d)
u = self.dec1(c6)
u = self.dec2(torch.cat([c5, u], dim=1))
u = self.dec3(torch.cat([c4, u], dim=1))
u = self.dec4(torch.cat([c3, u], dim=1))
u = self.dec5(torch.cat([c2, u], dim=1))
u = self.dec6(torch.cat([c1, u], dim=1))
return self.mask_layer(u)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
input_size = x.size()
x = self.utils.fold_unet_inputs(x)
i = self.utils.trim_freq_dim(x)
mask = self.produce_mask(i)
mask = self.utils.pad_freq_dim(mask)
return (self.utils.unfold_unet_outputs(x * mask, input_size),
self.utils.unfold_unet_outputs(mask, input_size))
class UNetWaveform(UNet):
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
if x.dim() == 1:
x = x.repeat(2, 1)
if x.dim() == 2:
x = x.unsqueeze(0)
mag, phase = self.utils.batch_stft(x)
mag_hat, mask = super().forward(mag)
return self.utils.batch_istft(mag_hat, phase, trim_length=x.size(-1)), mask
# ─────────────────────────────────────────────
# LarsNet
# ─────────────────────────────────────────────
class LarsNet(nn.Module):
def __init__(self, wiener_filter=False, wiener_exponent=1.0,
config: Union[str, Path] = "config.yaml",
return_stft=False, device='cpu', **kwargs):
super().__init__(**kwargs)
with open(config, "r") as f:
config = yaml.safe_load(f)
self.device = device
self.wiener_filter = wiener_filter
self.wiener_exponent = wiener_exponent
self.return_stft = return_stft
self.stems = config['inference_models'].keys()
self.utils = UNetUtils(device=self.device)
self.sr = config['global']['sr']
self.models = {}
print('Loading UNet models...')
for stem in tqdm(self.stems):
checkpoint_path = Path(config['inference_models'][stem])
F = config[stem]['F']
T = config[stem]['T']
model = (UNet if (wiener_filter or return_stft) else UNetWaveform)(
input_size=(2, F, T), device=self.device
)
checkpoint = torch.load(str(checkpoint_path), map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
self.models[stem] = model
@staticmethod
def _fix_dim(x):
if x.dim() == 1:
x = x.repeat(2, 1)
if x.dim() == 2:
x = x.unsqueeze(0)
return x
def separate(self, x):
out = {}
x = x.to(self.device)
for stem, model in tqdm(self.models.items()):
y, _ = model(x)
out[stem] = y.squeeze(0).detach()
return out
def separate_wiener(self, x):
out = {}
mag_pred = []
x = self._fix_dim(x).to(self.device)
mag, phase = self.utils.batch_stft(x)
for stem, model in tqdm(self.models.items()):
_, mask = model(mag)
mag_pred.append((mask * mag) ** self.wiener_exponent)
pred_sum = sum(mag_pred)
for stem, pred in zip(self.stems, mag_pred):
wiener_mask = pred / (pred_sum + 1e-7)
y = self.utils.batch_istft(mag * wiener_mask, phase, trim_length=x.size(-1))
out[stem] = y.squeeze(0).detach()
return out
def separate_stft(self, x):
out = {}
x = self._fix_dim(x).to(self.device)
mag, phase = self.utils.batch_stft(x)
for stem, model in tqdm(self.models.items()):
mag_pred, _ = model(mag)
out[stem] = torch.polar(mag_pred, phase).squeeze(0).detach()
return out
def forward(self, x):
if isinstance(x, (str, Path)):
x, sr_ = ta.load(str(x))
if sr_ != self.sr:
x = ta.functional.resample(x, sr_, self.sr)
if self.return_stft:
return self.separate_stft(x)
elif self.wiener_filter:
return self.separate_wiener(x)
else:
return self.separate(x)
# ─────────────────────────────────────────────
# App
# ─────────────────────────────────────────────
model_card = ModelCard(
name="LarsNet Drum Stem Separator",
description="Separates a drum mix into individual drum stems: Kick, Snare, Toms, Hi-Hat, and Cymbals.",
author="A. I. Mezza, et al.",
tags=["drums", "demucs", "source-separation", "pyharp", "stems", "multi-output"],
)
MODEL = LarsNet(wiener_filter=False, device="cpu", config="config.yaml")
@torch.inference_mode()
def process_fn(audio_path: str):
stems = MODEL(audio_path)
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
output_paths = []
for stem_name in ["kick", "snare", "toms", "hihat", "cymbals"]:
out_path = output_dir / f"{stem_name}.wav"
sf.write(out_path, stems[stem_name].cpu().numpy().T, MODEL.sr)
output_paths.append(str(out_path))
return tuple(output_paths)
with gr.Blocks() as demo:
input_audio = gr.Audio(type="filepath", label="Drum Mix (Input)").harp_required(True)
output_kick = gr.Audio(type="filepath", label="Kick")
output_snare = gr.Audio(type="filepath", label="Snare")
output_toms = gr.Audio(type="filepath", label="Toms")
output_hihat = gr.Audio(type="filepath", label="Hi-Hat")
output_cymbals = gr.Audio(type="filepath", label="Cymbals")
app = build_endpoint(
model_card=model_card,
input_components=[input_audio],
output_components=[output_kick, output_snare, output_toms, output_hihat, output_cymbals],
process_fn=process_fn,
)
demo.queue().launch(show_error=True, share=True)