import streamlit as st import librosa import librosa.display import matplotlib.pyplot as plt import numpy as np import torch import torchvision.transforms as transforms from torchvision.models import vit_b_16 import torch.nn as nn from PIL import Image import pickle import os # Set page config st.set_page_config( page_title="Baby Cry Analyzer", page_icon="👶", layout="wide" ) # Custom CSS st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): try: # Force CPU device device = torch.device('cpu') # Load the model from pickle file with CPU mapping with open("baby_cry_model.pkl", "rb") as f: # Convert CUDA tensors to CPU during unpickling model_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in pickle.load(f).items()} # Initialize model architecture model = vit_b_16(pretrained=True) num_classes = 3 # Adjust based on your actual number of classes model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) # Load state dict with CPU mapping model = model.to(device) model.load_state_dict(model_state_dict) model.eval() return model, device except Exception as e: st.error(f""" Error loading model. Make sure the model file exists and is accessible. If this error persists, the model might need to be re-saved for CPU compatibility. Technical details: {str(e)} """) raise e def create_spectrogram(audio_file): # Create spectrogram y, sr = librosa.load(audio_file, sr=22050) mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128) mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) plt.figure(figsize=(5, 5)) librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel") plt.axis("off") # Save spectrogram temp_path = "temp_spectrogram.png" plt.savefig(temp_path, bbox_inches="tight", pad_inches=0) plt.close() return temp_path def classify_audio(model, device, spectrogram_path): # Prepare image for classification img = Image.open(spectrogram_path).convert("RGB") transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) img = transform(img).unsqueeze(0).to(device) # Classify with torch.no_grad(): output = model(img) predicted_class = torch.argmax(output, dim=1).item() probabilities = torch.nn.functional.softmax(output[0], dim=0) return predicted_class, probabilities def main(): st.title("👶 Baby Cry Analyzer") st.write("Upload a WAV file to analyze the type of baby cry") # Load model try: model, device = load_model() st.success("Model loaded successfully!") except Exception as e: st.error(f"Error loading model: {str(e)}") return # File upload audio_file = st.file_uploader("Choose a WAV file", type=['wav']) if audio_file is not None: st.audio(audio_file) with st.spinner("Analyzing audio..."): # Create and display spectrogram spec_path = create_spectrogram(audio_file) st.image(spec_path, caption="Generated Spectrogram", width=300) # Classify predicted_class, probabilities = classify_audio(model, device, spec_path) # Display results classes = ['Belly Pain', 'Hungry', 'Tired'] # Adjust based on your classes st.subheader("Classification Results:") # Display prediction with confidence col1, col2 = st.columns(2) with col1: st.metric("Predicted Cry Type", classes[predicted_class]) with col2: confidence = float(probabilities[predicted_class]) * 100 st.metric("Confidence", f"{confidence:.2f}%") # Show all probabilities st.subheader("Probability Distribution:") for cls, prob in zip(classes, probabilities): st.write(f"{cls}: {float(prob)*100:.2f}%") # Cleanup if os.path.exists(spec_path): os.remove(spec_path) if __name__ == "__main__": main()