OfficerRaccoon's picture
Update app.py
9e8ea93 verified
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import numpy as np
import pickle
import tempfile
import os
# Your model architecture (same as before)
class ImprovedBirdSoundCNN(nn.Module):
def __init__(self, num_classes, dropout_rate=0.3):
super(ImprovedBirdSoundCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(64)
self.bn2 = nn.BatchNorm2d(64)
self.bn3 = nn.BatchNorm2d(128)
self.bn4 = nn.BatchNorm2d(128)
self.bn5 = nn.BatchNorm2d(256)
self.bn6 = nn.BatchNorm2d(256)
self.pool = nn.MaxPool2d(2, 2)
self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
self.dropout = nn.Dropout(dropout_rate)
self.fc1 = nn.Linear(256 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, num_classes)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool(x)
x = self.dropout(x)
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = self.pool(x)
x = self.dropout(x)
x = F.relu(self.bn5(self.conv5(x)))
x = F.relu(self.bn6(self.conv6(x)))
x = self.adaptive_pool(x)
x = self.dropout(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
@st.cache_resource
def load_model_and_encoder():
"""Load model and label encoder with size mismatch handling"""
device = torch.device('cpu')
try:
# Load label encoder
with open('label_encoder.pkl', 'rb') as f:
label_encoder = pickle.load(f)
num_classes = len(label_encoder.classes_)
print(f"Label encoder has {num_classes} classes")
# Load checkpoint first to check its structure
checkpoint = torch.load('best_bird_model_extended.pth', map_location=device, weights_only=False)
# Initialize model with current number of classes
model = ImprovedBirdSoundCNN(num_classes=107)
current_model_dict = model.state_dict()
# Get the saved state dict
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
saved_state_dict = checkpoint['model_state_dict']
else:
saved_state_dict = checkpoint
# Filter out mismatched layers (fc3 layer)
filtered_dict = {}
for k, v in saved_state_dict.items():
if k in current_model_dict:
if v.size() == current_model_dict[k].size():
filtered_dict[k] = v
else:
print(f"Skipping {k}: checkpoint {v.size()} vs model {current_model_dict[k].size()}")
else:
print(f"Parameter {k} not found in current model")
# Update model dict with compatible weights
current_model_dict.update(filtered_dict)
# Load the filtered state dict
model.load_state_dict(current_model_dict)
model.eval()
print(f"Model loaded successfully with {len(filtered_dict)} compatible layers")
return model, label_encoder, device
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
# ------------------------------------------------------------------------------------------------------------------------------
# def preprocess_audio(audio_file, sample_rate=22050, duration=5):
# """Preprocess audio for prediction"""
# try:
# # Load audio
# waveform, sr = torchaudio.load(audio_file)
# # Resample if necessary
# if sr != sample_rate:
# resampler = T.Resample(sr, sample_rate)
# waveform = resampler(waveform)
# # Convert to mono
# if waveform.shape[0] > 1:
# waveform = torch.mean(waveform, dim=0, keepdim=True)
# # Normalize
# waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
# # Pad or trim
# target_length = sample_rate * duration
# if waveform.shape[1] > target_length:
# start = (waveform.shape[1] - target_length) // 2
# waveform = waveform[:, start:start + target_length]
# else:
# padding = target_length - waveform.shape[1]
# waveform = torch.nn.functional.pad(waveform, (0, padding))
# # Create spectrogram
# mel_transform = T.MelSpectrogram(
# sample_rate=sample_rate,
# n_fft=2048,
# hop_length=512,
# n_mels=128,
# f_min=0,
# f_max=8000,
# window_fn=torch.hann_window,
# power=2.0
# )
# amplitude_to_db = T.AmplitudeToDB(stype='power', top_db=80)
# mel_spec = mel_transform(waveform)
# mel_spec_db = amplitude_to_db(mel_spec)
# mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
# return mel_spec_db.unsqueeze(0)
# except Exception as e:
# st.error(f"Error preprocessing audio: {str(e)}")
# return None
# ------------------------------------------------------------------------------------------------------------------------------
def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
"""Process audio using librosa - clean version without debug messages"""
import librosa
import numpy as np
tmp_file_path = None
try:
# Get the raw bytes from Streamlit uploaded file
audio_bytes = uploaded_file.getvalue()
# Create a unique temporary file path
import hashlib
file_hash = hashlib.md5(audio_bytes).hexdigest()[:8]
# Determine file extension from uploaded file name
file_ext = uploaded_file.name.split('.')[-1].lower()
tmp_file_path = f"/tmp/audio_{file_hash}.{file_ext}"
# Write bytes to temporary file
with open(tmp_file_path, 'wb') as f:
f.write(audio_bytes)
# Verify file was created
if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0:
return None
# Load audio with librosa (no debug messages)
waveform, sr = librosa.load(tmp_file_path, sr=sample_rate, duration=duration)
# Convert numpy array to torch tensor
waveform = torch.from_numpy(waveform).float()
# Add channel dimension
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
# Normalize audio
max_val = torch.max(torch.abs(waveform))
if max_val > 0:
waveform = waveform / max_val
# Ensure exact duration
target_length = sample_rate * duration
current_length = waveform.shape[1]
if current_length > target_length:
start = (current_length - target_length) // 2
waveform = waveform[:, start:start + target_length]
elif current_length < target_length:
padding = target_length - current_length
waveform = torch.nn.functional.pad(waveform, (0, padding))
# Create mel spectrogram
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=2048,
hop_length=512,
n_mels=128,
f_min=0,
f_max=8000,
window_fn=torch.hann_window,
power=2.0
)
amplitude_to_db = T.AmplitudeToDB(stype='power', top_db=80)
# Generate spectrogram
mel_spec = mel_transform(waveform)
mel_spec_db = amplitude_to_db(mel_spec)
# Normalize spectrogram
mean_val = mel_spec_db.mean()
std_val = mel_spec_db.std()
if std_val > 0:
mel_spec_db = (mel_spec_db - mean_val) / std_val
# Clean up temp file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
return mel_spec_db.unsqueeze(0)
except Exception as e:
# Only show error, not debug info
st.error(f"❌ Failed to process audio file")
# Clean up on error
try:
if tmp_file_path and os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except:
pass
return None
def predict_bird_species(model, spectrogram, label_encoder, device):
"""Make prediction on spectrogram"""
try:
spectrogram = spectrogram.to(device)
with torch.no_grad():
outputs = model(spectrogram)
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted = torch.max(probabilities, 1)
predicted_species = label_encoder.inverse_transform([predicted.item()])[0]
confidence_score = confidence.item()
# Get top 3 predictions
top3_probs, top3_indices = torch.topk(probabilities, 3, dim=1)
top3_species = []
for i in range(3):
species = label_encoder.inverse_transform([top3_indices[0][i].item()])[0]
prob = top3_probs[0][i].item()
top3_species.append((species, prob))
return predicted_species, confidence_score, top3_species
except Exception as e:
st.error(f"Error making prediction: {str(e)}")
return None, None, None
def main():
st.set_page_config(
page_title="Bird Sound Classifier",
page_icon="🐦",
layout="wide"
)
st.title("🐦 AI Bird Sound Classifier")
st.markdown("### Upload a bird audio recording to identify the species!")
st.markdown("**Trained on 110+ species with 80% accuracy**")
# Sidebar
st.sidebar.header("🌿 About This App")
st.sidebar.info(
"This AI model identifies bird species from audio recordings using "
"deep learning on spectrograms. Perfect for conservation efforts!"
)
st.sidebar.header("πŸ“‹ Instructions")
st.sidebar.markdown(
"""
1. Upload an audio file (.mp3, .wav)
2. Click 'Identify Bird Species'
3. View predictions and confidence scores
4. Check alternative species suggestions
"""
)
# Load model
model, label_encoder, device = load_model_and_encoder()
if model is None:
st.error("❌ Failed to load model. Please check the model files.")
st.stop()
st.success("βœ… Model loaded successfully!")
# File upload
uploaded_file = st.file_uploader(
"Choose an audio file",
type=['mp3', 'wav', 'flac'],
help="Upload a bird sound recording (first 5 seconds will be analyzed)"
)
if uploaded_file is not None:
# Display file info
col1, col2 = st.columns(2)
with col1:
st.write("**πŸ“ File Details:**")
st.write(f"β€’ Name: {uploaded_file.name}")
st.write(f"β€’ Size: {uploaded_file.size:,} bytes")
with col2:
st.write("**🎡 Audio Player:**")
st.audio(uploaded_file, format='audio/wav')
# Prediction button
# Prediction button
if st.button("πŸ” Identify Bird Species", type="primary", use_container_width=True):
with st.spinner("πŸ”„ Processing audio and making prediction..."):
try:
# Process audio using librosa (more reliable)
spectrogram = preprocess_audio(uploaded_file)
if spectrogram is not None:
predicted_species, confidence, top3_predictions = predict_bird_species(
model, spectrogram, label_encoder, device
)
# Display results
if predicted_species is not None:
st.success("πŸŽ‰ Prediction Complete!")
# Main prediction
st.subheader("πŸ† Primary Prediction")
clean_species = predicted_species.replace("_sound", "").replace("_", " ")
col1, col2 = st.columns([2, 1])
with col1:
st.metric(
label="Predicted Species",
value=clean_species,
delta=f"{confidence:.1%} confidence"
)
with col2:
if confidence > 0.8:
st.success("🎯 High Confidence")
elif confidence > 0.6:
st.warning("⚠️ Moderate Confidence")
else:
st.info("πŸ’­ Low Confidence")
# Top 3 predictions
st.subheader("πŸ“Š Alternative Predictions")
for i, (species, prob) in enumerate(top3_predictions):
clean_name = species.replace("_sound", "").replace("_", " ")
st.write(f"**{i+1}.** {clean_name}")
st.progress(prob)
st.caption(f"Confidence: {prob:.1%}")
# Conservation note
st.subheader("🌿 Conservation Impact")
st.info(
f"Identifying '{clean_species}' helps with biodiversity monitoring "
"and conservation efforts in national parks and protected areas."
)
else:
st.error("❌ Failed to process audio file.")
except Exception as e:
st.error(f"❌ Error processing audio: {str(e)}")
# Footer
st.markdown("---")
st.markdown(
"**🌍 Built for Conservation** | "
"This tool supports wildlife monitoring and biodiversity research."
)
if __name__ == "__main__":
main()