File size: 4,798 Bytes
dc6c835
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import streamlit as st
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import vit_b_16
import torch.nn as nn
from PIL import Image
import pickle
import os

# Set page config
st.set_page_config(
    page_title="Baby Cry Analyzer",
    page_icon="👶",
    layout="wide"
)

# Custom CSS
st.markdown("""

    <style>

    .main {

        padding: 2rem;

    }

    .stAlert {

        margin-top: 1rem;

    }

    </style>

    """, unsafe_allow_html=True)

@st.cache_resource
def load_model():
    try:
        # Force CPU device
        device = torch.device('cpu')
        
        # Load the model from pickle file with CPU mapping
        with open("baby_cry_model.pkl", "rb") as f:
            # Convert CUDA tensors to CPU during unpickling
            model_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
                              for k, v in pickle.load(f).items()}
        
        # Initialize model architecture
        model = vit_b_16(pretrained=True)
        num_classes = 3  # Adjust based on your actual number of classes
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
        
        # Load state dict with CPU mapping
        model = model.to(device)
        model.load_state_dict(model_state_dict)
        model.eval()
        
        return model, device
    except Exception as e:
        st.error(f"""

            Error loading model. Make sure the model file exists and is accessible.

            If this error persists, the model might need to be re-saved for CPU compatibility.

            Technical details: {str(e)}

        """)
        raise e

def create_spectrogram(audio_file):
    # Create spectrogram
    y, sr = librosa.load(audio_file, sr=22050)
    mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    
    plt.figure(figsize=(5, 5))
    librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel")
    plt.axis("off")
    
    # Save spectrogram
    temp_path = "temp_spectrogram.png"
    plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
    plt.close()
    
    return temp_path

def classify_audio(model, device, spectrogram_path):
    # Prepare image for classification
    img = Image.open(spectrogram_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    img = transform(img).unsqueeze(0).to(device)
    
    # Classify
    with torch.no_grad():
        output = model(img)
        predicted_class = torch.argmax(output, dim=1).item()
        probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    return predicted_class, probabilities

def main():
    st.title("👶 Baby Cry Analyzer")
    st.write("Upload a WAV file to analyze the type of baby cry")
    
    # Load model
    try:
        model, device = load_model()
        st.success("Model loaded successfully!")
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        return
    
    # File upload
    audio_file = st.file_uploader("Choose a WAV file", type=['wav'])
    
    if audio_file is not None:
        st.audio(audio_file)
        
        with st.spinner("Analyzing audio..."):
            # Create and display spectrogram
            spec_path = create_spectrogram(audio_file)
            st.image(spec_path, caption="Generated Spectrogram", width=300)
            
            # Classify
            predicted_class, probabilities = classify_audio(model, device, spec_path)
            
            # Display results
            classes = ['Belly Pain', 'Hungry', 'Tired']  # Adjust based on your classes
            st.subheader("Classification Results:")
            
            # Display prediction with confidence
            col1, col2 = st.columns(2)
            with col1:
                st.metric("Predicted Cry Type", classes[predicted_class])
            with col2:
                confidence = float(probabilities[predicted_class]) * 100
                st.metric("Confidence", f"{confidence:.2f}%")
            
            # Show all probabilities
            st.subheader("Probability Distribution:")
            for cls, prob in zip(classes, probabilities):
                st.write(f"{cls}: {float(prob)*100:.2f}%")
            
            # Cleanup
            if os.path.exists(spec_path):
                os.remove(spec_path)

if __name__ == "__main__":
    main()