|
|
|
|
|
|
|
from functools import lru_cache |
|
from typing import Optional, Tuple |
|
|
|
import ffmpeg |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from pydub import AudioSegment |
|
|
|
from unet import UNet |
|
|
|
|
|
def load_audio(filename): |
|
probe = ffmpeg.probe(filename) |
|
if "streams" not in probe or len(probe["streams"]) == 0: |
|
raise ValueError("No stream was found with ffprobe") |
|
|
|
metadata = next( |
|
stream for stream in probe["streams"] if stream["codec_type"] == "audio" |
|
) |
|
n_channels = metadata["channels"] |
|
|
|
sample_rate = 44100 |
|
|
|
process = ( |
|
ffmpeg.input(filename) |
|
.output("pipe:", format="f32le", ar=sample_rate) |
|
.run_async(pipe_stdout=True, pipe_stderr=True) |
|
) |
|
buffer, _ = process.communicate() |
|
waveform = np.frombuffer(buffer, dtype="<f4").reshape(-1, n_channels) |
|
|
|
waveform = torch.from_numpy(waveform).to(torch.float32) |
|
if n_channels == 1: |
|
waveform = waveform.tile(1, 2) |
|
|
|
if n_channels > 2: |
|
waveform = waveform[:, :2] |
|
|
|
return waveform |
|
|
|
|
|
def separate( |
|
vocals: torch.nn.Module, |
|
accompaniment: torch.nn.Module, |
|
waveform: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096)) |
|
|
|
|
|
stft = torch.stft( |
|
waveform.t(), |
|
n_fft=4096, |
|
hop_length=1024, |
|
window=torch.hann_window(4096, periodic=True), |
|
center=False, |
|
onesided=True, |
|
return_complex=True, |
|
) |
|
|
|
|
|
|
|
y = stft.permute(2, 1, 0) |
|
|
|
|
|
y = y[:, :1024, :] |
|
|
|
|
|
tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512 |
|
pad_size = 512 - tensor_size |
|
y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size)) |
|
|
|
|
|
num_splits = int(y.shape[0] / 512) |
|
y = y.reshape([num_splits, 512] + list(y.shape[1:])) |
|
|
|
|
|
y = y.abs() |
|
y = y.permute(0, 3, 1, 2) |
|
|
|
|
|
vocals_spec = vocals(y) |
|
accompaniment_spec = accompaniment(y) |
|
|
|
sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 |
|
|
|
vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec |
|
|
|
|
|
accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec |
|
|
|
|
|
ans = [] |
|
for spec in [vocals_spec, accompaniment_spec]: |
|
spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0)) |
|
|
|
|
|
spec = spec.permute(0, 2, 3, 1) |
|
|
|
|
|
spec = spec.reshape(-1, spec.shape[2], spec.shape[3]) |
|
|
|
|
|
spec = spec[: stft.shape[2], :, :] |
|
|
|
|
|
spec = spec.permute(2, 1, 0) |
|
|
|
|
|
masked_stft = spec * stft |
|
|
|
wave = torch.istft( |
|
masked_stft, |
|
4096, |
|
1024, |
|
window=torch.hann_window(4096, periodic=True), |
|
onesided=True, |
|
) * (2 / 3) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ans.append(wave) |
|
|
|
return ans[0], ans[1] |
|
|
|
|
|
@lru_cache(maxsize=10) |
|
def get_file( |
|
repo_id: str, |
|
filename: str, |
|
subfolder: str = "2stems", |
|
) -> str: |
|
nn_model_filename = hf_hub_download( |
|
repo_id=repo_id, |
|
filename=filename, |
|
subfolder=subfolder, |
|
) |
|
return nn_model_filename |
|
|
|
|
|
@lru_cache(maxsize=10) |
|
def load_model(name: str): |
|
net = UNet() |
|
net.eval() |
|
filename = get_file("csukuangfj/spleeter-torch", name, subfolder="2stems") |
|
|
|
state_dict = torch.load(filename, map_location="cpu") |
|
net.load_state_dict(state_dict) |
|
|
|
return net |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
vocals = load_model("vocals.pt") |
|
accompaniment = load_model("accompaniment.pt") |
|
|
|
filename = "./yesterday-once-more-carpenters.mp3" |
|
|
|
waveform = load_audio(filename) |
|
assert waveform.shape[1] == 2, waveform.shape |
|
|
|
vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform) |
|
vocals_wave = (vocals_wave.t() * 32768).to(torch.int16) |
|
accompaniment_wave = (accompaniment_wave.t() * 32768).to(torch.int16) |
|
|
|
vocals_sound = AudioSegment( |
|
data=vocals_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 |
|
) |
|
vocals_sound.export(f"vocals.mp3", format="mp3", bitrate="128k") |
|
|
|
accompaniment_sound = AudioSegment( |
|
data=accompaniment_wave.numpy().tobytes(), |
|
sample_width=2, |
|
frame_rate=44100, |
|
channels=2, |
|
) |
|
accompaniment_sound.export(f"accompaniment.mp3", format="mp3", bitrate="128k") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|