File size: 6,976 Bytes
96e64e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import sys
sys.path.append('./BigVGAN')
import time
import torch
import torchaudio
import argparse
from tqdm import tqdm
import librosa
from BigVGAN import bigvgan
from BigVGAN.meldataset import get_mel_spectrogram
from model import OptimizedAudioRestorationModel
# Set the device handle macbooks with M1 chip
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize BigVGAN model
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
'nvidia/bigvgan_v2_24khz_100band_256x',
use_cuda_kernel=False,
force_download=False
).to(device)
bigvgan_model.remove_weight_norm()
def measure_gpu_memory():
if device == 'cuda':
torch.cuda.synchronize()
return torch.cuda.max_memory_allocated() / (1024 ** 2) # Convert to MB
return 0
def apply_overlap_windowing_waveform(waveform, window_size_samples, overlap):
step_size = int(window_size_samples * (1 - overlap))
num_chunks = (waveform.shape[-1] - window_size_samples) // step_size + 1
windows = []
for i in range(num_chunks):
start_idx = i * step_size
end_idx = start_idx + window_size_samples
chunk = waveform[..., start_idx:end_idx]
windows.append(chunk)
return torch.stack(windows)
def reconstruct_waveform_from_windows(windows, window_size_samples, overlap):
step_size = int(window_size_samples * (1 - overlap))
shape = windows.shape
if len(shape) == 2:
# windows.shape == (num_windows, window_len)
num_windows, window_len = shape
channels = 1
windows = windows.unsqueeze(1) # Now windows.shape == (num_windows, 1, window_len)
elif len(shape) == 3:
num_windows, channels, window_len = shape
else:
raise ValueError(f"Unexpected windows.shape: {windows.shape}")
output_length = (num_windows - 1) * step_size + window_size_samples
reconstructed = torch.zeros((channels, output_length))
window_sums = torch.zeros((channels, output_length))
for i in range(num_windows):
start_idx = i * step_size
end_idx = start_idx + window_len
reconstructed[:, start_idx:end_idx] += windows[i]
window_sums[:, start_idx:end_idx] += 1
reconstructed = reconstructed / window_sums.clamp(min=1e-6)
if channels == 1:
reconstructed = reconstructed.squeeze(0) # Remove channel dimension if single channel
return reconstructed
def load_model(save_path):
"""
Load the optimized audio restoration model.
Parameters:
- save_path: Path to the checkpoint file.
"""
optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=bigvgan_model)
state_dict = torch.load(save_path, map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
optimized_model.voice_restore.load_state_dict(state_dict, strict=True)
return optimized_model
def restore_audio(model, input_path, output_path, steps=16, cfg_strength=0.5, window_size_sec=5.0, overlap=0.5):
# Load the audio file
start_time = time.time()
initial_gpu_memory = measure_gpu_memory()
wav, sr = librosa.load(input_path, sr=24000, mono=True)
wav = torch.FloatTensor(wav).unsqueeze(0) # Shape: [1, num_samples]
window_size_samples = int(window_size_sec * sr)
step_size = int(window_size_samples * (1 - overlap))
# Apply overlapping windowing to the waveform
wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap)
restored_wav_windows = []
for wav_window in tqdm(wav_windows):
wav_window = wav_window.to(device) # Shape: [1, window_size_samples]
# Convert to Mel-spectrogram
processed_mel = get_mel_spectrogram(wav_window, bigvgan_model.h).to(device)
# Restore audio
with torch.no_grad():
with torch.autocast(device):
restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength)
restored_mel = restored_mel.squeeze(0).transpose(0, 1)
# Convert restored mel-spectrogram to waveform
with torch.no_grad():
with torch.autocast(device):
restored_wav = bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu() # Shape: [num_samples]
# Debug: Print shapes
# print(f"restored_wav.shape: {restored_wav.shape}")
restored_wav_windows.append(restored_wav)
del wav_window, processed_mel, restored_mel, restored_wav
torch.cuda.empty_cache()
restored_wav_windows = torch.stack(restored_wav_windows) # Shape: [num_windows, num_samples]
# Debug: Print shapes
# print(f"restored_wav_windows.shape: {restored_wav_windows.shape}")
# Reconstruct the full waveform from the processed windows
restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap)
# Ensure the restored_wav has correct dimensions for saving
if restored_wav.dim() == 1:
restored_wav = restored_wav.unsqueeze(0) # Shape: [1, num_samples]
# Save the restored audio
torchaudio.save(output_path, restored_wav, 24000)
end_time = time.time()
total_time = end_time - start_time
peak_gpu_memory = measure_gpu_memory()
gpu_memory_used = peak_gpu_memory - initial_gpu_memory
print(f"Total inference time: {total_time:.2f} seconds")
print(f"Peak GPU memory usage: {peak_gpu_memory:.2f} MB")
print(f"GPU memory used: {gpu_memory_used:.2f} MB")
if __name__ == "__main__":
# Argument parser setup
parser = argparse.ArgumentParser(description="Audio restoration using OptimizedAudioRestorationModel for long-form audio.")
parser.add_argument('--checkpoint', type=str, required=True, help="Path to the checkpoint file")
parser.add_argument('--input', type=str, required=True, help="Path to the input audio file")
parser.add_argument('--output', type=str, required=True, help="Path to save the restored audio file")
parser.add_argument('--steps', type=int, default=16, help="Number of sampling steps")
parser.add_argument('--cfg_strength', type=float, default=0.5, help="CFG strength value")
parser.add_argument('--window_size_sec', type=float, default=5.0, help="Window size in seconds for overlapping")
parser.add_argument('--overlap', type=float, default=0.5, help="Overlap ratio for windowing")
# Parse arguments
args = parser.parse_args()
# Load the optimized model
optimized_model = load_model(args.checkpoint)
if device == 'cuda':
optimized_model.bfloat16()
optimized_model.eval()
optimized_model.to(device)
# Use the model to restore audio
restore_audio(
optimized_model,
args.input,
args.output,
steps=args.steps,
cfg_strength=args.cfg_strength,
window_size_sec=args.window_size_sec,
overlap=args.overlap
)
|