noblebarkrr's picture
test1
d0cd3b0
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import os
import sys
import librosa
import tempfile
import soundfile as sf
import numpy as np
import argparse
from separator.audio_writer import write_audio_file
def stft(wave, nfft, hl):
wave_left = np.asfortranarray(wave[0])
wave_right = np.asfortranarray(wave[1])
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
spec = np.asfortranarray([spec_left, spec_right])
return spec
def istft(spec, hl, length):
spec_left = np.asfortranarray(spec[0])
spec_right = np.asfortranarray(spec[1])
wave_left = librosa.istft(spec_left, hop_length=hl, length=length)
wave_right = librosa.istft(spec_right, hop_length=hl, length=length)
wave = np.asfortranarray([wave_left, wave_right])
return wave
def absmax(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
argmax = np.abs(a).argmax(axis=axis)
# Convert indices to list before insertion
indices = list(indices)
indices.insert(axis % len(a.shape), argmax)
return a[tuple(indices)]
def absmin(a, *, axis):
dims = list(a.shape)
dims.pop(axis)
indices = np.ogrid[tuple(slice(0, d) for d in dims)]
argmax = np.abs(a).argmin(axis=axis)
indices.insert((len(a.shape) + axis) % len(a.shape), argmax)
return a[tuple(indices)]
def lambda_max(arr, axis=None, key=None, keepdims=False):
idxs = np.argmax(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def lambda_min(arr, axis=None, key=None, keepdims=False):
idxs = np.argmin(key(arr), axis)
if axis is not None:
idxs = np.expand_dims(idxs, axis)
result = np.take_along_axis(arr, idxs, axis)
if not keepdims:
result = np.squeeze(result, axis=axis)
return result
else:
return arr.flatten()[idxs]
def average_waveforms(pred_track, weights, algorithm):
"""
:param pred_track: shape = (num, channels, length)
:param weights: shape = (num, )
:param algorithm: One of avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft
:return: averaged waveform in shape (channels, length)
"""
pred_track = np.array(pred_track)
final_length = pred_track.shape[-1]
mod_track = []
for i in range(pred_track.shape[0]):
if algorithm == 'avg_wave':
mod_track.append(pred_track[i] * weights[i])
elif algorithm in ['median_wave', 'min_wave', 'max_wave']:
mod_track.append(pred_track[i])
elif algorithm in ['avg_fft', 'min_fft', 'max_fft', 'median_fft']:
spec = stft(pred_track[i], nfft=2048, hl=1024)
if algorithm in ['avg_fft']:
mod_track.append(spec * weights[i])
else:
mod_track.append(spec)
pred_track = np.array(mod_track)
if algorithm in ['avg_wave']:
pred_track = pred_track.sum(axis=0)
pred_track /= np.array(weights).sum().T
elif algorithm in ['median_wave']:
pred_track = np.median(pred_track, axis=0)
elif algorithm in ['min_wave']:
pred_track = np.array(pred_track)
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
elif algorithm in ['max_wave']:
pred_track = np.array(pred_track)
pred_track = lambda_max(pred_track, axis=0, key=np.abs)
elif algorithm in ['avg_fft']:
pred_track = pred_track.sum(axis=0)
pred_track /= np.array(weights).sum()
pred_track = istft(pred_track, 1024, final_length)
elif algorithm in ['min_fft']:
pred_track = np.array(pred_track)
pred_track = lambda_min(pred_track, axis=0, key=np.abs)
pred_track = istft(pred_track, 1024, final_length)
elif algorithm in ['max_fft']:
pred_track = np.array(pred_track)
pred_track = absmax(pred_track, axis=0)
pred_track = istft(pred_track, 1024, final_length)
elif algorithm in ['median_fft']:
pred_track = np.array(pred_track)
pred_track = np.median(pred_track, axis=0)
pred_track = istft(pred_track, 1024, final_length)
return pred_track
def ensemble_audio_files(files, output="res.wav", ensemble_type='avg_wave', weights=None, out_format="wav"):
"""
Основная функция для объединения аудиофайлов
:param files: список путей к аудиофайлам
:param output: путь для сохранения результата
:param ensemble_type: метод объединения (avg_wave, median_wave, min_wave, max_wave, avg_fft, median_fft, min_fft, max_fft)
:param weights: список весов для каждого файла (None для равных весов)
:return: None
"""
print('Ensemble type: {}'.format(ensemble_type))
print('Number of input files: {}'.format(len(files)))
if weights is not None:
weights = np.array(weights)
else:
weights = np.ones(len(files))
print('Weights: {}'.format(weights))
print('Output file: {}'.format(output))
data = []
sr = None
for f in files:
if not os.path.isfile(f):
print('Error. Can\'t find file: {}. Check paths.'.format(f))
exit()
print('Reading file: {}'.format(f))
wav, current_sr = librosa.load(f, sr=None, mono=False)
if sr is None:
sr = current_sr
elif sr != current_sr:
print('Error: Sample rates must be equal for all files')
exit()
print("Waveform shape: {} sample rate: {}".format(wav.shape, sr))
data.append(wav)
data = np.array(data)
res = average_waveforms(data, weights, ensemble_type)
print('Result shape: {}'.format(res.shape))
output_wav = f"{output}_orig.wav"
output = f"{output}.{out_format}"
if out_format in ["wav", "flac"]:
sf.write(output, res.T, sr, subtype='PCM_16')
sf.write(output_wav, res.T, sr, subtype='PCM_16')
elif out_format in ["mp3", "m4a", "aac", "ogg", "opus", "aiff"]:
write_audio_file(output, res.T, sr, out_format, "320k")
sf.write(output_wav, res.T, sr, subtype='PCM_16')
return output, output_wav
# input_settings = [("demucs / v4", 1.0, "vocals"), ("mel_band_roformer / mel_4_stems", 0.5, "vocals")]
# out, wav = ensembless(input_audio, input_settings, "max_fft", format)