|
|
import streamlit as st |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
import torchaudio.transforms as T |
|
|
import numpy as np |
|
|
import pickle |
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
class ImprovedBirdSoundCNN(nn.Module): |
|
|
def __init__(self, num_classes, dropout_rate=0.3): |
|
|
super(ImprovedBirdSoundCNN, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) |
|
|
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
|
|
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1) |
|
|
self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1) |
|
|
self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1) |
|
|
|
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
self.bn2 = nn.BatchNorm2d(64) |
|
|
self.bn3 = nn.BatchNorm2d(128) |
|
|
self.bn4 = nn.BatchNorm2d(128) |
|
|
self.bn5 = nn.BatchNorm2d(256) |
|
|
self.bn6 = nn.BatchNorm2d(256) |
|
|
|
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4)) |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.fc1 = nn.Linear(256 * 4 * 4, 512) |
|
|
self.fc2 = nn.Linear(512, 256) |
|
|
self.fc3 = nn.Linear(256, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = F.relu(self.bn2(self.conv2(x))) |
|
|
x = self.pool(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = F.relu(self.bn3(self.conv3(x))) |
|
|
x = F.relu(self.bn4(self.conv4(x))) |
|
|
x = self.pool(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = F.relu(self.bn5(self.conv5(x))) |
|
|
x = F.relu(self.bn6(self.conv6(x))) |
|
|
x = self.adaptive_pool(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = x.view(x.size(0), -1) |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.dropout(x) |
|
|
x = F.relu(self.fc2(x)) |
|
|
x = self.dropout(x) |
|
|
x = self.fc3(x) |
|
|
|
|
|
return x |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_encoder(): |
|
|
"""Load model and label encoder with size mismatch handling""" |
|
|
device = torch.device('cpu') |
|
|
|
|
|
try: |
|
|
|
|
|
with open('label_encoder.pkl', 'rb') as f: |
|
|
label_encoder = pickle.load(f) |
|
|
|
|
|
num_classes = len(label_encoder.classes_) |
|
|
print(f"Label encoder has {num_classes} classes") |
|
|
|
|
|
|
|
|
checkpoint = torch.load('best_bird_model_extended.pth', map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
model = ImprovedBirdSoundCNN(num_classes=107) |
|
|
current_model_dict = model.state_dict() |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
|
saved_state_dict = checkpoint['model_state_dict'] |
|
|
else: |
|
|
saved_state_dict = checkpoint |
|
|
|
|
|
|
|
|
filtered_dict = {} |
|
|
for k, v in saved_state_dict.items(): |
|
|
if k in current_model_dict: |
|
|
if v.size() == current_model_dict[k].size(): |
|
|
filtered_dict[k] = v |
|
|
else: |
|
|
print(f"Skipping {k}: checkpoint {v.size()} vs model {current_model_dict[k].size()}") |
|
|
else: |
|
|
print(f"Parameter {k} not found in current model") |
|
|
|
|
|
|
|
|
current_model_dict.update(filtered_dict) |
|
|
|
|
|
|
|
|
model.load_state_dict(current_model_dict) |
|
|
model.eval() |
|
|
|
|
|
print(f"Model loaded successfully with {len(filtered_dict)} compatible layers") |
|
|
return model, label_encoder, device |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {str(e)}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess_audio(uploaded_file, sample_rate=22050, duration=5): |
|
|
"""Process audio using librosa - clean version without debug messages""" |
|
|
import librosa |
|
|
import numpy as np |
|
|
tmp_file_path = None |
|
|
|
|
|
try: |
|
|
|
|
|
audio_bytes = uploaded_file.getvalue() |
|
|
|
|
|
|
|
|
import hashlib |
|
|
file_hash = hashlib.md5(audio_bytes).hexdigest()[:8] |
|
|
|
|
|
|
|
|
file_ext = uploaded_file.name.split('.')[-1].lower() |
|
|
tmp_file_path = f"/tmp/audio_{file_hash}.{file_ext}" |
|
|
|
|
|
|
|
|
with open(tmp_file_path, 'wb') as f: |
|
|
f.write(audio_bytes) |
|
|
|
|
|
|
|
|
if not os.path.exists(tmp_file_path) or os.path.getsize(tmp_file_path) == 0: |
|
|
return None |
|
|
|
|
|
|
|
|
waveform, sr = librosa.load(tmp_file_path, sr=sample_rate, duration=duration) |
|
|
|
|
|
|
|
|
waveform = torch.from_numpy(waveform).float() |
|
|
|
|
|
|
|
|
if len(waveform.shape) == 1: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
|
|
|
max_val = torch.max(torch.abs(waveform)) |
|
|
if max_val > 0: |
|
|
waveform = waveform / max_val |
|
|
|
|
|
|
|
|
target_length = sample_rate * duration |
|
|
current_length = waveform.shape[1] |
|
|
|
|
|
if current_length > target_length: |
|
|
start = (current_length - target_length) // 2 |
|
|
waveform = waveform[:, start:start + target_length] |
|
|
elif current_length < target_length: |
|
|
padding = target_length - current_length |
|
|
waveform = torch.nn.functional.pad(waveform, (0, padding)) |
|
|
|
|
|
|
|
|
mel_transform = T.MelSpectrogram( |
|
|
sample_rate=sample_rate, |
|
|
n_fft=2048, |
|
|
hop_length=512, |
|
|
n_mels=128, |
|
|
f_min=0, |
|
|
f_max=8000, |
|
|
window_fn=torch.hann_window, |
|
|
power=2.0 |
|
|
) |
|
|
|
|
|
amplitude_to_db = T.AmplitudeToDB(stype='power', top_db=80) |
|
|
|
|
|
|
|
|
mel_spec = mel_transform(waveform) |
|
|
mel_spec_db = amplitude_to_db(mel_spec) |
|
|
|
|
|
|
|
|
mean_val = mel_spec_db.mean() |
|
|
std_val = mel_spec_db.std() |
|
|
if std_val > 0: |
|
|
mel_spec_db = (mel_spec_db - mean_val) / std_val |
|
|
|
|
|
|
|
|
if os.path.exists(tmp_file_path): |
|
|
os.unlink(tmp_file_path) |
|
|
|
|
|
return mel_spec_db.unsqueeze(0) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
st.error(f"β Failed to process audio file") |
|
|
|
|
|
|
|
|
try: |
|
|
if tmp_file_path and os.path.exists(tmp_file_path): |
|
|
os.unlink(tmp_file_path) |
|
|
except: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def predict_bird_species(model, spectrogram, label_encoder, device): |
|
|
"""Make prediction on spectrogram""" |
|
|
try: |
|
|
spectrogram = spectrogram.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(spectrogram) |
|
|
probabilities = torch.softmax(outputs, dim=1) |
|
|
confidence, predicted = torch.max(probabilities, 1) |
|
|
|
|
|
predicted_species = label_encoder.inverse_transform([predicted.item()])[0] |
|
|
confidence_score = confidence.item() |
|
|
|
|
|
|
|
|
top3_probs, top3_indices = torch.topk(probabilities, 3, dim=1) |
|
|
top3_species = [] |
|
|
|
|
|
for i in range(3): |
|
|
species = label_encoder.inverse_transform([top3_indices[0][i].item()])[0] |
|
|
prob = top3_probs[0][i].item() |
|
|
top3_species.append((species, prob)) |
|
|
|
|
|
return predicted_species, confidence_score, top3_species |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Error making prediction: {str(e)}") |
|
|
return None, None, None |
|
|
|
|
|
def main(): |
|
|
st.set_page_config( |
|
|
page_title="Bird Sound Classifier", |
|
|
page_icon="π¦", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
st.title("π¦ AI Bird Sound Classifier") |
|
|
st.markdown("### Upload a bird audio recording to identify the species!") |
|
|
st.markdown("**Trained on 110+ species with 80% accuracy**") |
|
|
|
|
|
|
|
|
st.sidebar.header("πΏ About This App") |
|
|
st.sidebar.info( |
|
|
"This AI model identifies bird species from audio recordings using " |
|
|
"deep learning on spectrograms. Perfect for conservation efforts!" |
|
|
) |
|
|
|
|
|
st.sidebar.header("π Instructions") |
|
|
st.sidebar.markdown( |
|
|
""" |
|
|
1. Upload an audio file (.mp3, .wav) |
|
|
2. Click 'Identify Bird Species' |
|
|
3. View predictions and confidence scores |
|
|
4. Check alternative species suggestions |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
model, label_encoder, device = load_model_and_encoder() |
|
|
|
|
|
if model is None: |
|
|
st.error("β Failed to load model. Please check the model files.") |
|
|
st.stop() |
|
|
|
|
|
st.success("β
Model loaded successfully!") |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
|
"Choose an audio file", |
|
|
type=['mp3', 'wav', 'flac'], |
|
|
help="Upload a bird sound recording (first 5 seconds will be analyzed)" |
|
|
) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.write("**π File Details:**") |
|
|
st.write(f"β’ Name: {uploaded_file.name}") |
|
|
st.write(f"β’ Size: {uploaded_file.size:,} bytes") |
|
|
|
|
|
with col2: |
|
|
st.write("**π΅ Audio Player:**") |
|
|
st.audio(uploaded_file, format='audio/wav') |
|
|
|
|
|
|
|
|
|
|
|
if st.button("π Identify Bird Species", type="primary", use_container_width=True): |
|
|
with st.spinner("π Processing audio and making prediction..."): |
|
|
try: |
|
|
|
|
|
spectrogram = preprocess_audio(uploaded_file) |
|
|
|
|
|
if spectrogram is not None: |
|
|
predicted_species, confidence, top3_predictions = predict_bird_species( |
|
|
model, spectrogram, label_encoder, device |
|
|
) |
|
|
|
|
|
|
|
|
if predicted_species is not None: |
|
|
st.success("π Prediction Complete!") |
|
|
|
|
|
|
|
|
st.subheader("π Primary Prediction") |
|
|
clean_species = predicted_species.replace("_sound", "").replace("_", " ") |
|
|
|
|
|
col1, col2 = st.columns([2, 1]) |
|
|
with col1: |
|
|
st.metric( |
|
|
label="Predicted Species", |
|
|
value=clean_species, |
|
|
delta=f"{confidence:.1%} confidence" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
if confidence > 0.8: |
|
|
st.success("π― High Confidence") |
|
|
elif confidence > 0.6: |
|
|
st.warning("β οΈ Moderate Confidence") |
|
|
else: |
|
|
st.info("π Low Confidence") |
|
|
|
|
|
|
|
|
st.subheader("π Alternative Predictions") |
|
|
for i, (species, prob) in enumerate(top3_predictions): |
|
|
clean_name = species.replace("_sound", "").replace("_", " ") |
|
|
st.write(f"**{i+1}.** {clean_name}") |
|
|
st.progress(prob) |
|
|
st.caption(f"Confidence: {prob:.1%}") |
|
|
|
|
|
|
|
|
st.subheader("πΏ Conservation Impact") |
|
|
st.info( |
|
|
f"Identifying '{clean_species}' helps with biodiversity monitoring " |
|
|
"and conservation efforts in national parks and protected areas." |
|
|
) |
|
|
|
|
|
else: |
|
|
st.error("β Failed to process audio file.") |
|
|
|
|
|
|
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"β Error processing audio: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown( |
|
|
"**π Built for Conservation** | " |
|
|
"This tool supports wildlife monitoring and biodiversity research." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|