Spaces:
Sleeping
Sleeping
File size: 4,798 Bytes
dc6c835 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import streamlit as st
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision.models import vit_b_16
import torch.nn as nn
from PIL import Image
import pickle
import os
# Set page config
st.set_page_config(
page_title="Baby Cry Analyzer",
page_icon="👶",
layout="wide"
)
# Custom CSS
st.markdown("""
<style>
.main {
padding: 2rem;
}
.stAlert {
margin-top: 1rem;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
try:
# Force CPU device
device = torch.device('cpu')
# Load the model from pickle file with CPU mapping
with open("baby_cry_model.pkl", "rb") as f:
# Convert CUDA tensors to CPU during unpickling
model_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in pickle.load(f).items()}
# Initialize model architecture
model = vit_b_16(pretrained=True)
num_classes = 3 # Adjust based on your actual number of classes
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
# Load state dict with CPU mapping
model = model.to(device)
model.load_state_dict(model_state_dict)
model.eval()
return model, device
except Exception as e:
st.error(f"""
Error loading model. Make sure the model file exists and is accessible.
If this error persists, the model might need to be re-saved for CPU compatibility.
Technical details: {str(e)}
""")
raise e
def create_spectrogram(audio_file):
# Create spectrogram
y, sr = librosa.load(audio_file, sr=22050)
mel_spec = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128)
mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
plt.figure(figsize=(5, 5))
librosa.display.specshow(mel_spec_db, sr=sr, x_axis="time", y_axis="mel")
plt.axis("off")
# Save spectrogram
temp_path = "temp_spectrogram.png"
plt.savefig(temp_path, bbox_inches="tight", pad_inches=0)
plt.close()
return temp_path
def classify_audio(model, device, spectrogram_path):
# Prepare image for classification
img = Image.open(spectrogram_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
img = transform(img).unsqueeze(0).to(device)
# Classify
with torch.no_grad():
output = model(img)
predicted_class = torch.argmax(output, dim=1).item()
probabilities = torch.nn.functional.softmax(output[0], dim=0)
return predicted_class, probabilities
def main():
st.title("👶 Baby Cry Analyzer")
st.write("Upload a WAV file to analyze the type of baby cry")
# Load model
try:
model, device = load_model()
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# File upload
audio_file = st.file_uploader("Choose a WAV file", type=['wav'])
if audio_file is not None:
st.audio(audio_file)
with st.spinner("Analyzing audio..."):
# Create and display spectrogram
spec_path = create_spectrogram(audio_file)
st.image(spec_path, caption="Generated Spectrogram", width=300)
# Classify
predicted_class, probabilities = classify_audio(model, device, spec_path)
# Display results
classes = ['Belly Pain', 'Hungry', 'Tired'] # Adjust based on your classes
st.subheader("Classification Results:")
# Display prediction with confidence
col1, col2 = st.columns(2)
with col1:
st.metric("Predicted Cry Type", classes[predicted_class])
with col2:
confidence = float(probabilities[predicted_class]) * 100
st.metric("Confidence", f"{confidence:.2f}%")
# Show all probabilities
st.subheader("Probability Distribution:")
for cls, prob in zip(classes, probabilities):
st.write(f"{cls}: {float(prob)*100:.2f}%")
# Cleanup
if os.path.exists(spec_path):
os.remove(spec_path)
if __name__ == "__main__":
main()
|