Ahsen Khaliq
Update bytesep/inference.py
16925ec
raw
history blame
11.9 kB
import sys
sys.path.append('.')
import argparse
import os
import time
from typing import Dict
import pathlib
import librosa
import numpy as np
import soundfile
import torch
import torch.nn as nn
from bytesep.models.lightning_modules import get_model_class
from bytesep.utils import read_yaml
class Separator:
def __init__(
self, model: nn.Module, segment_samples: int, batch_size: int, device: str
):
r"""Separate to separate an audio clip into a target source.
Args:
model: nn.Module, trained model
segment_samples: int, length of segments to be input to a model, e.g., 44100*30
batch_size, int, e.g., 12
device: str, e.g., 'cuda'
"""
self.model = model
self.segment_samples = segment_samples
self.batch_size = batch_size
self.device = device
def separate(self, input_dict: Dict) -> np.array:
r"""Separate an audio clip into a target source.
Args:
input_dict: dict, e.g., {
waveform: (channels_num, audio_samples),
...,
}
Returns:
sep_audio: (channels_num, audio_samples) | (target_sources_num, channels_num, audio_samples)
"""
audio = input_dict['waveform']
audio_samples = audio.shape[-1]
# Pad the audio with zero in the end so that the length of audio can be
# evenly divided by segment_samples.
audio = self.pad_audio(audio)
# Enframe long audio into segments.
segments = self.enframe(audio, self.segment_samples)
# (segments_num, channels_num, segment_samples)
segments_input_dict = {'waveform': segments}
if 'condition' in input_dict.keys():
segments_num = len(segments)
segments_input_dict['condition'] = np.tile(
input_dict['condition'][None, :], (segments_num, 1)
)
# (batch_size, segments_num)
# Separate in mini-batches.
sep_segments = self._forward_in_mini_batches(
self.model, segments_input_dict, self.batch_size
)['waveform']
# (segments_num, channels_num, segment_samples)
# Deframe segments into long audio.
sep_audio = self.deframe(sep_segments)
# (channels_num, padded_audio_samples)
sep_audio = sep_audio[:, 0:audio_samples]
# (channels_num, audio_samples)
return sep_audio
def pad_audio(self, audio: np.array) -> np.array:
r"""Pad the audio with zero in the end so that the length of audio can
be evenly divided by segment_samples.
Args:
audio: (channels_num, audio_samples)
Returns:
padded_audio: (channels_num, audio_samples)
"""
channels_num, audio_samples = audio.shape
# Number of segments
segments_num = int(np.ceil(audio_samples / self.segment_samples))
pad_samples = segments_num * self.segment_samples - audio_samples
padded_audio = np.concatenate(
(audio, np.zeros((channels_num, pad_samples))), axis=1
)
# (channels_num, padded_audio_samples)
return padded_audio
def enframe(self, audio: np.array, segment_samples: int) -> np.array:
r"""Enframe long audio into segments.
Args:
audio: (channels_num, audio_samples)
segment_samples: int
Returns:
segments: (segments_num, channels_num, segment_samples)
"""
audio_samples = audio.shape[1]
assert audio_samples % segment_samples == 0
hop_samples = segment_samples // 2
segments = []
pointer = 0
while pointer + segment_samples <= audio_samples:
segments.append(audio[:, pointer : pointer + segment_samples])
pointer += hop_samples
segments = np.array(segments)
return segments
def deframe(self, segments: np.array) -> np.array:
r"""Deframe segments into long audio.
Args:
segments: (segments_num, channels_num, segment_samples)
Returns:
output: (channels_num, audio_samples)
"""
(segments_num, _, segment_samples) = segments.shape
if segments_num == 1:
return segments[0]
assert self._is_integer(segment_samples * 0.25)
assert self._is_integer(segment_samples * 0.75)
output = []
output.append(segments[0, :, 0 : int(segment_samples * 0.75)])
for i in range(1, segments_num - 1):
output.append(
segments[
i, :, int(segment_samples * 0.25) : int(segment_samples * 0.75)
]
)
output.append(segments[-1, :, int(segment_samples * 0.25) :])
output = np.concatenate(output, axis=-1)
return output
def _is_integer(self, x: float) -> bool:
if x - int(x) < 1e-10:
return True
else:
return False
def _forward_in_mini_batches(
self, model: nn.Module, segments_input_dict: Dict, batch_size: int
) -> Dict:
r"""Forward data to model in mini-batch.
Args:
model: nn.Module
segments_input_dict: dict, e.g., {
'waveform': (segments_num, channels_num, segment_samples),
...,
}
batch_size: int
Returns:
output_dict: dict, e.g. {
'waveform': (segments_num, channels_num, segment_samples),
}
"""
output_dict = {}
pointer = 0
segments_num = len(segments_input_dict['waveform'])
while True:
if pointer >= segments_num:
break
batch_input_dict = {}
for key in segments_input_dict.keys():
batch_input_dict[key] = torch.Tensor(
segments_input_dict[key][pointer : pointer + batch_size]
).to(self.device)
pointer += batch_size
with torch.no_grad():
model.eval()
batch_output_dict = model(batch_input_dict)
for key in batch_output_dict.keys():
self._append_to_dict(
output_dict, key, batch_output_dict[key].data.cpu().numpy()
)
for key in output_dict.keys():
output_dict[key] = np.concatenate(output_dict[key], axis=0)
return output_dict
def _append_to_dict(self, dict, key, value):
if key in dict.keys():
dict[key].append(value)
else:
dict[key] = [value]
class SeparatorWrapper:
def __init__(
self, source_type='vocals', model=None, checkpoint_path=None, device='cuda'
):
input_channels = 2
target_sources_num = 1
model_type = "ResUNet143_Subbandtime"
segment_samples = 44100 * 10
batch_size = 1
self.checkpoint_path = self.download_checkpoints(checkpoint_path, source_type)
if device == 'cuda' and torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'
# Get model class.
Model = get_model_class(model_type)
# Create model.
self.model = Model(
input_channels=input_channels, target_sources_num=target_sources_num
)
# Load checkpoint.
checkpoint = torch.load(self.checkpoint_path, map_location='cpu')
self.model.load_state_dict(checkpoint["model"])
# Move model to device.
self.model.to(self.device)
# Create separator.
self.separator = Separator(
model=self.model,
segment_samples=segment_samples,
batch_size=batch_size,
device=self.device,
)
def download_checkpoints(self, checkpoint_path, source_type):
if source_type == "vocals":
checkpoint_bare_name = "resunet143_subbtandtime_vocals_8.8dB_350k_steps"
elif source_type == "accompaniment":
checkpoint_bare_name = (
"resunet143_subbtandtime_accompaniment_16.4dB_350k_steps.pth"
)
else:
raise NotImplementedError
if not checkpoint_path:
checkpoint_path = '{}/bytesep_data/{}.pth'.format(
str(pathlib.Path.home()), checkpoint_bare_name
)
print('Checkpoint path: {}'.format(checkpoint_path))
if (
not os.path.exists(checkpoint_path)
or os.path.getsize(checkpoint_path) < 4e8
):
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
zenodo_dir = "https://zenodo.org/record/5507029/files"
zenodo_path = os.path.join(
zenodo_dir, "{}?download=1".format(checkpoint_bare_name)
)
os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
return checkpoint_path
def separate(self, audio):
input_dict = {'waveform': audio}
sep_wav = self.separator.separate(input_dict)
return sep_wav
def inference(args):
# Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
import torch.distributed as dist
dist.init_process_group(
'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
)
# Arguments & parameters
config_yaml = args.config_yaml
checkpoint_path = args.checkpoint_path
audio_path = args.audio_path
output_path = args.output_path
device = (
torch.device('cuda')
if args.cuda and torch.cuda.is_available()
else torch.device('cpu')
)
configs = read_yaml(config_yaml)
sample_rate = configs['train']['sample_rate']
input_channels = configs['train']['channels']
target_source_types = configs['train']['target_source_types']
target_sources_num = len(target_source_types)
model_type = configs['train']['model_type']
segment_samples = int(30 * sample_rate)
batch_size = 1
print("Using {} for separating ..".format(device))
# paths
if os.path.dirname(output_path) != "":
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Get model class.
Model = get_model_class(model_type)
# Create model.
model = Model(input_channels=input_channels, target_sources_num=target_sources_num)
# Load checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint["model"])
# Move model to device.
model.to(device)
# Create separator.
separator = Separator(
model=model,
segment_samples=segment_samples,
batch_size=batch_size,
device=device,
)
# Load audio.
audio, _ = librosa.load(audio_path, sr=sample_rate, mono=False)
# audio = audio[None, :]
input_dict = {'waveform': audio}
# Separate
separate_time = time.time()
sep_wav = separator.separate(input_dict)
# (channels_num, audio_samples)
print('Separate time: {:.3f} s'.format(time.time() - separate_time))
# Write out separated audio.
soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)
os.system("ffmpeg -y -loglevel panic -i _zz.wav {}".format(output_path))
print('Write out to {}'.format(output_path))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument("--config_yaml", type=str, required=True)
parser.add_argument("--checkpoint_path", type=str, required=True)
parser.add_argument("--audio_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--cuda", action='store_true', default=True)
args = parser.parse_args()
inference(args)