OfficerRaccoon commited on
Commit
a1f7c70
Β·
verified Β·
1 Parent(s): a274886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -85
app.py CHANGED
@@ -166,8 +166,11 @@ def load_model_and_encoder():
166
  # ------------------------------------------------------------------------------------------------------------------------------
167
 
168
  def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
169
- """Process audio from Streamlit uploaded file without seek() method"""
 
 
170
  tmp_file_path = None
 
171
  try:
172
  # Get the raw bytes from Streamlit uploaded file
173
  audio_bytes = uploaded_file.getvalue()
@@ -176,7 +179,10 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
176
  # Create a unique temporary file path
177
  import hashlib
178
  file_hash = hashlib.md5(audio_bytes).hexdigest()[:8]
179
- tmp_file_path = f"/tmp/audio_{file_hash}.wav"
 
 
 
180
 
181
  # Write bytes to temporary file
182
  with open(tmp_file_path, 'wb') as f:
@@ -194,31 +200,28 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
194
 
195
  st.write(f"βœ… Created temp file: {file_size} bytes at {tmp_file_path}")
196
 
197
- # Load audio with torchaudio
198
  try:
199
- waveform, sr = torchaudio.load(tmp_file_path)
200
- st.write(f"βœ… Audio loaded: shape {waveform.shape}, sample rate {sr}")
 
201
  except Exception as load_error:
202
- st.error(f"❌ torchaudio.load failed: {load_error}")
203
  return None
204
 
205
- # Resample if necessary
206
- if sr != sample_rate:
207
- resampler = T.Resample(sr, sample_rate)
208
- waveform = resampler(waveform)
209
- st.write(f"βœ… Resampled to {sample_rate} Hz")
210
 
211
- # Convert to mono
212
- if waveform.shape[0] > 1:
213
- waveform = torch.mean(waveform, dim=0, keepdim=True)
214
- st.write("βœ… Converted to mono")
215
 
216
  # Normalize audio
217
  max_val = torch.max(torch.abs(waveform))
218
  if max_val > 0:
219
  waveform = waveform / max_val
220
 
221
- # Pad or trim to fixed duration
222
  target_length = sample_rate * duration
223
  current_length = waveform.shape[1]
224
 
@@ -231,7 +234,7 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
231
  waveform = torch.nn.functional.pad(waveform, (0, padding))
232
  st.write(f"βœ… Padded audio to {target_length} samples")
233
 
