File size: 4,953 Bytes
83a14b6 056de71 83a14b6 056de71 a9eca6f adbe84d 60eeb55 83a14b6 056de71 83a14b6 60eeb55 83a14b6 056de71 dd28e0e cbdd927 dd28e0e 1d4ed0f dd28e0e cbdd927 056de71 dd28e0e 83a14b6 5cc1efc 60eeb55 cbdd927 1526231 cbdd927 1526231 fcf045e dd28e0e cbdd927 1526231 cbdd927 1526231 fcf045e 60eeb55 dd28e0e cbdd927 dd28e0e 83a14b6 1526231 83a14b6 a9eca6f 60eeb55 6ee41b9 dd28e0e 60eeb55 dd28e0e 60eeb55 ff3066a 60eeb55 ff3066a 1526231 60eeb55 1526231 dd28e0e 1526231 ff3066a dd28e0e ff3066a 60eeb55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import torch
from torch import nn
import cv2
import numpy as np
import json
from torchvision import models
import librosa
# Define the BirdCallRNN model
class BirdCallRNN(nn.Module):
def __init__(self, resnet, num_features, num_classes):
super(BirdCallRNN, self).__init__()
self.resnet = resnet
self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
batch, seq_len, C, H, W = x.size()
x = x.view(batch * seq_len, C, H, W)
features = self.resnet(x)
features = features.view(batch, seq_len, -1)
rnn_out, _ = self.rnn(features)
output = self.fc(rnn_out[:, -1, :]) # Note: We’ll use this for single-segment sequences
return output
# Function to convert MP3 to mel spectrogram (unchanged)
def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
y, sr = librosa.load(mp3_file, sr=None)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
log_S = librosa.power_to_db(S, ref=np.max)
current_time_steps = log_S.shape[1]
target_time_steps = target_shape[1]
if current_time_steps < target_time_steps:
pad_width = target_time_steps - current_time_steps
log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
elif current_time_steps > target_time_steps:
log_S_resized = log_S[:, :target_time_steps]
else:
log_S_resized = log_S
log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
return log_S_resized
# Load class mapping globally
with open('class_mapping.json', 'r') as f:
class_names = json.load(f)
# Revised inference function to predict per segment
def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
model.eval()
# Load audio and compute mel spectrogram
y, sr = librosa.load(mp3_file, sr=None)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
log_S = librosa.power_to_db(S, ref=np.max)
# Segment the spectrogram
num_segments = log_S.shape[1] // segment_length
if num_segments == 0:
segments = [log_S]
else:
segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
predictions = []
# Process each segment individually
for seg in segments:
seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
# Create a tensor with batch size 1 and sequence length 1
seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device) # Shape: (1, 1, 3, 224, 224)
output = model(seg_tensor)
pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
predicted_bird = class_names[str(pred)] # Convert pred to string to match JSON keys
predictions.append(predicted_bird)
return predictions
# Initialize the model
resnet = models.resnet50(weights='IMAGENET1K_V2')
num_features = resnet.fc.in_features
resnet.fc = nn.Identity()
num_classes = len(class_names) # Should be 114
model = BirdCallRNN(resnet, num_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()
# Prediction function for Gradio
def predict_bird(file_path):
predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
# Format predictions as a numbered list
formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
return formatted_predictions # Return formatted list of predictions
# Custom Gradio interface with additional components
def gradio_interface(file_path):
# Predict bird species
prediction = predict_bird(file_path)
# Display the uploaded MP3 file with a play button
audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True)
# Display images with titles
bird_species_image = gr.Image("1.jpg", label="Bird Species")
bird_description_image = gr.Image("2.jpg", label="Bird Description")
bird_origins_image = gr.Image("3.jpg", label="Bird Origins")
return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image
# Launch Gradio interface
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']),
outputs=[
gr.Textbox(label="Predicted Bird Species"),
gr.Audio(label="Uploaded MP3 File"),
gr.Image(label="Bird Species"),
gr.Image(label="Bird Description"),
gr.Image(label="Bird Origins")
]
)
interface.launch() |