#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) from functools import lru_cache from typing import Optional, Tuple import ffmpeg import numpy as np import torch from huggingface_hub import hf_hub_download from pydub import AudioSegment from unet import UNet def load_audio(filename): probe = ffmpeg.probe(filename) if "streams" not in probe or len(probe["streams"]) == 0: raise ValueError("No stream was found with ffprobe") metadata = next( stream for stream in probe["streams"] if stream["codec_type"] == "audio" ) n_channels = metadata["channels"] sample_rate = 44100 process = ( ffmpeg.input(filename) .output("pipe:", format="f32le", ar=sample_rate) .run_async(pipe_stdout=True, pipe_stderr=True) ) buffer, _ = process.communicate() waveform = np.frombuffer(buffer, dtype=" 2: waveform = waveform[:, :2] return waveform def separate( vocals: torch.nn.Module, accompaniment: torch.nn.Module, waveform: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: waveform = torch.nn.functional.pad(waveform, (0, 0, 0, 4096)) # torch.stft requires a 2-D input of shape (N, T), so we transpose waveform stft = torch.stft( waveform.t(), n_fft=4096, hop_length=1024, window=torch.hann_window(4096, periodic=True), center=False, onesided=True, return_complex=True, ) # stft: (2, 2049, 465) # stft is a complex tensor y = stft.permute(2, 1, 0) # (465, 2049, 2) y = y[:, :1024, :] # (465, 1024, 2) tensor_size = y.shape[0] - int(y.shape[0] / 512) * 512 pad_size = 512 - tensor_size y = torch.nn.functional.pad(y, (0, 0, 0, 0, 0, pad_size)) # (512, 1024, 2) num_splits = int(y.shape[0] / 512) y = y.reshape([num_splits, 512] + list(y.shape[1:])) # y: (1, 512, 1024, 2) y = y.abs() y = y.permute(0, 3, 1, 2) # (1, 2, 512, 1024) vocals_spec = vocals(y) accompaniment_spec = accompaniment(y) sum_spec = (vocals_spec**2 + accompaniment_spec**2) + 1e-10 vocals_spec = (vocals_spec**2 + 1e-10 / 2) / sum_spec # (1, 2, 512, 1024) accompaniment_spec = (accompaniment_spec**2 + 1e-10 / 2) / sum_spec # (1, 2, 512, 1024) ans = [] for spec in [vocals_spec, accompaniment_spec]: spec = torch.nn.functional.pad(spec, (0, 2049 - 1024, 0, 0, 0, 0, 0, 0)) # (1, 2, 512, 2049) spec = spec.permute(0, 2, 3, 1) # (1, 512, 2049, 2) spec = spec.reshape(-1, spec.shape[2], spec.shape[3]) # (512, 2049, 2) spec = spec[: stft.shape[2], :, :] # (465, 2049, 2) spec = spec.permute(2, 1, 0) # (2, 2049, 465) masked_stft = spec * stft wave = torch.istft( masked_stft, 4096, 1024, window=torch.hann_window(4096, periodic=True), onesided=True, ) * (2 / 3) # sf.write(f"{name}.wav", wave.t(), 44100) # wave = (wave.t() * 32768).to(torch.int16) # sound = AudioSegment( # data=wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 # ) # sound.export(f"{name}.mp3", format="mp3", bitrate="128k") ans.append(wave) return ans[0], ans[1] @lru_cache(maxsize=10) def get_file( repo_id: str, filename: str, subfolder: str = "2stems", ) -> str: nn_model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) return nn_model_filename @lru_cache(maxsize=10) def load_model(name: str): net = UNet() net.eval() filename = get_file("csukuangfj/spleeter-torch", name, subfolder="2stems") state_dict = torch.load(filename, map_location="cpu") net.load_state_dict(state_dict) return net @torch.no_grad() def main(): vocals = load_model("vocals.pt") accompaniment = load_model("accompaniment.pt") filename = "./yesterday-once-more-carpenters.mp3" waveform = load_audio(filename) assert waveform.shape[1] == 2, waveform.shape vocals_wave, accompaniment_wave = separate(vocals, accompaniment, waveform) vocals_wave = (vocals_wave.t() * 32768).to(torch.int16) accompaniment_wave = (accompaniment_wave.t() * 32768).to(torch.int16) vocals_sound = AudioSegment( data=vocals_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2 ) vocals_sound.export(f"vocals.mp3", format="mp3", bitrate="128k") accompaniment_sound = AudioSegment( data=accompaniment_wave.numpy().tobytes(), sample_width=2, frame_rate=44100, channels=2, ) accompaniment_sound.export(f"accompaniment.mp3", format="mp3", bitrate="128k") if __name__ == "__main__": main()