Spaces:
Runtime error
Runtime error
import sys | |
sys.path.append('.') | |
import argparse | |
import os | |
import time | |
from typing import Dict | |
import pathlib | |
import librosa | |
import numpy as np | |
import soundfile | |
import torch | |
import torch.nn as nn | |
from bytesep.models.lightning_modules import get_model_class | |
from bytesep.utils import read_yaml | |
class Separator: | |
def __init__( | |
self, model: nn.Module, segment_samples: int, batch_size: int, device: str | |
): | |
r"""Separate to separate an audio clip into a target source. | |
Args: | |
model: nn.Module, trained model | |
segment_samples: int, length of segments to be input to a model, e.g., 44100*30 | |
batch_size, int, e.g., 12 | |
device: str, e.g., 'cuda' | |
""" | |
self.model = model | |
self.segment_samples = segment_samples | |
self.batch_size = batch_size | |
self.device = device | |
def separate(self, input_dict: Dict) -> np.array: | |
r"""Separate an audio clip into a target source. | |
Args: | |
input_dict: dict, e.g., { | |
waveform: (channels_num, audio_samples), | |
..., | |
} | |
Returns: | |
sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples) | |
""" | |
audio = input_dict['waveform'] | |
audio_samples = audio.shape[-1] | |
# Pad the audio with zero in the end so that the length of audio can be | |
# evenly divided by segment_samples. | |
audio = self.pad_audio(audio) | |
# Enframe long audio into segments. | |
segments = self.enframe(audio, self.segment_samples) | |
# (segments_num, channels_num, segment_samples) | |
segments_input_dict = {'waveform': segments} | |
if 'condition' in input_dict.keys(): | |
segments_num = len(segments) | |
segments_input_dict['condition'] = np.tile( | |
input_dict['condition'][None, :], (segments_num, 1) | |
) | |
# (batch_size, segments_num) | |
# Separate in mini-batches. | |
sep_segments = self._forward_in_mini_batches( | |
self.model, segments_input_dict, self.batch_size | |
)['waveform'] | |
# (segments_num, channels_num, segment_samples) | |
# Deframe segments into long audio. | |
sep_audio = self.deframe(sep_segments) | |
# (channels_num, padded_audio_samples) | |
sep_audio = sep_audio[:, 0:audio_samples] | |
# (channels_num, audio_samples) | |
return sep_audio | |
def pad_audio(self, audio: np.array) -> np.array: | |
r"""Pad the audio with zero in the end so that the length of audio can | |
be evenly divided by segment_samples. | |
Args: | |
audio: (channels_num, audio_samples) | |
Returns: | |
padded_audio: (channels_num, audio_samples) | |
""" | |
channels_num, audio_samples = audio.shape | |
# Number of segments | |
segments_num = int(np.ceil(audio_samples / self.segment_samples)) | |
pad_samples = segments_num * self.segment_samples - audio_samples | |
padded_audio = np.concatenate( | |
(audio, np.zeros((channels_num, pad_samples))), axis=1 | |
) | |
# (channels_num, padded_audio_samples) | |
return padded_audio | |
def enframe(self, audio: np.array, segment_samples: int) -> np.array: | |
r"""Enframe long audio into segments. | |
Args: | |
audio: (channels_num, audio_samples) | |
segment_samples: int | |
Returns: | |
segments: (segments_num, channels_num, segment_samples) | |
""" | |
audio_samples = audio.shape[1] | |
assert audio_samples % segment_samples == 0 | |
hop_samples = segment_samples // 2 | |
segments = [] | |
pointer = 0 | |
while pointer + segment_samples <= audio_samples: | |
segments.append(audio[:, pointer : pointer + segment_samples]) | |
pointer += hop_samples | |
segments = np.array(segments) | |
return segments | |
def deframe(self, segments: np.array) -> np.array: | |
r"""Deframe segments into long audio. | |
Args: | |
segments: (segments_num, channels_num, segment_samples) | |
Returns: | |
output: (channels_num, audio_samples) | |
""" | |
(segments_num, _, segment_samples) = segments.shape | |
if segments_num == 1: | |
return segments[0] | |
assert self._is_integer(segment_samples * 0.25) | |
assert self._is_integer(segment_samples * 0.75) | |
output = [] | |
output.append(segments[0, :, 0 : int(segment_samples * 0.75)]) | |
for i in range(1, segments_num - 1): | |
output.append( | |
segments[ | |
i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75) | |
] | |
) | |
output.append(segments[-1, :, int(segment_samples * 0.25) :]) | |
output = np.concatenate(output, axis=-1) | |
return output | |
def _is_integer(self, x: float) -> bool: | |
if x - int(x) < 1e-10: | |
return True | |
else: | |
return False | |
def _forward_in_mini_batches( | |
self, model: nn.Module, segments_input_dict: Dict, batch_size: int | |
) -> Dict: | |
r"""Forward data to model in mini-batch. | |
Args: | |
model: nn.Module | |
segments_input_dict: dict, e.g., { | |
'waveform': (segments_num, channels_num, segment_samples), | |
..., | |
} | |
batch_size: int | |
Returns: | |
output_dict: dict, e.g. { | |
'waveform': (segments_num, channels_num, segment_samples), | |
} | |
""" | |
output_dict = {} | |
pointer = 0 | |
segments_num = len(segments_input_dict['waveform']) | |
while True: | |
if pointer >= segments_num: | |
break | |
batch_input_dict = {} | |
for key in segments_input_dict.keys(): | |
batch_input_dict[key] = torch.Tensor( | |
segments_input_dict[key][pointer : pointer + batch_size] | |
).to(self.device) | |
pointer += batch_size | |
with torch.no_grad(): | |
model.eval() | |
batch_output_dict = model(batch_input_dict) | |
for key in batch_output_dict.keys(): | |
self._append_to_dict( | |
output_dict, key, batch_output_dict[key].data.cpu().numpy() | |
) | |
for key in output_dict.keys(): | |
output_dict[key] = np.concatenate(output_dict[key], axis=0) | |
return output_dict | |
def _append_to_dict(self, dict, key, value): | |
if key in dict.keys(): | |
dict[key].append(value) | |
else: | |
dict[key] = [value] | |
class SeparatorWrapper: | |
def __init__( | |
self, source_type='vocals', model=None, checkpoint_path=None, device='cuda' | |
): | |
input_channels = 2 | |
target_sources_num = 1 | |
model_type = "ResUNet143_Subbandtime" | |
segment_samples = 44100 * 10 | |
batch_size = 1 | |
self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type) | |
if device == 'cuda' and torch.cuda.is_available(): | |
self.device = 'cuda' | |
else: | |
self.device = 'cpu' | |
# Get model class. | |
Model = get_model_class(model_type) | |
# Create model. | |
self.model = Model( | |
input_channels=input_channels, target_sources_num=target_sources_num | |
) | |
# Load checkpoint. | |
checkpoint = torch.load(self.checkpoint_path, map_location='cpu') | |
self.model.load_state_dict(checkpoint["model"]) | |
# Move model to device. | |
self.model.to(self.device) | |
# Create separator. | |
self.separator = Separator( | |
model=self.model, | |
segment_samples=segment_samples, | |
batch_size=batch_size, | |
device=self.device, | |
) | |
def download_checkpoints(self, checkpoint_path, source_type): | |
if source_type == "vocals": | |
checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps" | |
elif source_type == "accompaniment": | |
checkpoint_bare_name = ( | |
"resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth" | |
) | |
else: | |
raise NotImplementedError | |
if not checkpoint_path: | |
checkpoint_path = '{}/bytesep_data/{}.pth'.format( | |
str(pathlib.Path.home()), checkpoint_bare_name | |
) | |
print('Checkpoint path: {}'.format(checkpoint_path)) | |
if ( | |
not os.path.exists(checkpoint_path) | |
or os.path.getsize(checkpoint_path) < 4e8 | |
): | |
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True) | |
zenodo_dir = "https://zenodo.org/record/5507029/files" | |
zenodo_path = os.path.join( | |
zenodo_dir, "{}?download=1".format(checkpoint_bare_name) | |
) | |
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path)) | |
return checkpoint_path | |
def separate(self, audio): | |
input_dict = {'waveform': audio} | |
sep_wav = self.separator.separate(input_dict) | |
return sep_wav | |
def inference(args): | |
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync. | |
import torch.distributed as dist | |
dist.init_process_group( | |
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1 | |
) | |
# Arguments & parameters | |
config_yaml = args.config_yaml | |
checkpoint_path = args.checkpoint_path | |
audio_path = args.audio_path | |
output_path = args.output_path | |
device = ( | |
torch.device('cuda') | |
if args.cuda and torch.cuda.is_available() | |
else torch.device('cpu') | |
) | |
configs = read_yaml(config_yaml) | |
sample_rate = configs['train']['sample_rate'] | |
input_channels = configs['train']['channels'] | |
target_source_types = configs['train']['target_source_types'] | |
target_sources_num = len(target_source_types) | |
model_type = configs['train']['model_type'] | |
segment_samples = int(30 * sample_rate) | |
batch_size = 1 | |
print("Using {} for separating ..".format(device)) | |
# paths | |
if os.path.dirname(output_path) != "": | |
os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
# Get model class. | |
Model = get_model_class(model_type) | |
# Create model. | |
model = Model(input_channels=input_channels, target_sources_num=target_sources_num) | |
# Load checkpoint. | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
model.load_state_dict(checkpoint["model"]) | |
# Move model to device. | |
model.to(device) | |
# Create separator. | |
separator = Separator( | |
model=model, | |
segment_samples=segment_samples, | |
batch_size=batch_size, | |
device=device, | |
) | |
# Load audio. | |
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False) | |
# audio = audio[None, :] | |
input_dict = {'waveform': audio} | |
# Separate | |
separate_time = time.time() | |
sep_wav = separator.separate(input_dict) | |
# (channels_num, audio_samples) | |
print('Separate time: {:.3f} s'.format(time.time() - separate_time)) | |
# Write out separated audio. | |
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate) | |
os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path)) | |
print('Write out to {}'.format(output_path)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="") | |
parser.add_argument("--config_yaml", type=str, required=True) | |
parser.add_argument("--checkpoint_path", type=str, required=True) | |
parser.add_argument("--audio_path", type=str, required=True) | |
parser.add_argument("--output_path", type=str, required=True) | |
parser.add_argument("--cuda", action='store_true', default=True) | |
args = parser.parse_args() | |
inference(args) | |