Spaces:
Sleeping
Sleeping
import streamlit as st | |
import librosa | |
import librosa.display | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import torchvision.transforms as transforms | |
from torchvision.models import vit_b_16 | |
import torch.nn as nn | |
from PIL import Image | |
import pickle | |
import os | |
# Set page config | |
st.set_page_config( | |
page_title="Baby Cry Analyzer", | |
page_icon="👶", | |
layout="wide" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
.stAlert { | |
margin-top: 1rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
def load_model(): | |
try: | |
# Force CPU device | |
device = torch.device('cpu') | |
# Load the model from pickle file with CPU mapping | |
with open("baby_cry_model.pkl", "rb") as f: | |
# Convert CUDA tensors to CPU during unpickling | |
model_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v | |
for k, v in pickle.load(f).items()} | |
# Initialize model architecture | |
model = vit_b_16(pretrained=True) | |
num_classes = 3 # Adjust based on your actual number of classes | |
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) | |
# Load state dict with CPU mapping | |
model = model.to(device) | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
return model, device | |
except Exception as e: | |
st.error(f""" | |
Error loading model. Make sure the model file exists and is accessible. | |
If this error persists, the model might need to be re-saved for CPU compatibility. | |
Technical details: {str(e)} | |
""") | |
raise e | |
def create_spectrogram(audio_file): | |
# Create spectrogram | |
y, sr = librosa.load(audio_file, sr=22050) | |
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128) | |
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max) | |
plt.figure(figsize=(5, 5)) | |
librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel") | |
plt.axis("off") | |
# Save spectrogram | |
temp_path = "temp_spectrogram.png" | |
plt.savefig(temp_path, bbox_inches="tight", pad_inches=0) | |
plt.close() | |
return temp_path | |
def classify_audio(model, device, spectrogram_path): | |
# Prepare image for classification | |
img = Image.open(spectrogram_path).convert("RGB") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5], std=[0.5]) | |
]) | |
img = transform(img).unsqueeze(0).to(device) | |
# Classify | |
with torch.no_grad(): | |
output = model(img) | |
predicted_class = torch.argmax(output, dim=1).item() | |
probabilities = torch.nn.functional.softmax(output[0], dim=0) | |
return predicted_class, probabilities | |
def main(): | |
st.title("👶 Baby Cry Analyzer") | |
st.write("Upload a WAV file to analyze the type of baby cry") | |
# Load model | |
try: | |
model, device = load_model() | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
return | |
# File upload | |
audio_file = st.file_uploader("Choose a WAV file", type=['wav']) | |
if audio_file is not None: | |
st.audio(audio_file) | |
with st.spinner("Analyzing audio..."): | |
# Create and display spectrogram | |
spec_path = create_spectrogram(audio_file) | |
st.image(spec_path, caption="Generated Spectrogram", width=300) | |
# Classify | |
predicted_class, probabilities = classify_audio(model, device, spec_path) | |
# Display results | |
classes = ['Belly Pain', 'Hungry', 'Tired'] # Adjust based on your classes | |
st.subheader("Classification Results:") | |
# Display prediction with confidence | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Predicted Cry Type", classes[predicted_class]) | |
with col2: | |
confidence = float(probabilities[predicted_class]) * 100 | |
st.metric("Confidence", f"{confidence:.2f}%") | |
# Show all probabilities | |
st.subheader("Probability Distribution:") | |
for cls, prob in zip(classes, probabilities): | |
st.write(f"{cls}: {float(prob)*100:.2f}%") | |
# Cleanup | |
if os.path.exists(spec_path): | |
os.remove(spec_path) | |
if __name__ == "__main__": | |
main() | |