asr / main.py
kgout's picture
Update main.py
cb6ae20 verified
import gradio as gr
import torch
import gc # free up memory
import spaces
import gc
import os
import random
import numpy as np
from scipy.signal.windows import hann
import soundfile as sf
import torch
import librosa
from audiosr import build_model, super_resolution
from scipy import signal
import pyloudnorm as pyln
import tempfile
import spaces
class AudioUpscaler:
"""
Upscales audio using the AudioSR model.
"""
def __init__(self, model_name="basic", device="auto"):
"""
Initializes the AudioUpscaler.
Args:
model_name (str, optional): Name of the AudioSR model to use. Defaults to "basic".
device (str, optional): Device to use for inference. Defaults to "auto".
"""
self.model_name = model_name
self.device = device
self.sr = 48000
self.audiosr = None # Model will be loaded in setup()
def setup(self):
"""
Loads the AudioSR model.
"""
print("Loading Model...")
self.audiosr = build_model(model_name=self.model_name, device=self.device)
print("Model loaded!")
def _match_array_shapes(self, array_1: np.ndarray, array_2: np.ndarray):
"""
Matches the shapes of two arrays by padding the shorter one with zeros.
Args:
array_1 (np.ndarray): First array.
array_2 (np.ndarray): Second array.
Returns:
np.ndarray: The first array with a matching shape to the second array.
"""
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
if array_1.shape[0] > array_2.shape[0]:
array_1 = array_1[: array_2.shape[0]]
elif array_1.shape[0] < array_2.shape[0]:
array_1 = np.pad(
array_1,
((array_2.shape[0] - array_1.shape[0], 0)),
"constant",
constant_values=0,
)
else:
if array_1.shape[1] > array_2.shape[1]:
array_1 = array_1[:, : array_2.shape[1]]
elif array_1.shape[1] < array_2.shape[1]:
padding = array_2.shape[1] - array_1.shape[1]
array_1 = np.pad(
array_1, ((0, 0), (0, padding)), "constant", constant_values=0
)
return array_1
def _lr_filter(
self, audio, cutoff, filter_type, order=12, sr=48000
):
"""
Applies a low-pass or high-pass filter to the audio.
Args:
audio (np.ndarray): Audio data.
cutoff (int): Cutoff frequency.
filter_type (str): Filter type ("lowpass" or "highpass").
order (int, optional): Filter order. Defaults to 12.
sr (int, optional): Sample rate. Defaults to 48000.
Returns:
np.ndarray: Filtered audio data.
"""
audio = audio.T
nyquist = 0.5 * sr
normal_cutoff = cutoff / nyquist
b, a = signal.butter(
order // 2, normal_cutoff, btype=filter_type, analog=False
)
sos = signal.tf2sos(b, a)
filtered_audio = signal.sosfiltfilt(sos, audio)
return filtered_audio.T
def _process_audio(
self,
input_file,
chunk_size=5.12,
overlap=0.1,
seed=None,
guidance_scale=3.5,
ddim_steps=50,
multiband_ensemble=True,
input_cutoff=14000,
):
"""
Processes the audio in chunks and performs upsampling.
Args:
input_file (str): Path to the input audio file.
chunk_size (float, optional): Chunk size in seconds. Defaults to 5.12.
overlap (float, optional): Overlap between chunks in seconds. Defaults to 0.1.
seed (int, optional): Random seed. Defaults to None.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
ddim_steps (int, optional): Number of inference steps. Defaults to 50.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
Returns:
np.ndarray: Upsampled audio data.
"""
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
audio = audio.T
sr = input_cutoff * 2
is_stereo = len(audio.shape) == 2
if is_stereo:
audio_ch1, audio_ch2 = audio[:, 0], audio[:, 1]
else:
audio_ch1 = audio
chunk_samples = int(chunk_size * sr)
overlap_samples = int(overlap * chunk_samples)
output_chunk_samples = int(chunk_size * self.sr)
output_overlap_samples = int(overlap * output_chunk_samples)
enable_overlap = True if overlap > 0 else False
def process_chunks(audio):
chunks = []
original_lengths = []
start = 0
while start < len(audio):
print(f"{start} / {len(audio)}")
end = min(start + chunk_samples, len(audio))
chunk = audio[start:end]
if len(chunk) < chunk_samples:
original_lengths.append(len(chunk))
pad = np.zeros(chunk_samples - len(chunk))
chunk = np.concatenate([chunk, pad])
else:
original_lengths.append(chunk_samples)
chunks.append(chunk)
start += (
chunk_samples - overlap_samples
if enable_overlap
else chunk_samples
)
return chunks, original_lengths
chunks_ch1, original_lengths_ch1 = process_chunks(audio_ch1)
if is_stereo:
chunks_ch2, original_lengths_ch2 = process_chunks(audio_ch2)
sample_rate_ratio = self.sr / sr
total_length = (
len(chunks_ch1) * output_chunk_samples
- (len(chunks_ch1) - 1)
* (output_overlap_samples if enable_overlap else 0)
)
reconstructed_ch1 = np.zeros((1, total_length))
meter_before = pyln.Meter(sr)
meter_after = pyln.Meter(self.sr)
for i, chunk in enumerate(chunks_ch1):
print(f"{i} / {len(chunks_ch1)}")
loudness_before = meter_before.integrated_loudness(chunk)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8,
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(
original_lengths_ch1[i] * sample_rate_ratio
)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(
out_chunk, loudness_after, loudness_before
)
if enable_overlap:
actual_overlap_samples = min(
output_overlap_samples, num_samples_to_keep
)
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks_ch1) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (
output_chunk_samples - output_overlap_samples
if enable_overlap
else output_chunk_samples
)
end = start + out_chunk.shape[0]
reconstructed_ch1[0, start:end] += out_chunk.flatten()
if is_stereo:
reconstructed_ch2 = np.zeros((1, total_length))
for i, chunk in enumerate(chunks_ch2):
print(f"{i} / {len(chunks_ch2)}")
loudness_before = meter_before.integrated_loudness(chunk)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8,
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(
original_lengths_ch2[i] * sample_rate_ratio
)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(
out_chunk, loudness_after, loudness_before
)
if enable_overlap:
actual_overlap_samples = min(
output_overlap_samples, num_samples_to_keep
)
fade_out = np.linspace(1.0, 0.0, actual_overlap_samples)
fade_in = np.linspace(0.0, 1.0, actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks_ch1) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (
output_chunk_samples - output_overlap_samples
if enable_overlap
else output_chunk_samples
)
end = start + out_chunk.shape[0]
reconstructed_ch2[0, start:end] += out_chunk.flatten()
reconstructed_audio = np.stack(
[reconstructed_ch1, reconstructed_ch2], axis=-1
)
else:
reconstructed_audio = reconstructed_ch1
if multiband_ensemble:
low, _ = librosa.load(input_file, sr=48000, mono=False)
output = self._match_array_shapes(
reconstructed_audio[0].T, low
)
crossover_freq = input_cutoff - 1000
low = self._lr_filter(
low.T, crossover_freq, "lowpass", order=10
)
high = self._lr_filter(
output.T, crossover_freq, "highpass", order=10
)
high = self._lr_filter(
high, 23000, "lowpass", order=2
)
output = low + high
else:
output = reconstructed_audio[0]
return output
def predict(
self,
input_file,
output_folder,
ddim_steps=50,
guidance_scale=3.5,
overlap=0.04,
chunk_size=10.24,
seed=None,
multiband_ensemble=True,
input_cutoff=14000,
):
"""
Upscales the audio and saves the result.
Args:
input_file (str): Path to the input audio file.
output_folder (str): Path to the output folder.
ddim_steps (int, optional): Number of inference steps. Defaults to 50.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
overlap (float, optional): Overlap between chunks. Defaults to 0.04.
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
seed (int, optional): Random seed. Defaults to None.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
"""
if seed == 0:
seed = random.randint(0, 2**32 - 1)
os.makedirs(output_folder, exist_ok=True)
waveform = self._process_audio(
input_file,
chunk_size=chunk_size,
overlap=overlap,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
multiband_ensemble=multiband_ensemble,
input_cutoff=input_cutoff,
)
filename = os.path.splitext(os.path.basename(input_file))[0]
output_file = f"{output_folder}/SR_{filename}.wav"
sf.write(output_file, data=waveform, samplerate=48000, subtype="PCM_16")
print(f"File created: {output_file}")
# Cleanup
gc.collect()
torch.cuda.empty_cache()
return waveform
# return output_file
def inference(audio_file, model_name, guidance_scale, ddim_steps, seed):
audiosr = build_model(model_name=model_name)
gc.collect()
# set random seed when seed input value is 0
if seed == 0:
import random
seed = random.randint(1, 2**32-1)
waveform = super_resolution(
audiosr,
audio_file,
seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps
)
return (48000, waveform)
@spaces.GPU(duration=300)
def upscale_audio(
input_file,
output_folder,
ddim_steps=20,
guidance_scale=3.5,
overlap=0.04,
chunk_size=10.24,
seed=0,
multiband_ensemble=True,
input_cutoff=14000,
):
"""
Upscales the audio using the AudioSR model.
Args:
input_file (str): Path to the input audio file.
output_folder (str): Path to the output folder.
ddim_steps (int, optional): Number of inference steps. Defaults to 20.
guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 3.5.
overlap (float, optional): Overlap between chunks. Defaults to 0.04.
chunk_size (float, optional): Chunk size in seconds. Defaults to 10.24.
seed (int, optional): Random seed. Defaults to 0.
multiband_ensemble (bool, optional): Whether to use multiband ensemble. Defaults to True.
input_cutoff (int, optional): Input cutoff frequency for multiband ensemble. Defaults to 14000.
Returns:
tuple: Upscaled audio data and sample rate.
"""
torch.cuda.empty_cache()
gc.collect()
upscaler = AudioUpscaler()
upscaler.setup()
waveform = upscaler.predict(
input_file,
output_folder,
ddim_steps=ddim_steps,
guidance_scale=guidance_scale,
overlap=overlap,
chunk_size=chunk_size,
seed=seed,
multiband_ensemble=multiband_ensemble,
input_cutoff=input_cutoff,
)
torch.cuda.empty_cache()
gc.collect()
return (48000,waveform)
os.getcwd()
gr.Textbox
iface = gr.Interface(
fn=upscale_audio,
inputs=[
gr.Audio(type="filepath", label="Input Audio"),
gr.Textbox(".",label="Out-dir"),
gr.Slider(10, 500, value=20, step=1, label="DDIM Steps", info="Number of inference steps (quality/speed)"),
gr.Slider(1.0, 20.0, value=3.5, step=0.1, label="Guidance Scale", info="Guidance scale (creativity/fidelity)"),
gr.Slider(0.0, 0.5, value=0.04, step=0.01, label="Overlap (s)", info="Overlap between chunks (smooth transitions)"),
gr.Slider(5.12, 20.48, value=5.12, step=0.64, label="Chunk Size (s)", info="Chunk size (memory/artifact balance)"),
gr.Number(value=0, precision=0, label="Seed", info="Random seed (0 for random)"),
gr.Checkbox(label="Multiband Ensemble", value=False, info="Enhance high frequencies"),
gr.Slider(500, 15000, value=9000, step=500, label="Crossover Frequency (Hz)", info="For multiband processing", visible=True)
],
outputs=gr.Audio(type="numpy", label="Output Audio"),
title="AudioSR",
description="Audio Super Resolution with AudioSR"
)
iface.launch(share=False)