Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) | |
from functools import lru_cache | |
from typing import Optional, Tuple | |
import ffmpeg | |
import numpy as np | |
import torch | |
from huggingface_hub import hf_hub_download | |
from pydub import AudioSegment | |
from unet import UNet | |
def load_audio(filename): | |
probe = ffmpeg.probe(filename) | |
if "streams" not in probe or len(probe["streams"]) == 0: | |
raise ValueError("No stream was found with ffprobe") | |
metadata = next( | |
stream for stream in probe["streams"] if stream["codec_type"] == "audio" | |
) | |
n_channels = metadata["channels"] | |
sample_rate = 44100 | |
process = ( | |
ffmpeg.input(filename) | |
.output("pipe:", format="f32le", ar=sample_rate) | |
.run_async(pipe_stdout=True, pipe_stderr=True) | |
) | |
buffer, _ = process.communicate() | |
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels) | |
waveform = torch.from_numpy(waveform).to(torch.float32) | |
if n_channels == 1: | |
waveform = waveform.tile(1, 2) | |
if n_channels > 2: | |
waveform = waveform[:, :2] | |
return waveform | |
def separate( | |
vocals: torch.nn.Module, | |
accompaniment: torch.nn.Module, | |
waveform: torch.Tensor, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096)) | |
# torch.stft requires a 2-D input of shape (N, T), so we transpose waveform | |
stft = torch.stft( | |
waveform.t(), | |
n_fft=4096, | |
hop_length=1024, | |
window=torch.hann_window(4096, periodic=True), | |
center=False, | |
onesided=True, | |
return_complex=True, | |
) | |
# stft: (2, 2049, 465) | |
# stft is a complex tensor | |
y = stft.permute(2, 1, 0) | |
# (465, 2049, 2) | |
y = y[:, :1024, :] | |
# (465, 1024, 2) | |
tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512 | |
pad_size = 512 - tensor_size | |
y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size)) | |
# (512, 1024, 2) | |
num_splits = int(y.shape[0] / 512) | |
y = y.reshape([num_splits, 512] + list(y.shape[1:])) | |
# y: (1, 512, 1024, 2) | |
y = y.abs() | |
y = y.permute(0, 3, 1, 2) | |
# (1, 2, 512, 1024) | |
vocals_spec = vocals(y) | |
accompaniment_spec = accompaniment(y) | |
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 | |
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec | |
# (1, 2, 512, 1024) | |
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec | |
# (1, 2, 512, 1024) | |
ans = [] | |
for spec in [vocals_spec, accompaniment_spec]: | |
spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0)) | |
# (1, 2, 512, 2049) | |
spec = spec.permute(0, 2, 3, 1) | |
# (1, 512, 2049, 2) | |
spec = spec.reshape(-1, spec.shape[2], spec.shape[3]) | |
# (512, 2049, 2) | |
spec = spec[: stft.shape[2], :, :] | |
# (465, 2049, 2) | |
spec = spec.permute(2, 1, 0) | |
# (2, 2049, 465) | |
masked_stft = spec * stft | |
wave = torch.istft( | |
masked_stft, | |
4096, | |
1024, | |
window=torch.hann_window(4096, periodic=True), | |
onesided=True, | |
) * (2 / 3) | |
# sf.write(f"{name}.wav", wave.t(), 44100) | |
# wave = (wave.t() * 32768).to(torch.int16) | |
# sound = AudioSegment( | |
# data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 | |
# ) | |
# sound.export(f"{name}.mp3", format="mp3", bitrate="128k") | |
ans.append(wave) | |
return ans[0], ans[1] | |
def get_nn_model_filename( | |
repo_id: str, | |
filename: str, | |
subfolder: str = "2stems", | |
) -> str: | |
nn_model_filename = hf_hub_download( | |
repo_id=repo_id, | |
filename=filename, | |
subfolder=subfolder, | |
) | |
return nn_model_filename | |
def load_model(name: str): | |
net = UNet() | |
net.eval() | |
filename = get_nn_model_filename( | |
"csukuangfj/spleeter-torch", name, subfolder="2stems" | |
) | |
state_dict = torch.load(filename, map_location="cpu") | |
net.load_state_dict(state_dict) | |
return net | |
def main(): | |
vocals = load_model("vocals.pt") | |
accompaniment = load_model("accompaniment.pt") | |
filename = "./yesterday-once-more-carpenters.mp3" | |
waveform = load_audio(filename) | |
assert waveform.shape[1] == 2, waveform.shape | |
vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform) | |
vocals_wave = (vocals_wave.t() * 32768).to(torch.int16) | |
accompaniment_wave = (accompaniment_wave.t() * 32768).to(torch.int16) | |
vocals_sound = AudioSegment( | |
data=vocals_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 | |
) | |
vocals_sound.export(f"vocals.mp3", format="mp3", bitrate="128k") | |
accompaniment_sound = AudioSegment( | |
data=accompaniment_wave.numpy().tobytes(), | |
sample_width=2, | |
frame_rate=44100, | |
channels=2, | |
) | |
accompaniment_sound.export(f"accompaniment.mp3", format="mp3", bitrate="128k") | |
if __name__ == "__main__": | |
main() | |