import gradio as gr import torch import gc # free up memory import spaces import gc import os import random import numpy as np from scipy.signal.windows import hann import soundfile as sf import torch import librosa from audiosr import build_model, super_resolution from scipy import signal import pyloudnorm as pyln import tempfile import spaces class AudioUpscaler: """ Upscales audio using the AudioSR model. """ def __init__(self, model_name="basic", device="auto"): """ Initializes the AudioUpscaler. Args: model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic". device (str, optional): Device to use for inference. Defaults to "auto". """ self.model_name = model_name self.device = device self.sr = 48000 self.audiosr = None # Model will be loaded in setup() def setup(self): """ Loads the AudioSR model. """ print("Loading Model...") self.audiosr = build_model(model_name=self.model_name, device=self.device) print("Model loaded!") def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray): """ Matches the shapes of two arrays by padding the shorter one with zeros. Args: array_1 (np.ndarray): First array. array_2 (np.ndarray): Second array. Returns: np.ndarray: The first array with a matching shape to the second array. """ if (len(array_1.shape) == 1) & (len(array_2.shape) == 1): if array_1.shape[0] > array_2.shape[0]: array_1 = array_1[: array_2.shape[0]] elif array_1.shape[0] < array_2.shape[0]: array_1 = np.pad( array_1, ((array_2.shape[0] - array_1.shape[0], 0)), "constant", constant_values=0, ) else: if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]] elif array_1.shape[1] < array_2.shape[1]: padding = array_2.shape[1] - array_1.shape[1] array_1 = np.pad( array_1, ((0, 0), (0, padding)), "constant", constant_values=0 ) return array_1 def _lr_filter( self, audio, cutoff, filter_type, order=12, sr=48000 ): """ Applies a low-pass or high-pass filter to the audio. Args: audio (np.ndarray): Audio data. cutoff (int): Cutoff frequency. filter_type (str): Filter type ("lowpass" or "highpass"). order (int, optional): Filter order. Defaults to 12. sr (int, optional): Sample rate. Defaults to 48000. Returns: np.ndarray: Filtered audio data. """ audio = audio.T nyquist = 0.5 * sr normal_cutoff = cutoff / nyquist b, a = signal.butter( order // 2, normal_cutoff, btype=filter_type, analog=False ) sos = signal.tf2sos(b, a) filtered_audio = signal.sosfiltfilt(sos, audio) return filtered_audio.T def _process_audio( self, input_file, chunk_size=5.12, overlap=0.1, seed=None, guidance_scale=3.5, ddim_steps=50, multiband_ensemble=True, input_cutoff=14000, ): """ Processes the audio in chunks and performs upsampling. Args: input_file (str): Path to the input audio file. chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12. overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1. seed (int, optional): Random seed. Defaults to None. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. ddim_steps (int, optional): Number of inference steps. Defaults to 50. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. Returns: np.ndarray: Upsampled audio data. """ audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False) audio = audio.T sr = input_cutoff * 2 is_stereo = len(audio.shape) == 2 if is_stereo: audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1] else: audio_ch1 = audio chunk_samples = int(chunk_size * sr) overlap_samples = int(overlap * chunk_samples) output_chunk_samples = int(chunk_size * self.sr) output_overlap_samples = int(overlap * output_chunk_samples) enable_overlap = True if overlap > 0 else False def process_chunks(audio): chunks = [] original_lengths = [] start = 0 while start < len(audio): print(f"{start} / {len(audio)}") end = min(start + chunk_samples, len(audio)) chunk = audio[start:end] if len(chunk) < chunk_samples: original_lengths.append(len(chunk)) pad = np.zeros(chunk_samples - len(chunk)) chunk = np.concatenate([chunk, pad]) else: original_lengths.append(chunk_samples) chunks.append(chunk) start += ( chunk_samples - overlap_samples if enable_overlap else chunk_samples ) return chunks, original_lengths chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1) if is_stereo: chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2) sample_rate_ratio = self.sr / sr total_length = ( len(chunks_ch1) * output_chunk_samples - (len(chunks_ch1) - 1) * (output_overlap_samples if enable_overlap else 0) ) reconstructed_ch1 = np.zeros((1, total_length)) meter_before = pyln.Meter(sr) meter_after = pyln.Meter(self.sr) for i, chunk in enumerate(chunks_ch1): print(f"{i} / {len(chunks_ch1)}") loudness_before = meter_before.integrated_loudness(chunk) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: sf.write(temp_wav.name, chunk, sr) out_chunk = super_resolution( self.audiosr, temp_wav.name, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=12.8, ) out_chunk = out_chunk[0] num_samples_to_keep = int( original_lengths_ch1[i] * sample_rate_ratio ) out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() loudness_after = meter_after.integrated_loudness(out_chunk) out_chunk = pyln.normalize.loudness( out_chunk, loudness_after, loudness_before ) if enable_overlap: actual_overlap_samples = min( output_overlap_samples, num_samples_to_keep ) fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) if i == 0: out_chunk[-actual_overlap_samples:] *= fade_out elif i < len(chunks_ch1) - 1: out_chunk[:actual_overlap_samples] *= fade_in out_chunk[-actual_overlap_samples:] *= fade_out else: out_chunk[:actual_overlap_samples] *= fade_in start = i * ( output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples ) end = start + out_chunk.shape[0] reconstructed_ch1[0, start:end] += out_chunk.flatten() if is_stereo: reconstructed_ch2 = np.zeros((1, total_length)) for i, chunk in enumerate(chunks_ch2): print(f"{i} / {len(chunks_ch2)}") loudness_before = meter_before.integrated_loudness(chunk) with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: sf.write(temp_wav.name, chunk, sr) out_chunk = super_resolution( self.audiosr, temp_wav.name, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, latent_t_per_second=12.8, ) out_chunk = out_chunk[0] num_samples_to_keep = int( original_lengths_ch2[i] * sample_rate_ratio ) out_chunk = out_chunk[:, :num_samples_to_keep].squeeze() loudness_after = meter_after.integrated_loudness(out_chunk) out_chunk = pyln.normalize.loudness( out_chunk, loudness_after, loudness_before ) if enable_overlap: actual_overlap_samples = min( output_overlap_samples, num_samples_to_keep ) fade_out = np.linspace(1.0, 0.0, actual_overlap_samples) fade_in = np.linspace(0.0, 1.0, actual_overlap_samples) if i == 0: out_chunk[-actual_overlap_samples:] *= fade_out elif i < len(chunks_ch1) - 1: out_chunk[:actual_overlap_samples] *= fade_in out_chunk[-actual_overlap_samples:] *= fade_out else: out_chunk[:actual_overlap_samples] *= fade_in start = i * ( output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples ) end = start + out_chunk.shape[0] reconstructed_ch2[0, start:end] += out_chunk.flatten() reconstructed_audio = np.stack( [reconstructed_ch1, reconstructed_ch2], axis=-1 ) else: reconstructed_audio = reconstructed_ch1 if multiband_ensemble: low, _ = librosa.load(input_file, sr=48000, mono=False) output = self._match_array_shapes( reconstructed_audio[0].T, low ) crossover_freq = input_cutoff - 1000 low = self._lr_filter( low.T, crossover_freq, "lowpass", order=10 ) high = self._lr_filter( output.T, crossover_freq, "highpass", order=10 ) high = self._lr_filter( high, 23000, "lowpass", order=2 ) output = low + high else: output = reconstructed_audio[0] return output def predict( self, input_file, output_folder, ddim_steps=50, guidance_scale=3.5, overlap=0.04, chunk_size=10.24, seed=None, multiband_ensemble=True, input_cutoff=14000, ): """ Upscales the audio and saves the result. Args: input_file (str): Path to the input audio file. output_folder (str): Path to the output folder. ddim_steps (int, optional): Number of inference steps. Defaults to 50. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. overlap (float, optional): Overlap between chunks. Defaults to 0.04. chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. seed (int, optional): Random seed. Defaults to None. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. """ if seed == 0: seed = random.randint(0, 2**32 - 1) os.makedirs(output_folder, exist_ok=True) waveform = self._process_audio( input_file, chunk_size=chunk_size, overlap=overlap, seed=seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps, multiband_ensemble=multiband_ensemble, input_cutoff=input_cutoff, ) filename = os.path.splitext(os.path.basename(input_file))[0] output_file = f"{output_folder}/SR_{filename}.wav" sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16") print(f"File created: {output_file}") # Cleanup gc.collect() torch.cuda.empty_cache() return waveform # return output_file def inference(audio_file, model_name, guidance_scale, ddim_steps, seed): audiosr = build_model(model_name=model_name) gc.collect() # set random seed when seed input value is 0 if seed == 0: import random seed = random.randint(1, 2**32-1) waveform = super_resolution( audiosr, audio_file, seed, guidance_scale=guidance_scale, ddim_steps=ddim_steps ) return (48000, waveform) @spaces.GPU(duration=300) def upscale_audio( input_file, output_folder, ddim_steps=20, guidance_scale=3.5, overlap=0.04, chunk_size=10.24, seed=0, multiband_ensemble=True, input_cutoff=14000, ): """ Upscales the audio using the AudioSR model. Args: input_file (str): Path to the input audio file. output_folder (str): Path to the output folder. ddim_steps (int, optional): Number of inference steps. Defaults to 20. guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5. overlap (float, optional): Overlap between chunks. Defaults to 0.04. chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24. seed (int, optional): Random seed. Defaults to 0. multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True. input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000. Returns: tuple: Upscaled audio data and sample rate. """ torch.cuda.empty_cache() gc.collect() upscaler = AudioUpscaler() upscaler.setup() waveform = upscaler.predict( input_file, output_folder, ddim_steps=ddim_steps, guidance_scale=guidance_scale, overlap=overlap, chunk_size=chunk_size, seed=seed, multiband_ensemble=multiband_ensemble, input_cutoff=input_cutoff, ) torch.cuda.empty_cache() gc.collect() return (48000,waveform) os.getcwd() gr.Textbox iface = gr.Interface( fn=upscale_audio, inputs=[ gr.Audio(type="filepath", label="Input Audio"), gr.Textbox(".",label="Out-dir"), gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"), gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"), gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"), gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"), gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"), gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"), gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True) ], outputs=gr.Audio(type="numpy", label="Output Audio"), title="AudioSR", description="Audio Super Resolution with AudioSR" ) iface.launch(share=False)