Apollo-PyHARP / app.py
NatalieElizabeth's picture
initial fr
e617857
import os
import torch
import torchaudio
import argparse
from huggingface_hub import hf_hub_download
# For PyHARP wrapper
from pyharp import ModelCard, build_endpoint, load_audio, save_audio
import gradio as gr
# Create a ModelCard
model_card = ModelCard(
name="Apollo",
description="High-quality audio restoration for lossy MP3 compressed audio. Converts low-bitrate MP3s to near-lossless quality using band-sequence modeling.",
author="JusperLee",
tags=["audio restoration", "music", "apollo", "mp3", "lossless"],
)
def load_audio(file_path):
audio, samplerate = torchaudio.load(file_path)
return audio.unsqueeze(0) # [1, 1, samples] - no .cuda()
def save_audio(file_path, audio, samplerate=44100):
audio = audio.squeeze(0).cpu()
torchaudio.save(file_path, audio, samplerate)
#Defining the process function
@torch.inference_mode()
def process_fn(
input_audio_path: str
) -> str:
# Don't set CUDA device - let it use CPU
device = torch.device("cpu")
print(f"Using device: {device}")
print("Loading Apollo model...")
# Download model weights from HuggingFace
model_path = hf_hub_download(
repo_id="JusperLee/Apollo",
filename="pytorch_model.bin",
cache_dir="./checkpoints"
)
# Load checkpoint WITH OmegaConf support
print(f"Loading checkpoint from {model_path}")
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
# Extract model info
model_name = checkpoint['model_name']
state_dict = checkpoint['state_dict']
model_args = checkpoint.get('model_args', {})
print(f"Model class: {model_name}")
print(f"Model args: {model_args}")
# Import the correct model class
from look2hear.models import get
model_class = get(model_name)
# Create model instance with model_args
# Convert OmegaConf to dict if needed
if hasattr(model_args, 'to_container'):
model_args = model_args.to_container(resolve=True)
print(f"Instantiating {model_name}...")
model = model_class(**model_args)
# Load state dict
print("Loading state dict...")
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
print("✓ Model loaded successfully")
# Commenting out excess print statement bc it uses input.wav
# print(f"Processing audio: {input_wav}")
sig = load_audio(input_audio_path)
# Move audio data to device
sig = sig.to(device)
# Add batch dimension if needed (Apollo expects [batch, channels, samples])
if sig.dim() == 2:
sig = sig.unsqueeze(0)
with torch.no_grad():
output = model(sig)
# Remove batch dimension
output = output.squeeze(0)
output_audio_path = os.path.join("src", "_outputs", "output_restored.wav")
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
torchaudio.save(output_audio_path, output, 44100)
print(f"✓ Saved output to {output_audio_path}")
return output_audio_path
# original export method
# save_audio(output_wav, out)
# print(f"✓ Saved output to {output_wav}")
# Build Gradio endpoint
with gr.Blocks() as demo:
# Define input Gradio Components
input_components = [
gr.Audio(type="filepath",
label="Input Audio A")
.harp_required(True),
]
# Define output Gradio Components
output_components = [
gr.Audio(type="filepath",
label="Output Audio")
.set_info("The restored audio."),
]
# Build a HARP-compatible endpoint
app = build_endpoint(
model_card=model_card,
input_components=input_components,
output_components=output_components,
process_fn=process_fn,
)
# run the thing
demo.queue().launch(share=True, show_error=False, pwa=True)
# original inference function run
'''
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Audio Inference Script")
parser.add_argument("--in_wav", type=str, required=True, help="Path to input wav file")
parser.add_argument("--out_wav", type=str, required=True, help="Path to output wav file")
args = parser.parse_args()
main(args.in_wav, args.out_wav)
'''