|
|
|
|
|
__author__ = "Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/" |
|
|
|
|
|
import os |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
import numpy as np |
|
|
import argparse |
|
|
|
|
|
|
|
|
def stft(wave, nfft, hl): |
|
|
wave_left = np.asfortranarray(wave[0]) |
|
|
wave_right = np.asfortranarray(wave[1]) |
|
|
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl) |
|
|
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl) |
|
|
spec = np.asfortranarray([spec_left, spec_right]) |
|
|
return spec |
|
|
|
|
|
|
|
|
def istft(spec, hl, length): |
|
|
spec_left = np.asfortranarray(spec[0]) |
|
|
spec_right = np.asfortranarray(spec[1]) |
|
|
wave_left = librosa.istft(spec_left, hop_length=hl, length=length) |
|
|
wave_right = librosa.istft(spec_right, hop_length=hl, length=length) |
|
|
wave = np.asfortranarray([wave_left, wave_right]) |
|
|
return wave |
|
|
|
|
|
|
|
|
def absmax(a, *, axis): |
|
|
dims = list(a.shape) |
|
|
dims.pop(axis) |
|
|
indices = np.ogrid[tuple(slice(0, d) for d in dims)] |
|
|
argmax = np.abs(a).argmax(axis=axis) |
|
|
indices.insert((len(a.shape) + axis) % len(a.shape), argmax) |
|
|
return a[tuple(indices)] |
|
|
|
|
|
|
|
|
def absmin(a, *, axis): |
|
|
dims = list(a.shape) |
|
|
dims.pop(axis) |
|
|
indices = np.ogrid[tuple(slice(0, d) for d in dims)] |
|
|
argmax = np.abs(a).argmin(axis=axis) |
|
|
indices.insert((len(a.shape) + axis) % len(a.shape), argmax) |
|
|
return a[tuple(indices)] |
|
|
|
|
|
|
|
|
def lambda_max(arr, axis=None, key=None, keepdims=False): |
|
|
idxs = np.argmax(key(arr), axis) |
|
|
if axis is not None: |
|
|
idxs = np.expand_dims(idxs, axis) |
|
|
result = np.take_along_axis(arr, idxs, axis) |
|
|
if not keepdims: |
|
|
result = np.squeeze(result, axis=axis) |
|
|
return result |
|
|
else: |
|
|
return arr.flatten()[idxs] |
|
|
|
|
|
|
|
|
def lambda_min(arr, axis=None, key=None, keepdims=False): |
|
|
idxs = np.argmin(key(arr), axis) |
|
|
if axis is not None: |
|
|
idxs = np.expand_dims(idxs, axis) |
|
|
result = np.take_along_axis(arr, idxs, axis) |
|
|
if not keepdims: |
|
|
result = np.squeeze(result, axis=axis) |
|
|
return result |
|
|
else: |
|
|
return arr.flatten()[idxs] |
|
|
|
|
|
|
|
|
def average_waveforms(pred_track, weights, algorithm): |
|
|
""" |
|
|
:param pred_track: shape = (num, channels, length) |
|
|
:param weights: shape = (num, ) |
|
|
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft |
|
|
:return: averaged waveform in shape (channels, length) |
|
|
""" |
|
|
|
|
|
pred_track = np.array(pred_track) |
|
|
final_length = pred_track.shape[-1] |
|
|
|
|
|
mod_track = [] |
|
|
for i in range(pred_track.shape[0]): |
|
|
if algorithm == "avg_wave": |
|
|
mod_track.append(pred_track[i] * weights[i]) |
|
|
elif algorithm in ["median_wave", "min_wave", "max_wave"]: |
|
|
mod_track.append(pred_track[i]) |
|
|
elif algorithm in ["avg_fft", "min_fft", "max_fft", "median_fft"]: |
|
|
spec = stft(pred_track[i], nfft=2048, hl=1024) |
|
|
if algorithm in ["avg_fft"]: |
|
|
mod_track.append(spec * weights[i]) |
|
|
else: |
|
|
mod_track.append(spec) |
|
|
pred_track = np.array(mod_track) |
|
|
|
|
|
if algorithm in ["avg_wave"]: |
|
|
pred_track = pred_track.sum(axis=0) |
|
|
pred_track /= np.array(weights).sum().T |
|
|
elif algorithm in ["median_wave"]: |
|
|
pred_track = np.median(pred_track, axis=0) |
|
|
elif algorithm in ["min_wave"]: |
|
|
pred_track = np.array(pred_track) |
|
|
pred_track = lambda_min(pred_track, axis=0, key=np.abs) |
|
|
elif algorithm in ["max_wave"]: |
|
|
pred_track = np.array(pred_track) |
|
|
pred_track = lambda_max(pred_track, axis=0, key=np.abs) |
|
|
elif algorithm in ["avg_fft"]: |
|
|
pred_track = pred_track.sum(axis=0) |
|
|
pred_track /= np.array(weights).sum() |
|
|
pred_track = istft(pred_track, 1024, final_length) |
|
|
elif algorithm in ["min_fft"]: |
|
|
pred_track = np.array(pred_track) |
|
|
pred_track = lambda_min(pred_track, axis=0, key=np.abs) |
|
|
pred_track = istft(pred_track, 1024, final_length) |
|
|
elif algorithm in ["max_fft"]: |
|
|
pred_track = np.array(pred_track) |
|
|
pred_track = absmax(pred_track, axis=0) |
|
|
pred_track = istft(pred_track, 1024, final_length) |
|
|
elif algorithm in ["median_fft"]: |
|
|
pred_track = np.array(pred_track) |
|
|
pred_track = np.median(pred_track, axis=0) |
|
|
pred_track = istft(pred_track, 1024, final_length) |
|
|
return pred_track |
|
|
|
|
|
|
|
|
def ensemble_files(args): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--files", |
|
|
type=str, |
|
|
required=True, |
|
|
nargs="+", |
|
|
help="Path to all audio-files to ensemble", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--type", |
|
|
type=str, |
|
|
default="avg_wave", |
|
|
help="One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--weights", |
|
|
type=float, |
|
|
nargs="+", |
|
|
help="Weights to create ensemble. Number of weights must be equal to number of files", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
default="res.wav", |
|
|
type=str, |
|
|
help="Path to wav file where ensemble result will be stored", |
|
|
) |
|
|
if args is None: |
|
|
args = parser.parse_args() |
|
|
else: |
|
|
args = parser.parse_args(args) |
|
|
|
|
|
print("Ensemble type: {}".format(args.type)) |
|
|
print("Number of input files: {}".format(len(args.files))) |
|
|
if args.weights is not None: |
|
|
weights = args.weights |
|
|
else: |
|
|
weights = np.ones(len(args.files)) |
|
|
print("Weights: {}".format(weights)) |
|
|
print("Output file: {}".format(args.output)) |
|
|
data = [] |
|
|
for f in args.files: |
|
|
if not os.path.isfile(f): |
|
|
print("Error. Can't find file: {}. Check paths.".format(f)) |
|
|
exit() |
|
|
print("Reading file: {}".format(f)) |
|
|
wav, sr = librosa.load(f, sr=None, mono=False) |
|
|
|
|
|
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr)) |
|
|
data.append(wav) |
|
|
data = np.array(data) |
|
|
res = average_waveforms(data, weights, args.type) |
|
|
print("Result shape: {}".format(res.shape)) |
|
|
sf.write(args.output, res.T, sr, "FLOAT") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
ensemble_files(None) |
|
|
|