import gc import os import random import numpy as np from scipy.signal.windows import hann import soundfile as sf import torch from cog import BasePredictor, Input, Path import tempfile import argparse import librosa from audiosr import build_model, super_resolution from scipy import signal import pyloudnorm as pyln import warnings warnings.filterwarnings("ignore") os.environ["TOKENIZERS_PARALLELISM"] = "true" torch.set_float32_matmul_precision("high") def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray): if (len(array_1.shape) == 1) & (len(array_2.shape) == 1): if array_1.shape[0] > array_2.shape[0]: array_1 = array_1[:array_2.shape[0]] elif array_1.shape[0] < array_2.shape[0]: array_1 = np.pad(array_1, ((array_2.shape[0] - array_1.shape[0], 0)), 'constant', constant_values=0) else: if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:,:array_2.shape[1]] elif array_1.shape[1] < array_2.shape[1]: padding = array_2.shape[1] - array_1.shape[1] array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0) return array_1 def lr_filter(audio, cutoff, filter_type, order=12, sr=48000): audio = audio.T nyquist = 0.5 * sr normal_cutoff = cutoff / nyquist b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False) sos = signal.tf2sos(b, a) filtered_audio = signal.sosfiltfilt(sos, audio) return filtered_audio.T class Predictor(BasePredictor): def setup(self, model_name="basic", device="auto"): self.model_name = model_name self.device = device self.sr = 48000 print("Loading Model...") self.audiosr = build_model(model_name=self.model_name, device=self.device) # print(self.audiosr) # exit() print("Model loaded!") def process_audio(self, input_file, chunk_size=5.12, overlap=0.1, seed=None, guidance_scale=3.5, ddim_steps=50): audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False) audio = audio.T sr = input_cutoff * 2 print(f"audio.shape = {audio.shape}") print(f"input cutoff = {input_cutoff}") is_stereo = len(audio.shape) == 2 audio_channels = [audio] if not is_stereo else [audio[:, 0], audio[:, 1]] print("audio is stereo" if is_stereo else "Audio is mono") chunk_samples = int(chunk_size * sr) overlap_samples = int(overlap * chunk_samples) output_chunk_samples = int(chunk_size * self.sr) output_overlap_samples = int(overlap * output_chunk_samples) enable_overlap = overlap > 0 print(f"enable_overlap = {enable_overlap}") def process_chunks(audio): chunks = [] original_lengths = [] start = 0 while start < len(audio): end = min(start + chunk_samples, len(audio)) chunk = audio[start:end] if len(chunk) < chunk_samples: original_lengths.append(len(chunk)) chunk = np.concatenate([chunk, np.zeros(chunk_samples - len(chunk))]) else: original_lengths.append(chunk_samples) chunks.append(chunk) start += chunk_samples - overlap_samples if enable_overlap else chunk_samples return chunks, original_lengths # Process both channels (mono or stereo) chunks_per_channel = [process_chunks(channel) for channel in audio_channels] sample_rate_ratio = self.sr / sr total_length = len(chunks_per_channel[0][0]) * output_chunk_samples - (len(chunks_per_channel[0][0]) - 1) * (output_overlap_samples if enable_overlap else 0) reconstructed_channels = [np.zeros((1, total_length)) for _ in audio_channels] meter_before = pyln.Meter(sr) meter_after = pyln.Meter(self.sr) # Process chunks for each channel for ch_idx, (chunks, original_lengths) in enumerate(chunks_per_channel): for i, chunk in enumerate(chunks): loudness_before = meter_before.integrated_loudness(chunk) print(f"Processing chunk {i+1} of {len(chunks)} for {'Left/Mono' if ch_idx == 0 else 'Right'} channel") with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: sf.write(temp_wav.name, chunk, sr) out_chunk = super_resolution( self.audiosr, temp_wav.name, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=12.8 ) out_chunk = out_chunk[0] num_samples_to_keep = int(original_lengths[i] * sample_rate_ratio) out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() loudness_after = meter_after.integrated_loudness(out_chunk) out_chunk = pyln.normalize.loudness(out_chunk, loudness_after, loudness_before) if enable_overlap: actual_overlap_samples = min(output_overlap_samples, num_samples_to_keep) fade_out = np.linspace(1., 0., actual_overlap_samples) fade_in = np.linspace(0., 1., actual_overlap_samples) if i == 0: out_chunk[-actual_overlap_samples:] *= fade_out elif i < len(chunks) - 1: out_chunk[:actual_overlap_samples] *= fade_in out_chunk[-actual_overlap_samples:] *= fade_out else: out_chunk[:actual_overlap_samples] *= fade_in start = i * (output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples) end = start + out_chunk.shape[0] reconstructed_channels[ch_idx][0, start:end] += out_chunk.flatten() reconstructed_audio = np.stack(reconstructed_channels, axis=-1) if is_stereo else reconstructed_channels[0] if multiband_ensemble: low, _ = librosa.load(input_file, sr=48000, mono=False) output = match_array_shapes(reconstructed_audio[0].T, low) low = lr_filter(low.T, crossover_freq, 'lowpass', order=10) high = lr_filter(output.T, crossover_freq, 'highpass', order=10) high = lr_filter(high, 23000, 'lowpass', order=2) output = low + high else: output = reconstructed_audio[0] # print(output, type(output)) return output def predict(self, input_file: Path = Input(description="Audio to upsample"), ddim_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=500), guidance_scale: float = Input(description="Scale for classifier free guidance", default=3.5, ge=1.0, le=20.0), overlap: float = Input(description="overlap size", default=0.04), chunk_size: float = Input(description="chunksize", default=10.24), seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None) ) -> Path: if seed == 0: seed = random.randint(0, 2**32 - 1) print(f"Setting seed to: {seed}") print(f"overlap = {overlap}") print(f"guidance_scale = {guidance_scale}") print(f"ddim_steps = {ddim_steps}") print(f"chunk_size = {chunk_size}") print(f"multiband_ensemble = {multiband_ensemble}") print(f"input file = {os.path.basename(input_file)}") os.makedirs(output_folder, exist_ok=True) waveform = self.process_audio( input_file, chunk_size=chunk_size, overlap=overlap, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps ) filename = os.path.splitext(os.path.basename(input_file))[0] sf.write(f"{output_folder}/SR_{filename}.wav", data=waveform, samplerate=48000, subtype="PCM_16") print(f"file created: {output_folder}/SR_{filename}.wav") del self.audiosr, waveform gc.collect() torch.cuda.empty_cache() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Find volume difference of two audio files.") parser.add_argument("--input", help="Path to input audio file") parser.add_argument("--output", help="Output folder") parser.add_argument("--ddim_steps", help="Number of ddim steps", type=int, required=False, default=50) parser.add_argument("--chunk_size", help="chunk size", type=float, required=False, default=10.24) parser.add_argument("--guidance_scale", help="Guidance scale value", type=float, required=False, default=3.5) parser.add_argument("--seed", help="Seed value, 0 = random seed", type=int, required=False, default=0) parser.add_argument("--overlap", help="overlap value", type=float, required=False, default=0.04) parser.add_argument("--multiband_ensemble", type=bool, help="Use multiband ensemble with input") parser.add_argument("--input_cutoff", help="Define the crossover of audio input in the multiband ensemble", type=int, required=False, default=12000) args = parser.parse_args() input_file_path = args.input output_folder = args.output ddim_steps = args.ddim_steps chunk_size = args.chunk_size guidance_scale = args.guidance_scale seed = args.seed overlap = args.overlap input_cutoff = args.input_cutoff multiband_ensemble = args.multiband_ensemble crossover_freq = input_cutoff - 1000 p = Predictor() p.setup(device='auto') out = p.predict( input_file_path, ddim_steps=ddim_steps, guidance_scale=guidance_scale, seed=seed, chunk_size=chunk_size, overlap=overlap ) del p gc.collect() torch.cuda.empty_cache()