shuka-demo / app.py
gunnerforlife52's picture
Fix: Enforce 30-second audio limit for Whisper mel features
62ebebb
import os
import gradio as gr
import transformers
import numpy as np
import librosa
import spaces
# ---------------------------
# Quiet OpenMP noise on Spaces
# ---------------------------
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
# ---------------------------
# Model config
# ---------------------------
MODEL_ID = "sarvamai/shuka_v1"
TARGET_SR = 16000 # Shuka uses 16k audio
# ---------------------------
# Global pipeline (lazy-loaded)
# ---------------------------
pipe = None
def load_model():
"""Load the Shuka v1 pipeline (8.73B)."""
global pipe
if pipe is not None:
return "βœ… Model already loaded!"
try:
print(f"Loading Shuka model: {MODEL_ID}")
pipe = transformers.pipeline(
model=MODEL_ID,
trust_remote_code=True, # required for Shuka custom pipeline
device_map="auto", # Use auto device mapping for HF Spaces
torch_dtype="bfloat16",
)
print("βœ… Pipeline loaded successfully!")
return "βœ… Model pipeline loaded successfully!"
except Exception as e:
import traceback
err = f"❌ Error loading model: {e}\n\n{traceback.format_exc()}"
print(err)
return err
# ---------------------------
# Audio utilities
# ---------------------------
def load_audio_from_gradio(audio_input):
"""
Supports both gr.Audio types:
- type="numpy" -> (sample_rate, np.ndarray)
- type="filepath" -> "/tmp/....wav"
Returns (audio: float32 mono @ 16k, sr: int)
"""
if isinstance(audio_input, tuple):
sr, audio = audio_input
elif isinstance(audio_input, str):
# Read from tmp filepath
audio, sr = librosa.load(audio_input, sr=None)
else:
raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
# Ensure float32 ndarray
audio = np.asarray(audio, dtype=np.float32)
# Stereo -> mono
if audio.ndim > 1:
audio = np.mean(audio, axis=1)
# Trim leading/trailing silence (conservative)
audio, _ = librosa.effects.trim(audio, top_db=30)
# Remove DC offset
if audio.size:
audio = audio - float(np.mean(audio))
# Normalize peak to ~0.98 to improve quiet recordings
peak = float(np.max(np.abs(audio))) if audio.size else 0.0
if peak > 0:
audio = (0.98 / peak) * audio
# Resample to 16k
if sr != TARGET_SR:
audio = librosa.resample(audio, orig_sr=sr, target_sr=TARGET_SR)
sr = TARGET_SR
# CRITICAL: Whisper encoder has hard limit of 3000 mel features
# At 16kHz, this equals exactly 30 seconds (100 mel features/second)
max_sec = 30
if len(audio) / float(sr) > max_sec:
audio = audio[: int(max_sec * sr)]
return audio, sr
# ---------------------------
# Inference handler
# ---------------------------
@spaces.GPU
def analyze_audio(audio_file, system_prompt):
"""
System prompt contains analysis instructions.
Audio is processed using the <|audio|> placeholder token.
"""
global pipe
if pipe is None:
status = load_model()
if status.startswith("❌"):
return status
if audio_file is None:
return "❌ Please upload or record an audio file."
# Load & preprocess audio
try:
audio, sr = load_audio_from_gradio(audio_file)
except Exception as e:
return f"❌ Failed to read/process audio: {e}"
# Quick quality checks
dur = len(audio) / float(sr) if sr else 0
rms = float(np.sqrt(np.mean(audio**2))) if audio.size else 0.0
if dur < 1.0:
return "❌ Audio too short (<1s). Please upload a longer sample."
if rms < 1e-3:
return "❌ Audio extremely quiet. Increase mic gain or speak closer to the microphone."
sys_text = (system_prompt or "Respond naturally and informatively.").strip()
# Build turns: system message with user instructions + user message with audio token
turns = [
{"role": "system", "content": sys_text},
{"role": "user", "content": "<|audio|>"}
]
try:
out = pipe(
{"audio": audio, "turns": turns, "sampling_rate": sr},
max_new_tokens=512,
)
# Debug: print raw output
print(f"Raw output type: {type(out)}")
print(f"Raw output: {out}")
# Extract text from response
if isinstance(out, list) and len(out) > 0:
text = out[0].get("generated_text", str(out[0]))
elif isinstance(out, dict):
text = out.get("generated_text", str(out))
else:
text = str(out)
return f"βœ… Processed.\n\n{text}"
except Exception as e:
import traceback
error_details = traceback.format_exc()
print(f"Full error: {error_details}")
return f"❌ Inference error: {e}\n\nDetails:\n{error_details}"
# ---------------------------
# UI
# ---------------------------
startup_status = "⏳ Model loads on first request (8.73B parameters)."
with gr.Blocks(title="Shuka v1 (8.73B) β€” Audio Analyzer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎀 Shuka v1 (8.73B) β€” Audio Analyzer
Upload an audio file (or record) and provide **analysis instructions**.
The instructions tell the AI what to analyze in the audio using the `<|audio|>` token.
**Shuka** is a multilingual audio-language model with strong capabilities in **11 Indic languages** including Hindi, Bengali, Tamil, Telugu, Marathi, Gujarati, Kannada, Malayalam, Punjabi, Odia, and Assamese.
⚠️ **Note:** Audio is automatically capped at **30 seconds maximum** due to Whisper encoder constraints (3000 mel features limit). For best results, use clear, concise audio recordings.
""")
with gr.Row():
with gr.Column():
# For uploads, `filepath` is robust; mic also works.
audio_input = gr.Audio(
label="🎡 Upload or Record Audio",
sources=["upload", "microphone"],
type="filepath", # handler also supports numpy tuples
)
system_prompt = gr.Textbox(
label="🧠 Analysis Instructions (what should the AI analyze in the audio?)",
value="Respond naturally and informatively.",
lines=8,
max_lines=20,
)
submit_btn = gr.Button("πŸš€ Analyze", variant="primary")
with gr.Column():
output = gr.Markdown(
label="πŸ€– Model Response",
value=f"**Model Status:** {startup_status}",
)
submit_btn.click(
fn=analyze_audio,
inputs=[audio_input, system_prompt],
outputs=output,
)
if __name__ == "__main__":
demo.launch()