Spaces:
Running
Running
| import os | |
| import torch | |
| import torchaudio | |
| import argparse | |
| from huggingface_hub import hf_hub_download | |
| # For PyHARP wrapper | |
| from pyharp import ModelCard, build_endpoint, load_audio, save_audio | |
| import gradio as gr | |
| # Create a ModelCard | |
| model_card = ModelCard( | |
| name="Apollo", | |
| description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.", | |
| author="JusperLee", | |
| tags=["audio restoration", "music", "apollo", "mp3", "lossless"], | |
| ) | |
| def load_audio(file_path): | |
| audio, samplerate = torchaudio.load(file_path) | |
| return audio.unsqueeze(0) # [1, 1, samples] - no .cuda() | |
| def save_audio(file_path, audio, samplerate=44100): | |
| audio = audio.squeeze(0).cpu() | |
| torchaudio.save(file_path, audio, samplerate) | |
| #Defining the process function | |
| def process_fn( | |
| input_audio_path: str | |
| ) -> str: | |
| # Don't set CUDA device - let it use CPU | |
| device = torch.device("cpu") | |
| print(f"Using device: {device}") | |
| print("Loading Apollo model...") | |
| # Download model weights from HuggingFace | |
| model_path = hf_hub_download( | |
| repo_id="JusperLee/Apollo", | |
| filename="pytorch_model.bin", | |
| cache_dir="./checkpoints" | |
| ) | |
| # Load checkpoint WITH OmegaConf support | |
| print(f"Loading checkpoint from {model_path}") | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=False) | |
| # Extract model info | |
| model_name = checkpoint['model_name'] | |
| state_dict = checkpoint['state_dict'] | |
| model_args = checkpoint.get('model_args', {}) | |
| print(f"Model class: {model_name}") | |
| print(f"Model args: {model_args}") | |
| # Import the correct model class | |
| from look2hear.models import get | |
| model_class = get(model_name) | |
| # Create model instance with model_args | |
| # Convert OmegaConf to dict if needed | |
| if hasattr(model_args, 'to_container'): | |
| model_args = model_args.to_container(resolve=True) | |
| print(f"Instantiating {model_name}...") | |
| model = model_class(**model_args) | |
| # Load state dict | |
| print("Loading state dict...") | |
| model.load_state_dict(state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| print("✓ Model loaded successfully") | |
| # Commenting out excess print statement bc it uses input.wav | |
| # print(f"Processing audio: {input_wav}") | |
| sig = load_audio(input_audio_path) | |
| # Move audio data to device | |
| sig = sig.to(device) | |
| # Add batch dimension if needed (Apollo expects [batch, channels, samples]) | |
| if sig.dim() == 2: | |
| sig = sig.unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(sig) | |
| # Remove batch dimension | |
| output = output.squeeze(0) | |
| output_audio_path = os.path.join("src", "_outputs", "output_restored.wav") | |
| os.makedirs(os.path.dirname(output_audio_path), exist_ok=True) | |
| torchaudio.save(output_audio_path, output, 44100) | |
| print(f"✓ Saved output to {output_audio_path}") | |
| return output_audio_path | |
| # original export method | |
| # save_audio(output_wav, out) | |
| # print(f"✓ Saved output to {output_wav}") | |
| # Build Gradio endpoint | |
| with gr.Blocks() as demo: | |
| # Define input Gradio Components | |
| input_components = [ | |
| gr.Audio(type="filepath", | |
| label="Input Audio A") | |
| .harp_required(True), | |
| ] | |
| # Define output Gradio Components | |
| output_components = [ | |
| gr.Audio(type="filepath", | |
| label="Output Audio") | |
| .set_info("The restored audio."), | |
| ] | |
| # Build a HARP-compatible endpoint | |
| app = build_endpoint( | |
| model_card=model_card, | |
| input_components=input_components, | |
| output_components=output_components, | |
| process_fn=process_fn, | |
| ) | |
| # run the thing | |
| demo.queue().launch(share=True, show_error=False, pwa=True) | |
| # original inference function run | |
| ''' | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Audio Inference Script") | |
| parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file") | |
| parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file") | |
| args = parser.parse_args() | |
| main(args.in_wav, args.out_wav) | |
| ''' | |