234
- # Create mel spectrogram
235
  mel_transform = T.MelSpectrogram(
236
  sample_rate=sample_rate,
237
  n_fft=2048,
@@ -360,93 +363,31 @@ def main():
360
  st.write("**🎡 Audio Player:**")
361
  st.audio(uploaded_file, format='audio/wav')
362
 
 
363
  # Prediction button
364
  if st.button("πŸ” Identify Bird Species", type="primary", use_container_width=True):
365
  with st.spinner("πŸ”„ Processing audio and making prediction..."):
366
  try:
367
- # Create temporary file with proper handling
368
- with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
369
- # Write the uploaded file data
370
- tmp_file.write(uploaded_file.getvalue())
371
- tmp_file.flush() # Ensure data is written
372
- tmp_file_path = tmp_file.name
373
-
374
- # Verify file was created successfully
375
- if not os.path.exists(tmp_file_path):
376
- st.error("❌ Failed to create temporary file")
377
- return
378
-
379
- file_size = os.path.getsize(tmp_file_path)
380
- if file_size == 0:
381
- st.error("❌ Temporary file is empty")
382
- return
383
-
384
- st.write(f"βœ… Temporary file created: {file_size} bytes")
385
-
386
- # Process audio
387
- spectrogram = preprocess_audio(uploaded_file)
388
 
389
  if spectrogram is not None:
390
  predicted_species, confidence, top3_predictions = predict_bird_species(
391
  model, spectrogram, label_encoder, device
392
  )
393
 
394
- # Clean up temp file
395
- try:
396
- os.unlink(tmp_file_path)
397
- except:
398
- pass # Ignore cleanup errors
399
-
400
  # Display results
401
  if predicted_species is not None:
402
  st.success("πŸŽ‰ Prediction Complete!")
403
 
404
- # Main prediction
405
- st.subheader("πŸ† Primary Prediction")
406
- clean_species = predicted_species.replace("_sound", "").replace("_", " ")
407
 
408
- col1, col2 = st.columns([2, 1])
409
- with col1:
410
- st.metric(
411
- label="Predicted Species",
412
- value=clean_species,
413
- delta=f"{confidence:.1%} confidence"
414
- )
415
-
416
- with col2:
417
- if confidence > 0.8:
418
- st.success("🎯 High Confidence")
419
- elif confidence > 0.6:
420
- st.warning("⚠️ Moderate Confidence")
421
- else:
422
- st.info("πŸ’­ Low Confidence")
423
-
424
- # Top 3 predictions
425
- st.subheader("πŸ“Š Alternative Predictions")
426
- for i, (species, prob) in enumerate(top3_predictions):
427
- clean_name = species.replace("_sound", "").replace("_", " ")
428
- st.write(f"**{i+1}.** {clean_name}")
429
- st.progress(prob)
430
- st.caption(f"Confidence: {prob:.1%}")
431
-
432
- # Conservation note
433
- st.subheader("🌿 Conservation Impact")
434
- st.info(
435
- f"Identifying '{clean_species}' helps with biodiversity monitoring "
436
- "and conservation efforts in national parks and protected areas."
437
- )
438
-
439
  else:
440
  st.error("❌ Failed to process audio file.")
441
-
442
  except Exception as e:
443
  st.error(f"❌ Error processing audio: {str(e)}")
444
- # Clean up on error
445
- try:
446
- if 'tmp_file_path' in locals():
447
- os.unlink(tmp_file_path)
448
- except:
449
- pass
450
 
451
  # Footer
452
  st.markdown("---")
 
166
  # ------------------------------------------------------------------------------------------------------------------------------
167
 
168
  def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
169
+ """Process audio using librosa instead of torchaudio for better compatibility"""
170
+ import librosa
171
+ import numpy as np
172
  tmp_file_path = None
173
+
174
  try:
175
  # Get the raw bytes from Streamlit uploaded file
176
  audio_bytes = uploaded_file.getvalue()
 
179
  # Create a unique temporary file path
180
  import hashlib
181
  file_hash = hashlib.md5(audio_bytes).hexdigest()[:8]
182
+
183
+ # Determine file extension from uploaded file name
184
+ file_ext = uploaded_file.name.split('.')[-1].lower()
185
+ tmp_file_path = f"/tmp/audio_{file_hash}.{file_ext}"
186
 
187
  # Write bytes to temporary file
188
  with open(tmp_file_path, 'wb') as f:
 
200
 
201
  st.write(f"βœ… Created temp file: {file_size} bytes at {tmp_file_path}")
202
 
203
+ # Load audio with librosa (more reliable than torchaudio)
204
  try:
205
+ # librosa can handle MP3, WAV, FLAC automatically
206
+ waveform, sr = librosa.load(tmp_file_path, sr=sample_rate, duration=duration)
207
+ st.write(f"βœ… Audio loaded with librosa: shape {waveform.shape}, sample rate {sr}")
208
  except Exception as load_error:
209
+ st.error(f"❌ librosa.load failed: {load_error}")
210
  return None
211
 
212
+ # Convert numpy array to torch tensor
213
+ waveform = torch.from_numpy(waveform).float()
 
 
 
214
 
215
+ # Add channel dimension (librosa loads as 1D, we need 2D)
216
+ if len(waveform.shape) == 1:
217
+ waveform = waveform.unsqueeze(0) # Shape: (1, time)
 
218
 
219
  # Normalize audio
220
  max_val = torch.max(torch.abs(waveform))
221
  if max_val > 0:
222
  waveform = waveform / max_val
223
 
224
+ # Ensure exact duration
225
  target_length = sample_rate * duration
226
  current_length = waveform.shape[1]
227
 
 
234
  waveform = torch.nn.functional.pad(waveform, (0, padding))
235
  st.write(f"βœ… Padded audio to {target_length} samples")
236
 
237
+ # Create mel spectrogram using torchaudio transforms
238
  mel_transform = T.MelSpectrogram(
239
  sample_rate=sample_rate,
240
  n_fft=2048,
 
363
  st.write("**🎡 Audio Player:**")
364
  st.audio(uploaded_file, format='audio/wav')
365
 
366
+ # Prediction button
367
  # Prediction button
368
  if st.button("πŸ” Identify Bird Species", type="primary", use_container_width=True):
369
  with st.spinner("πŸ”„ Processing audio and making prediction..."):
370
  try:
371
+ # Process audio using librosa (more reliable)
372
+ spectrogram = preprocess_audio_librosa(uploaded_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  if spectrogram is not None:
375
  predicted_species, confidence, top3_predictions = predict_bird_species(
376
  model, spectrogram, label_encoder, device
377
  )
378
 
 
 
 
 
 
 
379
  # Display results
380
  if predicted_species is not None:
381
  st.success("πŸŽ‰ Prediction Complete!")
382
 
383
+ # Your existing result display code...
 
 
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  else:
386
  st.error("❌ Failed to process audio file.")
387
+
388
  except Exception as e:
389
  st.error(f"❌ Error processing audio: {str(e)}")
390
+
 
 
 
 
 
391
 
392
  # Footer
393
  st.markdown("---")