OfficerRaccoon commited on
Commit
ecef223
Β·
verified Β·
1 Parent(s): eb3fddf

Upload 5 files

Browse files

Deploy bird sound classifier with 80% accuracy

Files changed (5) hide show
  1. README.md +22 -20
  2. app.py +296 -0
  3. best_bird_model_extended.pth +3 -0
  4. label_encoder.pkl +3 -0
  5. requirements.txt +11 -3
README.md CHANGED
@@ -1,20 +1,22 @@
1
- ---
2
- title: Bird Sound Classifier V2
3
- emoji: πŸš€
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: AI bird species identification from audio
12
- license: cc-by-nc-4.0
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
1
+ ---
2
+ title: Bird Sound Classifier
3
+ emoji: 🐦
4
+ colorFrom: green
5
+ colorTo: blue
6
+ sdk: streamlit
7
+ sdk_version: 1.28.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Bird Sound Classifier 🐦
14
+
15
+ AI-powered bird species identification from audio recordings.
16
+
17
+ ## Features
18
+ - 80% accuracy across 110+ bird species
19
+ - Upload .mp3/.wav files
20
+ - Real-time predictions with confidence scores
21
+
22
+ Built for conservation efforts.
app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import torchaudio.transforms as T
7
+ import numpy as np
8
+ import pickle
9
+ import tempfile
10
+ import os
11
+
12
+ # Your model architecture (same as before)
13
+ class ImprovedBirdSoundCNN(nn.Module):
14
+ def __init__(self, num_classes, dropout_rate=0.3):
15
+ super(ImprovedBirdSoundCNN, self).__init__()
16
+
17
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
18
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
19
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
20
+ self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
21
+ self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
22
+ self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
23
+
24
+ self.bn1 = nn.BatchNorm2d(64)
25
+ self.bn2 = nn.BatchNorm2d(64)
26
+ self.bn3 = nn.BatchNorm2d(128)
27
+ self.bn4 = nn.BatchNorm2d(128)
28
+ self.bn5 = nn.BatchNorm2d(256)
29
+ self.bn6 = nn.BatchNorm2d(256)
30
+
31
+ self.pool = nn.MaxPool2d(2, 2)
32
+ self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
33
+ self.dropout = nn.Dropout(dropout_rate)
34
+
35
+ self.fc1 = nn.Linear(256 * 4 * 4, 512)
36
+ self.fc2 = nn.Linear(512, 256)
37
+ self.fc3 = nn.Linear(256, num_classes)
38
+
39
+ def forward(self, x):
40
+ x = F.relu(self.bn1(self.conv1(x)))
41
+ x = F.relu(self.bn2(self.conv2(x)))
42
+ x = self.pool(x)
43
+ x = self.dropout(x)
44
+
45
+ x = F.relu(self.bn3(self.conv3(x)))
46
+ x = F.relu(self.bn4(self.conv4(x)))
47
+ x = self.pool(x)
48
+ x = self.dropout(x)
49
+
50
+ x = F.relu(self.bn5(self.conv5(x)))
51
+ x = F.relu(self.bn6(self.conv6(x)))
52
+ x = self.adaptive_pool(x)
53
+ x = self.dropout(x)
54
+
55
+ x = x.view(x.size(0), -1)
56
+ x = F.relu(self.fc1(x))
57
+ x = self.dropout(x)
58
+ x = F.relu(self.fc2(x))
59
+ x = self.dropout(x)
60
+ x = self.fc3(x)
61
+
62
+ return x
63
+
64
+ @st.cache_resource
65
+ def load_model_and_encoder():
66
+ """Load model and label encoder - cached for performance"""
67
+ device = torch.device('cpu') # HF Spaces uses CPU
68
+
69
+ try:
70
+ # Load label encoder
71
+ with open('label_encoder.pkl', 'rb') as f:
72
+ label_encoder = pickle.load(f)
73
+
74
+ num_classes = len(label_encoder.classes_)
75
+
76
+ # Load model
77
+ model = ImprovedBirdSoundCNN(num_classes=num_classes)
78
+ checkpoint = torch.load('best_bird_model_extended.pth', map_location=device)
79
+
80
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
81
+ model.load_state_dict(checkpoint['model_state_dict'])
82
+ else:
83
+ model.load_state_dict(checkpoint)
84
+
85
+ model.eval()
86
+ return model, label_encoder, device
87
+
88
+ except Exception as e:
89
+ st.error(f"Error loading model: {str(e)}")
90
+ return None, None, None
91
+
92
+ def preprocess_audio(audio_file, sample_rate=22050, duration=5):
93
+ """Preprocess audio for prediction"""
94
+ try:
95
+ # Load audio
96
+ waveform, sr = torchaudio.load(audio_file)
97
+
98
+ # Resample if necessary
99
+ if sr != sample_rate:
100
+ resampler = T.Resample(sr, sample_rate)
101
+ waveform = resampler(waveform)
102
+
103
+ # Convert to mono
104
+ if waveform.shape[0] > 1:
105
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
106
+
107
+ # Normalize
108
+ waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
109
+
110
+ # Pad or trim
111
+ target_length = sample_rate * duration
112
+ if waveform.shape[1] > target_length:
113
+ start = (waveform.shape[1] - target_length) // 2
114
+ waveform = waveform[:, start:start + target_length]
115
+ else:
116
+ padding = target_length - waveform.shape[1]
117
+ waveform = torch.nn.functional.pad(waveform, (0, padding))
118
+
119
+ # Create spectrogram
120
+ mel_transform = T.MelSpectrogram(
121
+ sample_rate=sample_rate,
122
+ n_fft=2048,
123
+ hop_length=512,
124
+ n_mels=128,
125
+ f_min=0,
126
+ f_max=8000,
127
+ window_fn=torch.hann_window,
128
+ power=2.0
129
+ )
130
+
131
+ amplitude_to_db = T.AmplitudeToDB(stype='power', top_db=80)
132
+
133
+ mel_spec = mel_transform(waveform)
134
+ mel_spec_db = amplitude_to_db(mel_spec)
135
+ mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
136
+
137
+ return mel_spec_db.unsqueeze(0)
138
+
139
+ except Exception as e:
140
+ st.error(f"Error preprocessing audio: {str(e)}")
141
+ return None
142
+
143
+ def predict_bird_species(model, spectrogram, label_encoder, device):
144
+ """Make prediction on spectrogram"""
145
+ try:
146
+ spectrogram = spectrogram.to(device)
147
+
148
+ with torch.no_grad():
149
+ outputs = model(spectrogram)
150
+ probabilities = torch.softmax(outputs, dim=1)
151
+ confidence, predicted = torch.max(probabilities, 1)
152
+
153
+ predicted_species = label_encoder.inverse_transform([predicted.item()])[0]
154
+ confidence_score = confidence.item()
155
+
156
+ # Get top 3 predictions
157
+ top3_probs, top3_indices = torch.topk(probabilities, 3, dim=1)
158
+ top3_species = []
159
+
160
+ for i in range(3):
161
+ species = label_encoder.inverse_transform([top3_indices[0][i].item()])[0]
162
+ prob = top3_probs[0][i].item()
163
+ top3_species.append((species, prob))
164
+
165
+ return predicted_species, confidence_score, top3_species
166
+
167
+ except Exception as e:
168
+ st.error(f"Error making prediction: {str(e)}")
169
+ return None, None, None
170
+
171
+ def main():
172
+ st.set_page_config(
173
+ page_title="Bird Sound Classifier",
174
+ page_icon="🐦",
175
+ layout="wide"
176
+ )
177
+
178
+ st.title("🐦 AI Bird Sound Classifier")
179
+ st.markdown("### Upload a bird audio recording to identify the species!")
180
+ st.markdown("**Trained on 110+ species with 80% accuracy**")
181
+
182
+ # Sidebar
183
+ st.sidebar.header("🌿 About This App")
184
+ st.sidebar.info(
185
+ "This AI model identifies bird species from audio recordings using "
186
+ "deep learning on spectrograms. Perfect for conservation efforts!"
187
+ )
188
+
189
+ st.sidebar.header("πŸ“‹ Instructions")
190
+ st.sidebar.markdown(
191
+ """
192
+ 1. Upload an audio file (.mp3, .wav)
193
+ 2. Click 'Identify Bird Species'
194
+ 3. View predictions and confidence scores
195
+ 4. Check alternative species suggestions
196
+ """
197
+ )
198
+
199
+ # Load model
200
+ model, label_encoder, device = load_model_and_encoder()
201
+
202
+ if model is None:
203
+ st.error("❌ Failed to load model. Please check the model files.")
204
+ st.stop()
205
+
206
+ st.success("βœ… Model loaded successfully!")
207
+
208
+ # File upload
209
+ uploaded_file = st.file_uploader(
210
+ "Choose an audio file",
211
+ type=['mp3', 'wav', 'flac'],
212
+ help="Upload a bird sound recording (first 5 seconds will be analyzed)"
213
+ )
214
+
215
+ if uploaded_file is not None:
216
+ # Display file info
217
+ col1, col2 = st.columns(2)
218
+ with col1:
219
+ st.write("**πŸ“ File Details:**")
220
+ st.write(f"β€’ Name: {uploaded_file.name}")
221
+ st.write(f"β€’ Size: {uploaded_file.size:,} bytes")
222
+
223
+ with col2:
224
+ st.write("**🎡 Audio Player:**")
225
+ st.audio(uploaded_file, format='audio/wav')
226
+
227
+ # Prediction button
228
+ if st.button("πŸ” Identify Bird Species", type="primary", use_container_width=True):
229
+ with st.spinner("πŸ”„ Processing audio and making prediction..."):
230
+ # Save uploaded file temporarily
231
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
232
+ tmp_file.write(uploaded_file.getvalue())
233
+ tmp_file_path = tmp_file.name
234
+
235
+ # Process and predict
236
+ spectrogram = preprocess_audio(tmp_file_path)
237
+
238
+ if spectrogram is not None:
239
+ predicted_species, confidence, top3_predictions = predict_bird_species(
240
+ model, spectrogram, label_encoder, device
241
+ )
242
+
243
+ # Clean up
244
+ os.unlink(tmp_file_path)
245
+
246
+ if predicted_species is not None:
247
+ # Display results
248
+ st.success("πŸŽ‰ Prediction Complete!")
249
+
250
+ # Main prediction
251
+ st.subheader("πŸ† Primary Prediction")
252
+ clean_species = predicted_species.replace("_sound", "").replace("_", " ")
253
+
254
+ col1, col2 = st.columns([2, 1])
255
+ with col1:
256
+ st.metric(
257
+ label="Predicted Species",
258
+ value=clean_species,
259
+ delta=f"{confidence:.1%} confidence"
260
+ )
261
+
262
+ with col2:
263
+ if confidence > 0.8:
264
+ st.success("🎯 High Confidence")
265
+ elif confidence > 0.6:
266
+ st.warning("⚠️ Moderate Confidence")
267
+ else:
268
+ st.info("πŸ’­ Low Confidence")
269
+
270
+ # Top 3 predictions
271
+ st.subheader("πŸ“Š Alternative Predictions")
272
+ for i, (species, prob) in enumerate(top3_predictions):
273
+ clean_name = species.replace("_sound", "").replace("_", " ")
274
+ st.write(f"**{i+1}.** {clean_name}")
275
+ st.progress(prob)
276
+ st.caption(f"Confidence: {prob:.1%}")
277
+
278
+ # Conservation note
279
+ st.subheader("🌿 Conservation Impact")
280
+ st.info(
281
+ f"Identifying '{clean_species}' helps with biodiversity monitoring "
282
+ "and conservation efforts in national parks and protected areas."
283
+ )
284
+
285
+ else:
286
+ st.error("❌ Failed to process audio file.")
287
+
288
+ # Footer
289
+ st.markdown("---")
290
+ st.markdown(
291
+ "**🌍 Built for Conservation** | "
292
+ "This tool supports wildlife monitoring and biodiversity research."
293
+ )
294
+
295
+ if __name__ == "__main__":
296
+ main()
best_bird_model_extended.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec8fe37088efd2913125f8af12564cb74ae2f00c6de14de4811c2227d5ad77c6
3
+ size 133
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5873dfe367f384888203d3e05f35af1e72484dd996b06cc5184b56b4f7d5bdb
3
+ size 14832
requirements.txt CHANGED
@@ -1,3 +1,11 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ streamlit==1.28.1
3
+ torch==2.0.1
4
+ torchaudio==2.0.2
5
+ scikit-learn==1.3.0
6
+ numpy==1.24.3
7
+ =======
8
+ altair
9
+ pandas
10
+ streamlit
11
+ >>>>>>> dc2102210ae764c915fc5f4ac1fa3ad6b0cadc59