Update app.py
Browse files
app.py
CHANGED
|
@@ -166,8 +166,11 @@ def load_model_and_encoder():
|
|
| 166 |
# ------------------------------------------------------------------------------------------------------------------------------
|
| 167 |
|
| 168 |
def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
|
| 169 |
-
"""Process audio
|
|
|
|
|
|
|
| 170 |
tmp_file_path = None
|
|
|
|
| 171 |
try:
|
| 172 |
# Get the raw bytes from Streamlit uploaded file
|
| 173 |
audio_bytes = uploaded_file.getvalue()
|
|
@@ -176,7 +179,10 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
|
|
| 176 |
# Create a unique temporary file path
|
| 177 |
import hashlib
|
| 178 |
file_hash = hashlib.md5(audio_bytes).hexdigest()[:8]
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Write bytes to temporary file
|
| 182 |
with open(tmp_file_path, 'wb') as f:
|
|
@@ -194,31 +200,28 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
|
|
| 194 |
|
| 195 |
st.write(f"β
Created temp file: {file_size} bytes at {tmp_file_path}")
|
| 196 |
|
| 197 |
-
# Load audio with torchaudio
|
| 198 |
try:
|
| 199 |
-
|
| 200 |
-
|
|
|
|
| 201 |
except Exception as load_error:
|
| 202 |
-
st.error(f"β
|
| 203 |
return None
|
| 204 |
|
| 205 |
-
#
|
| 206 |
-
|
| 207 |
-
resampler = T.Resample(sr, sample_rate)
|
| 208 |
-
waveform = resampler(waveform)
|
| 209 |
-
st.write(f"β
Resampled to {sample_rate} Hz")
|
| 210 |
|
| 211 |
-
#
|
| 212 |
-
if waveform.shape
|
| 213 |
-
waveform =
|
| 214 |
-
st.write("β
Converted to mono")
|
| 215 |
|
| 216 |
# Normalize audio
|
| 217 |
max_val = torch.max(torch.abs(waveform))
|
| 218 |
if max_val > 0:
|
| 219 |
waveform = waveform / max_val
|
| 220 |
|
| 221 |
-
#
|
| 222 |
target_length = sample_rate * duration
|
| 223 |
current_length = waveform.shape[1]
|
| 224 |
|
|
@@ -231,7 +234,7 @@ def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
|
|
| 231 |
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
| 232 |
st.write(f"β
Padded audio to {target_length} samples")
|
| 233 |
|
| 234 |
-
# Create mel spectrogram
|
| 235 |
mel_transform = T.MelSpectrogram(
|
| 236 |
sample_rate=sample_rate,
|
| 237 |
n_fft=2048,
|
|
@@ -360,93 +363,31 @@ def main():
|
|
| 360 |
st.write("**π΅ Audio Player:**")
|
| 361 |
st.audio(uploaded_file, format='audio/wav')
|
| 362 |
|
|
|
|
| 363 |
# Prediction button
|
| 364 |
if st.button("π Identify Bird Species", type="primary", use_container_width=True):
|
| 365 |
with st.spinner("π Processing audio and making prediction..."):
|
| 366 |
try:
|
| 367 |
-
#
|
| 368 |
-
|
| 369 |
-
# Write the uploaded file data
|
| 370 |
-
tmp_file.write(uploaded_file.getvalue())
|
| 371 |
-
tmp_file.flush() # Ensure data is written
|
| 372 |
-
tmp_file_path = tmp_file.name
|
| 373 |
-
|
| 374 |
-
# Verify file was created successfully
|
| 375 |
-
if not os.path.exists(tmp_file_path):
|
| 376 |
-
st.error("β Failed to create temporary file")
|
| 377 |
-
return
|
| 378 |
-
|
| 379 |
-
file_size = os.path.getsize(tmp_file_path)
|
| 380 |
-
if file_size == 0:
|
| 381 |
-
st.error("β Temporary file is empty")
|
| 382 |
-
return
|
| 383 |
-
|
| 384 |
-
st.write(f"β
Temporary file created: {file_size} bytes")
|
| 385 |
-
|
| 386 |
-
# Process audio
|
| 387 |
-
spectrogram = preprocess_audio(uploaded_file)
|
| 388 |
|
| 389 |
if spectrogram is not None:
|
| 390 |
predicted_species, confidence, top3_predictions = predict_bird_species(
|
| 391 |
model, spectrogram, label_encoder, device
|
| 392 |
)
|
| 393 |
|
| 394 |
-
# Clean up temp file
|
| 395 |
-
try:
|
| 396 |
-
os.unlink(tmp_file_path)
|
| 397 |
-
except:
|
| 398 |
-
pass # Ignore cleanup errors
|
| 399 |
-
|
| 400 |
# Display results
|
| 401 |
if predicted_species is not None:
|
| 402 |
st.success("π Prediction Complete!")
|
| 403 |
|
| 404 |
-
#
|
| 405 |
-
st.subheader("π Primary Prediction")
|
| 406 |
-
clean_species = predicted_species.replace("_sound", "").replace("_", " ")
|
| 407 |
|
| 408 |
-
col1, col2 = st.columns([2, 1])
|
| 409 |
-
with col1:
|
| 410 |
-
st.metric(
|
| 411 |
-
label="Predicted Species",
|
| 412 |
-
value=clean_species,
|
| 413 |
-
delta=f"{confidence:.1%} confidence"
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
with col2:
|
| 417 |
-
if confidence > 0.8:
|
| 418 |
-
st.success("π― High Confidence")
|
| 419 |
-
elif confidence > 0.6:
|
| 420 |
-
st.warning("β οΈ Moderate Confidence")
|
| 421 |
-
else:
|
| 422 |
-
st.info("π Low Confidence")
|
| 423 |
-
|
| 424 |
-
# Top 3 predictions
|
| 425 |
-
st.subheader("π Alternative Predictions")
|
| 426 |
-
for i, (species, prob) in enumerate(top3_predictions):
|
| 427 |
-
clean_name = species.replace("_sound", "").replace("_", " ")
|
| 428 |
-
st.write(f"**{i+1}.** {clean_name}")
|
| 429 |
-
st.progress(prob)
|
| 430 |
-
st.caption(f"Confidence: {prob:.1%}")
|
| 431 |
-
|
| 432 |
-
# Conservation note
|
| 433 |
-
st.subheader("πΏ Conservation Impact")
|
| 434 |
-
st.info(
|
| 435 |
-
f"Identifying '{clean_species}' helps with biodiversity monitoring "
|
| 436 |
-
"and conservation efforts in national parks and protected areas."
|
| 437 |
-
)
|
| 438 |
-
|
| 439 |
else:
|
| 440 |
st.error("β Failed to process audio file.")
|
| 441 |
-
|
| 442 |
except Exception as e:
|
| 443 |
st.error(f"β Error processing audio: {str(e)}")
|
| 444 |
-
|
| 445 |
-
try:
|
| 446 |
-
if 'tmp_file_path' in locals():
|
| 447 |
-
os.unlink(tmp_file_path)
|
| 448 |
-
except:
|
| 449 |
-
pass
|
| 450 |
|
| 451 |
# Footer
|
| 452 |
st.markdown("---")
|
|
|
|
| 166 |
# ------------------------------------------------------------------------------------------------------------------------------
|
| 167 |
|
| 168 |
def preprocess_audio(uploaded_file, sample_rate=22050, duration=5):
|
| 169 |
+
"""Process audio using librosa instead of torchaudio for better compatibility"""
|
| 170 |
+
import librosa
|
| 171 |
+
import numpy as np
|
| 172 |
tmp_file_path = None
|
| 173 |
+
|
| 174 |
try:
|
| 175 |
# Get the raw bytes from Streamlit uploaded file
|
| 176 |
audio_bytes = uploaded_file.getvalue()
|
|
|
|
| 179 |
# Create a unique temporary file path
|
| 180 |
import hashlib
|
| 181 |
file_hash = hashlib.md5(audio_bytes).hexdigest()[:8]
|
| 182 |
+
|
| 183 |
+
# Determine file extension from uploaded file name
|
| 184 |
+
file_ext = uploaded_file.name.split('.')[-1].lower()
|
| 185 |
+
tmp_file_path = f"/tmp/audio_{file_hash}.{file_ext}"
|
| 186 |
|
| 187 |
# Write bytes to temporary file
|
| 188 |
with open(tmp_file_path, 'wb') as f:
|
|
|
|
| 200 |
|
| 201 |
st.write(f"β
Created temp file: {file_size} bytes at {tmp_file_path}")
|
| 202 |
|
| 203 |
+
# Load audio with librosa (more reliable than torchaudio)
|
| 204 |
try:
|
| 205 |
+
# librosa can handle MP3, WAV, FLAC automatically
|
| 206 |
+
waveform, sr = librosa.load(tmp_file_path, sr=sample_rate, duration=duration)
|
| 207 |
+
st.write(f"β
Audio loaded with librosa: shape {waveform.shape}, sample rate {sr}")
|
| 208 |
except Exception as load_error:
|
| 209 |
+
st.error(f"β librosa.load failed: {load_error}")
|
| 210 |
return None
|
| 211 |
|
| 212 |
+
# Convert numpy array to torch tensor
|
| 213 |
+
waveform = torch.from_numpy(waveform).float()
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
+
# Add channel dimension (librosa loads as 1D, we need 2D)
|
| 216 |
+
if len(waveform.shape) == 1:
|
| 217 |
+
waveform = waveform.unsqueeze(0) # Shape: (1, time)
|
|
|
|
| 218 |
|
| 219 |
# Normalize audio
|
| 220 |
max_val = torch.max(torch.abs(waveform))
|
| 221 |
if max_val > 0:
|
| 222 |
waveform = waveform / max_val
|
| 223 |
|
| 224 |
+
# Ensure exact duration
|
| 225 |
target_length = sample_rate * duration
|
| 226 |
current_length = waveform.shape[1]
|
| 227 |
|
|
|
|
| 234 |
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
| 235 |
st.write(f"β
Padded audio to {target_length} samples")
|
| 236 |
|
| 237 |
+
# Create mel spectrogram using torchaudio transforms
|
| 238 |
mel_transform = T.MelSpectrogram(
|
| 239 |
sample_rate=sample_rate,
|
| 240 |
n_fft=2048,
|
|
|
|
| 363 |
st.write("**π΅ Audio Player:**")
|
| 364 |
st.audio(uploaded_file, format='audio/wav')
|
| 365 |
|
| 366 |
+
# Prediction button
|
| 367 |
# Prediction button
|
| 368 |
if st.button("π Identify Bird Species", type="primary", use_container_width=True):
|
| 369 |
with st.spinner("π Processing audio and making prediction..."):
|
| 370 |
try:
|
| 371 |
+
# Process audio using librosa (more reliable)
|
| 372 |
+
spectrogram = preprocess_audio_librosa(uploaded_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
if spectrogram is not None:
|
| 375 |
predicted_species, confidence, top3_predictions = predict_bird_species(
|
| 376 |
model, spectrogram, label_encoder, device
|
| 377 |
)
|
| 378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
# Display results
|
| 380 |
if predicted_species is not None:
|
| 381 |
st.success("π Prediction Complete!")
|
| 382 |
|
| 383 |
+
# Your existing result display code...
|
|
|
|
|
|
|
| 384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
else:
|
| 386 |
st.error("β Failed to process audio file.")
|
| 387 |
+
|
| 388 |
except Exception as e:
|
| 389 |
st.error(f"β Error processing audio: {str(e)}")
|
| 390 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
|
| 392 |
# Footer
|
| 393 |
st.markdown("---")
|