webdeveloperdev's picture
Update app.py
438a416 verified
Raw
History Blame Contribute Delete
3.06 kB
import gradio as gr
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import librosa
from tensorflow.keras.models import load_model
# --- 1. Load Your Models (On Startup) ---
print("Loading YAMNet model...")
YAMNET_MODEL_PATH = 'https://tfhub.dev/google/yamnet/1'
yamnet_model = hub.load(YAMNET_MODEL_PATH)
print("YAMNet loaded.")
print("Loading custom classifier model...")
SAVED_KERAS_MODEL_PATH = "emergency_audio_classifier.keras"
classifier_model = load_model(SAVED_KERAS_MODEL_PATH)
print("Custom classifier loaded.")
# --- 2. Define Constants ---
SAMPLE_RATE = 16000 # The sample rate YAMNet expects
# --- 3. Define The Classification Function ---
#
# !!! --- THIS FUNCTION IS NOW UPDATED --- !!!
#
def classify_audio(filepath):
"""
Takes an audio FILEPATH from Gradio, processes it, and returns predictions.
"""
if filepath is None:
print("Error: No audio file provided")
return {"Error: No audio provided": 1.0}
print(f"[DEBUG] Received file: {filepath}")
try:
# --- THE FIX ---
# We now use librosa.load() directly on the filepath,
# just like in run_classifier.py. This handles all
# resampling, normalization, and mono conversion for us.
audio_data, _ = librosa.load(filepath,
sr=SAMPLE_RATE,
mono=True)
if len(audio_data) < 100:
print("Error: Audio is too short after processing")
return {"Error: Audio is too short": 1.0}
print(f"[DEBUG] Audio loaded and processed: {len(audio_data)} samples")
# --- Run the classification ---
_, embeddings, _ = yamnet_model(audio_data)
mean_embedding = np.mean(embeddings.numpy(), axis=0)
audio_embedding = np.expand_dims(mean_embedding, axis=0)
prediction_prob = classifier_model.predict(audio_embedding)[0][0]
probability_distress = float(prediction_prob)
probability_normal = 1.0 - probability_distress
output_labels = {
"Distress": probability_distress,
"Normal": probability_normal
}
print(f"[DEBUG] Prediction: {output_labels}")
return output_labels
except Exception as e:
print(f"Error during classification: {e}")
return {f"Error: {str(e)}": 1.0}
# --- 4. Create the Gradio Interface ---
#
# !!! --- THIS PART IS NOW UPDATED --- !!!
#
iface = gr.Interface(
fn=classify_audio,
inputs=gr.Audio(
sources=["microphone", "upload"],
# Change type from "numpy" to "filepath"
type="filepath",
label="Record Audio or Upload File"
),
outputs=gr.Label(
num_top_classes=2,
label="Prediction"
),
title="🚨 Emergency Audio Detection Module",
description="A two-stage (YAMNet + Custom) model to classify audio as 'Distress' or 'Normal'."
)
# --- 5. Launch the App ---
if __name__ == "__main__":
iface.launch()