File size: 10,433 Bytes
f0ca515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import gc
import os
import random
import numpy as np
from scipy.signal.windows import hann
import soundfile as sf
import torch
from cog import BasePredictor, Input, Path
import tempfile
import argparse
import librosa
from audiosr import build_model, super_resolution
from scipy import signal
import pyloudnorm as pyln
import warnings
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
def match_array_shapes(array_1:np.ndarray, array_2:np.ndarray):
if (len(array_1.shape) == 1) & (len(array_2.shape) == 1):
if array_1.shape[0] > array_2.shape[0]:
array_1 = array_1[:array_2.shape[0]]
elif array_1.shape[0] < array_2.shape[0]:
array_1 = np.pad(array_1, ((array_2.shape[0] - array_1.shape[0], 0)), 'constant', constant_values=0)
else:
if array_1.shape[1] > array_2.shape[1]:
array_1 = array_1[:,:array_2.shape[1]]
elif array_1.shape[1] < array_2.shape[1]:
padding = array_2.shape[1] - array_1.shape[1]
array_1 = np.pad(array_1, ((0,0), (0,padding)), 'constant', constant_values=0)
return array_1
def lr_filter(audio, cutoff, filter_type, order=12, sr=48000):
audio = audio.T
nyquist = 0.5 * sr
normal_cutoff = cutoff / nyquist
b, a = signal.butter(order//2, normal_cutoff, btype=filter_type, analog=False)
sos = signal.tf2sos(b, a)
filtered_audio = signal.sosfiltfilt(sos, audio)
return filtered_audio.T
class Predictor(BasePredictor):
def setup(self, model_name="basic", device="auto"):
self.model_name = model_name
self.device = device
self.sr = 48000
print("Loading Model...")
self.audiosr = build_model(model_name=self.model_name, device=self.device)
# print(self.audiosr)
# exit()
print("Model loaded!")
def process_audio(self, input_file, chunk_size=5.12, overlap=0.1, seed=None, guidance_scale=3.5, ddim_steps=50):
audio, sr = librosa.load(input_file, sr=input_cutoff * 2, mono=False)
audio = audio.T
sr = input_cutoff * 2
print(f"audio.shape = {audio.shape}")
print(f"input cutoff = {input_cutoff}")
is_stereo = len(audio.shape) == 2
audio_channels = [audio] if not is_stereo else [audio[:, 0], audio[:, 1]]
print("audio is stereo" if is_stereo else "Audio is mono")
chunk_samples = int(chunk_size * sr)
overlap_samples = int(overlap * chunk_samples)
output_chunk_samples = int(chunk_size * self.sr)
output_overlap_samples = int(overlap * output_chunk_samples)
enable_overlap = overlap > 0
print(f"enable_overlap = {enable_overlap}")
def process_chunks(audio):
chunks = []
original_lengths = []
start = 0
while start < len(audio):
end = min(start + chunk_samples, len(audio))
chunk = audio[start:end]
if len(chunk) < chunk_samples:
original_lengths.append(len(chunk))
chunk = np.concatenate([chunk, np.zeros(chunk_samples - len(chunk))])
else:
original_lengths.append(chunk_samples)
chunks.append(chunk)
start += chunk_samples - overlap_samples if enable_overlap else chunk_samples
return chunks, original_lengths
# Process both channels (mono or stereo)
chunks_per_channel = [process_chunks(channel) for channel in audio_channels]
sample_rate_ratio = self.sr / sr
total_length = len(chunks_per_channel[0][0]) * output_chunk_samples - (len(chunks_per_channel[0][0]) - 1) * (output_overlap_samples if enable_overlap else 0)
reconstructed_channels = [np.zeros((1, total_length)) for _ in audio_channels]
meter_before = pyln.Meter(sr)
meter_after = pyln.Meter(self.sr)
# Process chunks for each channel
for ch_idx, (chunks, original_lengths) in enumerate(chunks_per_channel):
for i, chunk in enumerate(chunks):
loudness_before = meter_before.integrated_loudness(chunk)
print(f"Processing chunk {i+1} of {len(chunks)} for {'Left/Mono' if ch_idx == 0 else 'Right'} channel")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
sf.write(temp_wav.name, chunk, sr)
out_chunk = super_resolution(
self.audiosr,
temp_wav.name,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps,
latent_t_per_second=12.8
)
out_chunk = out_chunk[0]
num_samples_to_keep = int(original_lengths[i] * sample_rate_ratio)
out_chunk = out_chunk[:, :num_samples_to_keep].squeeze()
loudness_after = meter_after.integrated_loudness(out_chunk)
out_chunk = pyln.normalize.loudness(out_chunk, loudness_after, loudness_before)
if enable_overlap:
actual_overlap_samples = min(output_overlap_samples, num_samples_to_keep)
fade_out = np.linspace(1., 0., actual_overlap_samples)
fade_in = np.linspace(0., 1., actual_overlap_samples)
if i == 0:
out_chunk[-actual_overlap_samples:] *= fade_out
elif i < len(chunks) - 1:
out_chunk[:actual_overlap_samples] *= fade_in
out_chunk[-actual_overlap_samples:] *= fade_out
else:
out_chunk[:actual_overlap_samples] *= fade_in
start = i * (output_chunk_samples - output_overlap_samples if enable_overlap else output_chunk_samples)
end = start + out_chunk.shape[0]
reconstructed_channels[ch_idx][0, start:end] += out_chunk.flatten()
reconstructed_audio = np.stack(reconstructed_channels, axis=-1) if is_stereo else reconstructed_channels[0]
if multiband_ensemble:
low, _ = librosa.load(input_file, sr=48000, mono=False)
output = match_array_shapes(reconstructed_audio[0].T, low)
low = lr_filter(low.T, crossover_freq, 'lowpass', order=10)
high = lr_filter(output.T, crossover_freq, 'highpass', order=10)
high = lr_filter(high, 23000, 'lowpass', order=2)
output = low + high
else:
output = reconstructed_audio[0]
# print(output, type(output))
return output
def predict(self,
input_file: Path = Input(description="Audio to upsample"),
ddim_steps: int = Input(description="Number of inference steps", default=50, ge=10, le=500),
guidance_scale: float = Input(description="Scale for classifier free guidance", default=3.5, ge=1.0, le=20.0),
overlap: float = Input(description="overlap size", default=0.04),
chunk_size: float = Input(description="chunksize", default=10.24),
seed: int = Input(description="Random seed. Leave blank to randomize the seed", default=None)
) -> Path:
if seed == 0:
seed = random.randint(0, 2**32 - 1)
print(f"Setting seed to: {seed}")
print(f"overlap = {overlap}")
print(f"guidance_scale = {guidance_scale}")
print(f"ddim_steps = {ddim_steps}")
print(f"chunk_size = {chunk_size}")
print(f"multiband_ensemble = {multiband_ensemble}")
print(f"input file = {os.path.basename(input_file)}")
os.makedirs(output_folder, exist_ok=True)
waveform = self.process_audio(
input_file,
chunk_size=chunk_size,
overlap=overlap,
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=ddim_steps
)
filename = os.path.splitext(os.path.basename(input_file))[0]
sf.write(f"{output_folder}/SR_{filename}.wav", data=waveform, samplerate=48000, subtype="PCM_16")
print(f"file created: {output_folder}/SR_{filename}.wav")
del self.audiosr, waveform
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Find volume difference of two audio files.")
parser.add_argument("--input", help="Path to input audio file")
parser.add_argument("--output", help="Output folder")
parser.add_argument("--ddim_steps", help="Number of ddim steps", type=int, required=False, default=50)
parser.add_argument("--chunk_size", help="chunk size", type=float, required=False, default=10.24)
parser.add_argument("--guidance_scale", help="Guidance scale value", type=float, required=False, default=3.5)
parser.add_argument("--seed", help="Seed value, 0 = random seed", type=int, required=False, default=0)
parser.add_argument("--overlap", help="overlap value", type=float, required=False, default=0.04)
parser.add_argument("--multiband_ensemble", type=bool, help="Use multiband ensemble with input")
parser.add_argument("--input_cutoff", help="Define the crossover of audio input in the multiband ensemble", type=int, required=False, default=12000)
args = parser.parse_args()
input_file_path = args.input
output_folder = args.output
ddim_steps = args.ddim_steps
chunk_size = args.chunk_size
guidance_scale = args.guidance_scale
seed = args.seed
overlap = args.overlap
input_cutoff = args.input_cutoff
multiband_ensemble = args.multiband_ensemble
crossover_freq = input_cutoff - 1000
p = Predictor()
p.setup(device='auto')
out = p.predict(
input_file_path,
ddim_steps=ddim_steps,
guidance_scale=guidance_scale,
seed=seed,
chunk_size=chunk_size,
overlap=overlap
)
del p
gc.collect()
torch.cuda.empty_cache()
|