Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torchaudio | |
import torchaudio.transforms as T | |
import numpy as np | |
import traceback | |
import io | |
import time | |
# Attempt to import SNAC (should work if requirements.txt is correct) | |
try: | |
from snac import SNAC | |
print("SNAC module imported successfully.") | |
except ImportError as e: | |
print(f"Error importing SNAC: {e}") | |
# Raise a more informative error if SNAC isn't installed | |
raise ImportError("Could not import SNAC. Make sure 'snac' is listed in requirements.txt and installed correctly.") from e | |
# --- Configuration --- | |
TARGET_SR = 32000 # SNAC operates at 32kHz | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
# --- Load Model (Load once globally) --- | |
snac_model = None | |
try: | |
print("Loading SNAC model...") | |
start_time = time.time() | |
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_32khz") | |
snac_model = snac_model.to(DEVICE) | |
snac_model.eval() # Set model to evaluation mode | |
end_time = time.time() | |
print(f"SNAC model loaded successfully to {DEVICE}. Time taken: {end_time - start_time:.2f} seconds.") | |
except Exception as e: | |
print(f"FATAL: Error loading SNAC model: {e}") | |
print(traceback.format_exc()) | |
# If the model fails to load, the app can't function. | |
# Gradio will likely show an error, but we print specifics here. | |
# --- Main Processing Function --- | |
def process_audio(audio_filepath): | |
""" | |
Loads, resamples, encodes, decodes audio using SNAC, and returns results. | |
""" | |
if snac_model is None: | |
return None, None, None, "Error: SNAC model could not be loaded. Cannot process audio." | |
if audio_filepath is None: | |
return None, None, None, "Please upload an audio file." | |
logs = ["--- Starting Audio Processing ---"] | |
try: | |
# 1. Load Audio | |
logs.append(f"Loading audio file: {audio_filepath}") | |
load_start = time.time() | |
original_waveform, original_sr = torchaudio.load(audio_filepath) | |
load_end = time.time() | |
logs.append(f"Audio loaded. Original SR: {original_sr} Hz, Shape: {original_waveform.shape}, Time: {load_end - load_start:.2f}s") | |
# Ensure float32 | |
original_waveform = original_waveform.to(dtype=torch.float32) | |
# Handle multi-channel audio: Use the first channel | |
if original_waveform.shape[0] > 1: | |
logs.append(f"Warning: Input audio has {original_waveform.shape[0]} channels. Using only the first channel.") | |
original_waveform = original_waveform[0:1, :] # Keep channel dim for consistency initially | |
# --- Prepare Original for Playback --- | |
# Gradio Audio component expects (sample_rate, numpy_array) | |
# Ensure numpy array is 1D or 2D [channels, samples] | |
original_audio_playback = (original_sr, original_waveform.squeeze().numpy()) # Squeeze removes channel dim if 1 | |
logs.append("Prepared original audio for playback.") | |
# 2. Resample if necessary | |
resample_start = time.time() | |
if original_sr != TARGET_SR: | |
logs.append(f"Resampling waveform from {original_sr} Hz to {TARGET_SR} Hz...") | |
resampler = T.Resample(orig_freq=original_sr, new_freq=TARGET_SR).to(original_waveform.device) # Resampler on same device | |
waveform_to_encode = resampler(original_waveform) | |
logs.append(f"Resampling complete. New Shape: {waveform_to_encode.shape}") | |
else: | |
logs.append("Waveform is already at the target sample rate (32kHz).") | |
waveform_to_encode = original_waveform | |
resample_end = time.time() | |
logs.append(f"Resampling time: {resample_end - resample_start:.2f}s") | |
# --- Prepare Resampled for Playback --- | |
resampled_audio_playback = (TARGET_SR, waveform_to_encode.squeeze().numpy()) | |
logs.append("Prepared resampled audio for playback.") | |
# 3. Prepare for SNAC Encoding (add batch dim, move to device) | |
# Input should be [Batch, Channel, Time] = [1, 1, Time] | |
# waveform_to_encode should currently be [1, Time] after channel selection/resampling | |
waveform_batch = waveform_to_encode.unsqueeze(0).to(DEVICE) # Add batch dimension -> [1, 1, Time] | |
logs.append(f"Waveform prepared for encoding. Shape: {waveform_batch.shape}, Device: {DEVICE}") | |
# 4. Encode Audio using SNAC | |
logs.append("Encoding audio with snac_model.encode()...") | |
encode_start = time.time() | |
with torch.inference_mode(): | |
codes = snac_model.encode(waveform_batch) | |
encode_end = time.time() | |
if not codes or not all(isinstance(c, torch.Tensor) for c in codes): | |
log_msg = f"Encoding failed: Expected list of Tensors, but got: {type(codes)}" | |
if isinstance(codes, list): | |
log_msg += f" with types {[type(c) for c in codes]}" | |
logs.append(log_msg) | |
raise ValueError(log_msg) | |
logs.append(f"Encoding complete. Received {len(codes)} code layers. Time: {encode_end - encode_start:.2f}s") | |
for i, layer_codes in enumerate(codes): | |
logs.append(f" Layer {i+1} codes shape: {layer_codes.shape}, Device: {layer_codes.device}") | |
# 5. Decode the Codes using SNAC | |
logs.append("Decoding the generated codes with snac_model.decode()...") | |
decode_start = time.time() | |
with torch.inference_mode(): | |
reconstructed_waveform = snac_model.decode(codes) # codes are already on DEVICE | |
decode_end = time.time() | |
logs.append(f"Decoding complete. Reconstructed waveform shape: {reconstructed_waveform.shape}, Device: {reconstructed_waveform.device}. Time: {decode_end - decode_start:.2f}s") | |
# 6. Prepare Reconstructed Audio for Playback | |
# Output is [Batch, 1, Time]. Move to CPU, remove Batch/Channel, convert to NumPy. | |
reconstructed_audio_np = reconstructed_waveform.cpu().squeeze().numpy() # Squeeze removes Batch and Channel dims | |
logs.append(f"Reconstructed audio prepared for playback. Shape: {reconstructed_audio_np.shape}") | |
reconstructed_audio_playback = (TARGET_SR, reconstructed_audio_np) | |
logs.append("\n--- Audio Processing Completed Successfully ---") | |
return original_audio_playback, resampled_audio_playback, reconstructed_audio_playback, "\n".join(logs) | |
except Exception as e: | |
logs.append("\n--- An Error Occurred ---") | |
logs.append(f"Error Type: {type(e).__name__}") | |
logs.append(f"Error Details: {e}") | |
logs.append("\n--- Traceback ---") | |
logs.append(traceback.format_exc()) | |
# Return None for audio components on error, and the detailed log | |
return None, None, None, "\n".join(logs) | |
# --- Gradio Interface --- | |
DESCRIPTION = """ | |
This Space demonstrates the **SNAC (Scalable Neural Audio Codec)** model (`hubertsiuzdak/snac_32khz`). | |
1. Upload an audio file (wav, mp3, flac, etc.). | |
2. The audio will be automatically resampled to 32kHz if needed. | |
3. The 32kHz audio is encoded into discrete codes by SNAC. | |
4. These codes are then decoded back into audio by SNAC. | |
5. You can listen to the original, the 32kHz version (if resampled), and the final reconstructed audio. | |
**Note:** Processing happens on the server. Larger files will take longer. If the input is stereo, only the first channel is processed. | |
""" | |
iface = gr.Interface( | |
fn=process_audio, | |
inputs=gr.Audio(type="filepath", label="Upload Audio File"), | |
outputs=[ | |
gr.Audio(label="Original Audio"), | |
gr.Audio(label="Resampled Audio (32kHz Input to SNAC)"), | |
gr.Audio(label="Reconstructed Audio (Output from SNAC)"), | |
gr.Textbox(label="Log Output", lines=15) | |
], | |
title="SNAC Audio Codec Demo (32kHz)", | |
description=DESCRIPTION, | |
examples=[ | |
# Add paths to example audio files if you upload some to your Space repo | |
# ["examples/example1.wav"], | |
# ["examples/example2.mp3"], | |
], | |
cache_examples=False # Disable caching if examples change or have issues | |
) | |
if __name__ == "__main__": | |
if snac_model is None: | |
print("Cannot launch Gradio interface because SNAC model failed to load.") | |
else: | |
print("Launching Gradio Interface...") | |
iface.launch() |