Spaces:
Sleeping
Sleeping
| import math | |
| import tempfile | |
| from pathlib import Path | |
| import yaml | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio as ta | |
| import soundfile as sf | |
| import gradio as gr | |
| from tqdm import tqdm | |
| from typing import Union, Tuple, Optional | |
| from torch import Tensor | |
| from pyharp import build_endpoint, ModelCard | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # UNet Utilities | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class UNetUtils: | |
| def __init__(self, F=None, T=None, n_fft=4096, win_length=None, | |
| hop_length=None, center=True, device='cpu'): | |
| self.n_fft = n_fft | |
| self.win_length = n_fft if win_length is None else win_length | |
| self.hop_length = self.win_length // 4 if hop_length is None else hop_length | |
| self.hann_window = torch.hann_window(self.win_length, periodic=True).to(device) | |
| self.center = center | |
| self.device = device | |
| self.F = F | |
| self.T = T | |
| def fold_unet_inputs(self, x): | |
| time_dim = x.size(-1) | |
| pad_len = math.ceil(time_dim / self.T) * self.T - time_dim | |
| padded = F.pad(x, (0, pad_len)) | |
| if time_dim < self.T: | |
| return padded | |
| return torch.cat(torch.split(padded, self.T, dim=-1), dim=0) | |
| def unfold_unet_outputs(self, x, input_size): | |
| batch_size, n_frames = input_size[0], input_size[-1] | |
| if x.size(0) == batch_size: | |
| return x[..., :n_frames] | |
| x = torch.cat(torch.split(x, batch_size, dim=0), dim=-1) | |
| return x[..., :n_frames] | |
| def trim_freq_dim(self, x): | |
| return x[..., :self.F, :] | |
| def pad_freq_dim(self, x): | |
| padding = (self.n_fft // 2 + 1) - x.size(-2) | |
| return F.pad(x, (0, 0, 0, padding)) | |
| def pad_stft_input(self, x): | |
| pad_len = (-(x.size(-1) - self.win_length) % self.hop_length) % self.win_length | |
| return F.pad(x, (0, pad_len)) | |
| def _stft(self, x): | |
| return torch.stft(input=x, n_fft=self.n_fft, window=self.hann_window, | |
| win_length=self.win_length, hop_length=self.hop_length, | |
| center=self.center, return_complex=True) | |
| def _istft(self, x, trim_length=None): | |
| return torch.istft(input=x, n_fft=self.n_fft, window=self.hann_window, | |
| win_length=self.win_length, hop_length=self.hop_length, | |
| center=self.center, length=trim_length) | |
| def batch_stft(self, x, pad=True, return_complex=False): | |
| x_shape = x.size() | |
| x = x.reshape(-1, x_shape[-1]) | |
| if pad: | |
| x = self.pad_stft_input(x) | |
| S = self._stft(x) | |
| S = S.reshape(x_shape[:-1] + S.shape[-2:]) | |
| if return_complex: | |
| return S | |
| return S.abs(), S.angle() | |
| def batch_istft(self, magnitude, phase, trim_length=None): | |
| S = torch.polar(magnitude, phase) | |
| S_shape = S.size() | |
| S = S.reshape(-1, S_shape[-2], S_shape[-1]) | |
| x = self._istft(S, trim_length) | |
| return x.reshape(S_shape[:-2] + x.shape[-1:]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # UNet Blocks | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class UNetEncoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=(5,5), | |
| stride=(2,2), padding=(2,2), relu_slope=0.2): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=stride, padding=padding) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.activ = nn.LeakyReLU(relu_slope) | |
| nn.init.kaiming_uniform_(self.conv.weight, nonlinearity='leaky_relu', a=relu_slope) | |
| nn.init.zeros_(self.conv.bias) | |
| def forward(self, x): | |
| c = self.conv(x) | |
| return self.activ(self.bn(c)), c | |
| class UNetDecoderBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=(5,5), | |
| stride=(2,2), padding=(2,2), output_padding=(1,1), dropout=0.0): | |
| super().__init__() | |
| self.conv_trans = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=stride, padding=padding, output_padding=output_padding) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.dropout = nn.Dropout(dropout) | |
| self.activ = nn.ReLU() | |
| def forward(self, x): | |
| return self.dropout(self.bn(self.activ(self.conv_trans(x)))) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # UNet Models | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class UNet(nn.Module): | |
| def __init__(self, input_size: Tuple[int, ...] = (2, 2048, 512), | |
| power: float = 1.0, device: Optional[str] = None): | |
| super().__init__() | |
| self.input_size = input_size | |
| audio_channels, f_size, t_size = input_size | |
| self.utils = UNetUtils(F=f_size, T=t_size, device=device) | |
| self.input_norm = nn.BatchNorm2d(f_size) | |
| self.enc1 = UNetEncoderBlock(audio_channels, 16) | |
| self.enc2 = UNetEncoderBlock(16, 32) | |
| self.enc3 = UNetEncoderBlock(32, 64) | |
| self.enc4 = UNetEncoderBlock(64, 128) | |
| self.enc5 = UNetEncoderBlock(128, 256) | |
| self.enc6 = UNetEncoderBlock(256, 512) | |
| self.dec1 = UNetDecoderBlock(512, 256, dropout=0.5) | |
| self.dec2 = UNetDecoderBlock(512, 128, dropout=0.5) | |
| self.dec3 = UNetDecoderBlock(256, 64, dropout=0.5) | |
| self.dec4 = UNetDecoderBlock(128, 32) | |
| self.dec5 = UNetDecoderBlock(64, 16) | |
| self.dec6 = UNetDecoderBlock(32, audio_channels) | |
| self.mask_layer = nn.Sequential( | |
| nn.Conv2d(audio_channels, audio_channels, kernel_size=(4,4), dilation=(2,2), padding=3), | |
| nn.Sigmoid() | |
| ) | |
| nn.init.kaiming_uniform_(self.mask_layer[0].weight) | |
| nn.init.zeros_(self.mask_layer[0].bias) | |
| if device is not None: | |
| self.to(device) | |
| def produce_mask(self, x: Tensor) -> Tensor: | |
| x = self.input_norm(x.transpose(1, 2)).transpose(1, 2) | |
| d, c1 = self.enc1(x) | |
| d, c2 = self.enc2(d) | |
| d, c3 = self.enc3(d) | |
| d, c4 = self.enc4(d) | |
| d, c5 = self.enc5(d) | |
| _, c6 = self.enc6(d) | |
| u = self.dec1(c6) | |
| u = self.dec2(torch.cat([c5, u], dim=1)) | |
| u = self.dec3(torch.cat([c4, u], dim=1)) | |
| u = self.dec4(torch.cat([c3, u], dim=1)) | |
| u = self.dec5(torch.cat([c2, u], dim=1)) | |
| u = self.dec6(torch.cat([c1, u], dim=1)) | |
| return self.mask_layer(u) | |
| def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: | |
| input_size = x.size() | |
| x = self.utils.fold_unet_inputs(x) | |
| i = self.utils.trim_freq_dim(x) | |
| mask = self.produce_mask(i) | |
| mask = self.utils.pad_freq_dim(mask) | |
| return (self.utils.unfold_unet_outputs(x * mask, input_size), | |
| self.utils.unfold_unet_outputs(mask, input_size)) | |
| class UNetWaveform(UNet): | |
| def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: | |
| if x.dim() == 1: | |
| x = x.repeat(2, 1) | |
| if x.dim() == 2: | |
| x = x.unsqueeze(0) | |
| mag, phase = self.utils.batch_stft(x) | |
| mag_hat, mask = super().forward(mag) | |
| return self.utils.batch_istft(mag_hat, phase, trim_length=x.size(-1)), mask | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LarsNet | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class LarsNet(nn.Module): | |
| def __init__(self, wiener_filter=False, wiener_exponent=1.0, | |
| config: Union[str, Path] = "config.yaml", | |
| return_stft=False, device='cpu', **kwargs): | |
| super().__init__(**kwargs) | |
| with open(config, "r") as f: | |
| config = yaml.safe_load(f) | |
| self.device = device | |
| self.wiener_filter = wiener_filter | |
| self.wiener_exponent = wiener_exponent | |
| self.return_stft = return_stft | |
| self.stems = config['inference_models'].keys() | |
| self.utils = UNetUtils(device=self.device) | |
| self.sr = config['global']['sr'] | |
| self.models = {} | |
| print('Loading UNet models...') | |
| for stem in tqdm(self.stems): | |
| checkpoint_path = Path(config['inference_models'][stem]) | |
| F = config[stem]['F'] | |
| T = config[stem]['T'] | |
| model = (UNet if (wiener_filter or return_stft) else UNetWaveform)( | |
| input_size=(2, F, T), device=self.device | |
| ) | |
| checkpoint = torch.load(str(checkpoint_path), map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| self.models[stem] = model | |
| def _fix_dim(x): | |
| if x.dim() == 1: | |
| x = x.repeat(2, 1) | |
| if x.dim() == 2: | |
| x = x.unsqueeze(0) | |
| return x | |
| def separate(self, x): | |
| out = {} | |
| x = x.to(self.device) | |
| for stem, model in tqdm(self.models.items()): | |
| y, _ = model(x) | |
| out[stem] = y.squeeze(0).detach() | |
| return out | |
| def separate_wiener(self, x): | |
| out = {} | |
| mag_pred = [] | |
| x = self._fix_dim(x).to(self.device) | |
| mag, phase = self.utils.batch_stft(x) | |
| for stem, model in tqdm(self.models.items()): | |
| _, mask = model(mag) | |
| mag_pred.append((mask * mag) ** self.wiener_exponent) | |
| pred_sum = sum(mag_pred) | |
| for stem, pred in zip(self.stems, mag_pred): | |
| wiener_mask = pred / (pred_sum + 1e-7) | |
| y = self.utils.batch_istft(mag * wiener_mask, phase, trim_length=x.size(-1)) | |
| out[stem] = y.squeeze(0).detach() | |
| return out | |
| def separate_stft(self, x): | |
| out = {} | |
| x = self._fix_dim(x).to(self.device) | |
| mag, phase = self.utils.batch_stft(x) | |
| for stem, model in tqdm(self.models.items()): | |
| mag_pred, _ = model(mag) | |
| out[stem] = torch.polar(mag_pred, phase).squeeze(0).detach() | |
| return out | |
| def forward(self, x): | |
| if isinstance(x, (str, Path)): | |
| x, sr_ = ta.load(str(x)) | |
| if sr_ != self.sr: | |
| x = ta.functional.resample(x, sr_, self.sr) | |
| if self.return_stft: | |
| return self.separate_stft(x) | |
| elif self.wiener_filter: | |
| return self.separate_wiener(x) | |
| else: | |
| return self.separate(x) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # App | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| model_card = ModelCard( | |
| name="LarsNet Drum Stem Separator", | |
| description="Separates a drum mix into individual drum stems: Kick, Snare, Toms, Hi-Hat, and Cymbals.", | |
| author="A. I. Mezza, et al.", | |
| tags=["drums", "demucs", "source-separation", "pyharp", "stems", "multi-output"], | |
| ) | |
| MODEL = LarsNet(wiener_filter=False, device="cpu", config="config.yaml") | |
| def process_fn(audio_path: str): | |
| stems = MODEL(audio_path) | |
| output_dir = Path("outputs") | |
| output_dir.mkdir(exist_ok=True) | |
| output_paths = [] | |
| for stem_name in ["kick", "snare", "toms", "hihat", "cymbals"]: | |
| out_path = output_dir / f"{stem_name}.wav" | |
| sf.write(out_path, stems[stem_name].cpu().numpy().T, MODEL.sr) | |
| output_paths.append(str(out_path)) | |
| return tuple(output_paths) | |
| with gr.Blocks() as demo: | |
| input_audio = gr.Audio(type="filepath", label="Drum Mix (Input)").harp_required(True) | |
| output_kick = gr.Audio(type="filepath", label="Kick") | |
| output_snare = gr.Audio(type="filepath", label="Snare") | |
| output_toms = gr.Audio(type="filepath", label="Toms") | |
| output_hihat = gr.Audio(type="filepath", label="Hi-Hat") | |
| output_cymbals = gr.Audio(type="filepath", label="Cymbals") | |
| app = build_endpoint( | |
| model_card=model_card, | |
| input_components=[input_audio], | |
| output_components=[output_kick, output_snare, output_toms, output_hihat, output_cymbals], | |
| process_fn=process_fn, | |
| ) | |
| demo.queue().launch(show_error=True, share=True) |