File size: 3,057 Bytes
438a416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bb48de
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import librosa
from tensorflow.keras.models import load_model

# --- 1. Load Your Models (On Startup) ---
print("Loading YAMNet model...")
YAMNET_MODEL_PATH = 'https://tfhub.dev/google/yamnet/1'
yamnet_model = hub.load(YAMNET_MODEL_PATH)
print("YAMNet loaded.")

print("Loading custom classifier model...")
SAVED_KERAS_MODEL_PATH = "emergency_audio_classifier.keras"
classifier_model = load_model(SAVED_KERAS_MODEL_PATH)
print("Custom classifier loaded.")

# --- 2. Define Constants ---
SAMPLE_RATE = 16000 # The sample rate YAMNet expects

# --- 3. Define The Classification Function ---
#
# !!! --- THIS FUNCTION IS NOW UPDATED --- !!!
#
def classify_audio(filepath):
    """
    Takes an audio FILEPATH from Gradio, processes it, and returns predictions.
    """
    if filepath is None:
        print("Error: No audio file provided")
        return {"Error: No audio provided": 1.0}
        
    print(f"[DEBUG] Received file: {filepath}")

    try:
        # --- THE FIX ---
        # We now use librosa.load() directly on the filepath,
        # just like in run_classifier.py. This handles all
        # resampling, normalization, and mono conversion for us.
        audio_data, _ = librosa.load(filepath, 
                                     sr=SAMPLE_RATE, 
                                     mono=True)
        
        if len(audio_data) < 100:
            print("Error: Audio is too short after processing")
            return {"Error: Audio is too short": 1.0}

        print(f"[DEBUG] Audio loaded and processed: {len(audio_data)} samples")

        # --- Run the classification ---
        _, embeddings, _ = yamnet_model(audio_data)
        mean_embedding = np.mean(embeddings.numpy(), axis=0)
        audio_embedding = np.expand_dims(mean_embedding, axis=0)
        
        prediction_prob = classifier_model.predict(audio_embedding)[0][0]
        
        probability_distress = float(prediction_prob)
        probability_normal = 1.0 - probability_distress
        
        output_labels = {
            "Distress": probability_distress,
            "Normal": probability_normal
        }
        
        print(f"[DEBUG] Prediction: {output_labels}")
        return output_labels
        
    except Exception as e:
        print(f"Error during classification: {e}")
        return {f"Error: {str(e)}": 1.0}

# --- 4. Create the Gradio Interface ---
#
# !!! --- THIS PART IS NOW UPDATED --- !!!
#
iface = gr.Interface(
    fn=classify_audio,
    inputs=gr.Audio(
        sources=["microphone", "upload"],
        # Change type from "numpy" to "filepath"
        type="filepath",  
        label="Record Audio or Upload File"
    ),
    outputs=gr.Label(
        num_top_classes=2,
        label="Prediction"
    ),
    title="🚨 Emergency Audio Detection Module",
    description="A two-stage (YAMNet + Custom) model to classify audio as 'Distress' or 'Normal'."
)

# --- 5. Launch the App ---
if __name__ == "__main__":
    iface.launch()