|
import gc
|
|
import os
|
|
import random
|
|
import numpy as np
|
|
from scipy.signal.windows import hann
|
|
import soundfile as sf
|
|
import torch
|
|
from cog import BasePredictor, Input, Path
|
|
import tempfile
|
|
import argparse
|
|
import librosa
|
|
from audiosr import build_model, super_resolution
|
|
from scipy import signal
|
|
import pyloudnorm as pyln
|
|
|
|
|
|
import warnings
|
|
warnings.filterwarnings("ignore")
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray):
|
|
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
|
|
if array_1.shape[0] > array_2.shape[0]:
|
|
array_1 = array_1[:array_2.shape[0]]
|
|
elif array_1.shape[0] < array_2.shape[0]:
|
|
array_1 = np.pad(array_1, ((array_2.shape[0] - array_1.shape[0], 0)), 'constant', constant_values=0)
|
|
else:
|
|
if array_1.shape[1] > array_2.shape[1]:
|
|
array_1 = array_1[:,:array_2.shape[1]]
|
|
elif array_1.shape[1] < array_2.shape[1]:
|
|
padding = array_2.shape[1] - array_1.shape[1]
|
|
array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0)
|
|
return array_1
|
|
|
|
|
|
def lr_filter(audio, cutoff, filter_type, order=12, sr=48000):
|
|
audio = audio.T
|
|
nyquist = 0.5 * sr
|
|
normal_cutoff = cutoff / nyquist
|
|
b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False)
|
|
sos = signal.tf2sos(b, a)
|
|
filtered_audio = signal.sosfiltfilt(sos, audio)
|
|
return filtered_audio.T
|
|
|
|
class Predictor(BasePredictor):
|
|
def setup(self, model_name="basic", device="auto"):
|
|
self.model_name = model_name
|
|
self.device = device
|
|
self.sr = 48000
|
|
print("Loading Model...")
|
|
self.audiosr = build_model(model_name=self.model_name, device=self.device)
|
|
|
|
|
|
print("Model loaded!")
|
|
|
|
def process_audio(self, input_file, chunk_size=5.12, overlap=0.1, seed=None, guidance_scale=3.5, ddim_steps=50):
|
|
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
|
|
audio = audio.T
|
|
sr = input_cutoff * 2
|
|
print(f"audio.shape = {audio.shape}")
|
|
print(f"input cutoff = {input_cutoff}")
|
|
|
|
is_stereo = len(audio.shape) == 2
|
|
audio_channels = [audio] if not is_stereo else [audio[:, 0], audio[:, 1]]
|
|
print("audio is stereo" if is_stereo else "Audio is mono")
|
|
|
|
chunk_samples = int(chunk_size * sr)
|
|
overlap_samples = int(overlap * chunk_samples)
|
|
output_chunk_samples = int(chunk_size * self.sr)
|
|
output_overlap_samples = int(overlap * output_chunk_samples)
|
|
enable_overlap = overlap > 0
|
|
print(f"enable_overlap = {enable_overlap}")
|
|
|
|
def process_chunks(audio):
|
|
chunks = []
|
|
original_lengths = []
|
|
start = 0
|
|
while start < len(audio):
|
|
end = min(start + chunk_samples, len(audio))
|
|
chunk = audio[start:end]
|
|
if len(chunk) < chunk_samples:
|
|
original_lengths.append(len(chunk))
|
|
chunk = np.concatenate([chunk, np.zeros(chunk_samples - len(chunk))])
|
|
else:
|
|
original_lengths.append(chunk_samples)
|
|
chunks.append(chunk)
|
|
start += chunk_samples - overlap_samples if enable_overlap else chunk_samples
|
|
return chunks, original_lengths
|
|
|
|
|
|
chunks_per_channel = [process_chunks(channel) for channel in audio_channels]
|
|
sample_rate_ratio = self.sr / sr
|
|
total_length = len(chunks_per_channel[0][0]) * output_chunk_samples - (len(chunks_per_channel[0][0]) - 1) * (output_overlap_samples if enable_overlap else 0)
|
|
reconstructed_channels = [np.zeros((1, total_length)) for _ in audio_channels]
|
|
|
|
meter_before = pyln.Meter(sr)
|
|
meter_after = pyln.Meter(self.sr)
|
|
|
|
|
|
for ch_idx, (chunks, original_lengths) in enumerate(chunks_per_channel):
|
|
for i, chunk in enumerate(chunks):
|
|
loudness_before = meter_before.integrated_loudness(chunk)
|
|
print(f"Processing chunk {i+1} of {len(chunks)} for {'Left/Mono' if ch_idx == 0 else 'Right'} channel")
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
|
|
sf.write(temp_wav.name, chunk, sr)
|
|
|
|
out_chunk = super_resolution(
|
|
self.audiosr,
|
|
temp_wav.name,
|
|
seed=seed,
|
|
guidance_scale=guidance_scale,
|
|
ddim_steps=ddim_steps,
|
|
latent_t_per_second=12.8
|
|
)
|
|
|
|
out_chunk = out_chunk[0]
|
|
num_samples_to_keep = int(original_lengths[i] * sample_rate_ratio)
|
|
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
|
|
loudness_after = meter_after.integrated_loudness(out_chunk)
|
|
out_chunk = pyln.normalize.loudness(out_chunk, loudness_after, loudness_before)
|
|
|
|
if enable_overlap:
|
|
actual_overlap_samples = min(output_overlap_samples, num_samples_to_keep)
|
|
fade_out = np.linspace(1., 0., actual_overlap_samples)
|
|
fade_in = np.linspace(0., 1., actual_overlap_samples)
|
|
|
|
if i == 0:
|
|
out_chunk[-actual_overlap_samples:] *= fade_out
|
|
elif i < len(chunks) - 1:
|
|
out_chunk[:actual_overlap_samples] *= fade_in
|
|
out_chunk[-actual_overlap_samples:] *= fade_out
|
|
else:
|
|
out_chunk[:actual_overlap_samples] *= fade_in
|
|
|
|
start = i * (output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples)
|
|
end = start + out_chunk.shape[0]
|
|
reconstructed_channels[ch_idx][0, start:end] += out_chunk.flatten()
|
|
|
|
reconstructed_audio = np.stack(reconstructed_channels, axis=-1) if is_stereo else reconstructed_channels[0]
|
|
|
|
if multiband_ensemble:
|
|
low, _ = librosa.load(input_file, sr=48000, mono=False)
|
|
output = match_array_shapes(reconstructed_audio[0].T, low)
|
|
low = lr_filter(low.T, crossover_freq, 'lowpass', order=10)
|
|
high = lr_filter(output.T, crossover_freq, 'highpass', order=10)
|
|
high = lr_filter(high, 23000, 'lowpass', order=2)
|
|
output = low + high
|
|
else:
|
|
output = reconstructed_audio[0]
|
|
|
|
return output
|
|
|
|
|
|
def predict(self,
|
|
input_file: Path = Input(description="Audio to upsample"),
|
|
ddim_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=500),
|
|
guidance_scale: float = Input(description="Scale for classifier free guidance", default=3.5, ge=1.0, le=20.0),
|
|
overlap: float = Input(description="overlap size", default=0.04),
|
|
chunk_size: float = Input(description="chunksize", default=10.24),
|
|
seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None)
|
|
) -> Path:
|
|
|
|
if seed == 0:
|
|
seed = random.randint(0, 2**32 - 1)
|
|
print(f"Setting seed to: {seed}")
|
|
print(f"overlap = {overlap}")
|
|
print(f"guidance_scale = {guidance_scale}")
|
|
print(f"ddim_steps = {ddim_steps}")
|
|
print(f"chunk_size = {chunk_size}")
|
|
print(f"multiband_ensemble = {multiband_ensemble}")
|
|
print(f"input file = {os.path.basename(input_file)}")
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
waveform = self.process_audio(
|
|
input_file,
|
|
chunk_size=chunk_size,
|
|
overlap=overlap,
|
|
seed=seed,
|
|
guidance_scale=guidance_scale,
|
|
ddim_steps=ddim_steps
|
|
)
|
|
|
|
filename = os.path.splitext(os.path.basename(input_file))[0]
|
|
sf.write(f"{output_folder}/SR_{filename}.wav", data=waveform, samplerate=48000, subtype="PCM_16")
|
|
print(f"file created: {output_folder}/SR_{filename}.wav")
|
|
del self.audiosr, waveform
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Find volume difference of two audio files.")
|
|
parser.add_argument("--input", help="Path to input audio file")
|
|
parser.add_argument("--output", help="Output folder")
|
|
parser.add_argument("--ddim_steps", help="Number of ddim steps", type=int, required=False, default=50)
|
|
parser.add_argument("--chunk_size", help="chunk size", type=float, required=False, default=10.24)
|
|
parser.add_argument("--guidance_scale", help="Guidance scale value", type=float, required=False, default=3.5)
|
|
parser.add_argument("--seed", help="Seed value, 0 = random seed", type=int, required=False, default=0)
|
|
parser.add_argument("--overlap", help="overlap value", type=float, required=False, default=0.04)
|
|
parser.add_argument("--multiband_ensemble", type=bool, help="Use multiband ensemble with input")
|
|
parser.add_argument("--input_cutoff", help="Define the crossover of audio input in the multiband ensemble", type=int, required=False, default=12000)
|
|
|
|
args = parser.parse_args()
|
|
|
|
input_file_path = args.input
|
|
output_folder = args.output
|
|
ddim_steps = args.ddim_steps
|
|
chunk_size = args.chunk_size
|
|
guidance_scale = args.guidance_scale
|
|
seed = args.seed
|
|
overlap = args.overlap
|
|
input_cutoff = args.input_cutoff
|
|
multiband_ensemble = args.multiband_ensemble
|
|
|
|
crossover_freq = input_cutoff - 1000
|
|
|
|
p = Predictor()
|
|
|
|
p.setup(device='auto')
|
|
|
|
|
|
out = p.predict(
|
|
input_file_path,
|
|
ddim_steps=ddim_steps,
|
|
guidance_scale=guidance_scale,
|
|
seed=seed,
|
|
chunk_size=chunk_size,
|
|
overlap=overlap
|
|
)
|
|
|
|
del p
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|