import torch import torchaudio import torch.nn as nn from transformers import PreTrainedModel import torch from BigVGAN import bigvgan from BigVGAN.meldataset import get_mel_spectrogram from voice_restore import VoiceRestore import argparse from model import OptimizedAudioRestorationModel import librosa from inference_long import apply_overlap_windowing_waveform, reconstruct_waveform_from_windows from huggingface_hub import snapshot_download device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Configuration class for VoiceRestore class VoiceRestoreConfig(PretrainedConfig): model_type = "voice_restore" def __init__(self, **kwargs): super().__init__(**kwargs) self.steps = kwargs.get("steps", 16) self.cfg_strength = kwargs.get("cfg_strength", 0.5) self.window_size_sec = kwargs.get("window_size_sec", 5.0) self.overlap = kwargs.get("overlap", 0.5) # Model class for VoiceRestore class VoiceRestore(PreTrainedModel): config_class = VoiceRestoreConfig def __init__(self, config: VoiceRestoreConfig): super().__init__(config) self.steps = config.steps self.cfg_strength = config.cfg_strength self.window_size_sec = config.window_size_sec self.overlap = config.overlap # Initialize BigVGAN model self.bigvgan_model = bigvgan.BigVGAN.from_pretrained( 'nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False, force_download=False ).to(device) self.bigvgan_model.remove_weight_norm() # Optimized restoration model self.optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=self.bigvgan_model) save_path = "/content/voicerestore/checkpoints/voice-restore-20d-16h-optim.pt" state_dict = torch.load(save_path, map_location=torch.device(device)) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] self.optimized_model.voice_restore.load_state_dict(state_dict, strict=True) self.optimized_model.eval() def forward(self, input_path, output_path, short=True): # Restore the audio using the parameters from the config if short: self.restore_audio_short(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength) else: self.restore_audio_long(self.optimized_model, input_path, output_path, self.steps, self.cfg_strength, self.window_size_sec, self.overlap) def restore_audio_short(self, model, input_path, output_path, steps, cfg_strength): """ Short inference for audio restoration. """ # Load the audio file device_type = device.type audio, sr = torchaudio.load(input_path) if sr != model.target_sample_rate: audio = torchaudio.functional.resample(audio, sr, model.target_sample_rate) audio = audio.mean(dim=0, keepdim=True) if audio.dim() > 1 else audio # Convert to mono if stereo with torch.inference_mode(): with torch.autocast(device_type): restored_wav = model(audio, steps=steps, cfg_strength=cfg_strength) restored_wav = restored_wav.squeeze(0).float().cpu() # Move to CPU after processing # Save the restored audio torchaudio.save(output_path, restored_wav, model.target_sample_rate) def restore_audio_long(self, model, input_path, output_path, steps, cfg_strength, window_size_sec, overlap): """ Long inference for audio restoration using overlapping windows. """ # Load the audio file wav, sr = librosa.load(input_path, sr=24000, mono=True) wav = torch.FloatTensor(wav).unsqueeze(0) # Shape: [1, num_samples] window_size_samples = int(window_size_sec * sr) wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap) restored_wav_windows = [] for wav_window in wav_windows: wav_window = wav_window.to(device) processed_mel = get_mel_spectrogram(wav_window, self.bigvgan_model.h).to(device) # Restore audio with torch.no_grad(): with torch.autocast(device): restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength) restored_mel = restored_mel.squeeze(0).transpose(0, 1) restored_wav = self.bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu() restored_wav_windows.append(restored_wav) torch.cuda.empty_cache() restored_wav_windows = torch.stack(restored_wav_windows) restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap) # Save the restored audio torchaudio.save(output_path, restored_wav.unsqueeze(0), 24000) # # Function to load the model using AutoModel # from transformers import AutoModel # def load_voice_restore_model(checkpoint_path: str): # model = AutoModel.from_pretrained(checkpoint_path, config=VoiceRestoreConfig()) # return model # # Example Usage # model = load_voice_restore_model("./checkpoints/voice-restore-20d-16h-optim.pt") # model("test_input.wav", "test_output.wav")