File size: 4,953 Bytes
83a14b6
056de71
83a14b6
 
 
056de71
a9eca6f
 
adbe84d
60eeb55
83a14b6
 
 
 
 
 
056de71
83a14b6
 
 
 
 
 
60eeb55
83a14b6
056de71
dd28e0e
cbdd927
 
 
 
 
 
 
 
 
dd28e0e
1d4ed0f
dd28e0e
 
cbdd927
 
056de71
dd28e0e
83a14b6
 
5cc1efc
60eeb55
cbdd927
 
1526231
cbdd927
 
 
1526231
fcf045e
dd28e0e
 
 
 
 
cbdd927
1526231
cbdd927
 
 
1526231
 
fcf045e
60eeb55
 
 
dd28e0e
cbdd927
dd28e0e
83a14b6
 
 
1526231
83a14b6
 
 
 
 
a9eca6f
60eeb55
6ee41b9
dd28e0e
60eeb55
 
 
dd28e0e
60eeb55
ff3066a
60eeb55
ff3066a
1526231
 
60eeb55
1526231
 
dd28e0e
 
 
1526231
ff3066a
dd28e0e
ff3066a
 
 
 
 
 
 
 
 
 
 
 
60eeb55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import torch
from torch import nn
import cv2
import numpy as np
import json
from torchvision import models
import librosa

# Define the BirdCallRNN model
class BirdCallRNN(nn.Module):
    def __init__(self, resnet, num_features, num_classes):
        super(BirdCallRNN, self).__init__()
        self.resnet = resnet
        self.rnn = nn.LSTM(input_size=num_features, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        batch, seq_len, C, H, W = x.size()
        x = x.view(batch * seq_len, C, H, W)
        features = self.resnet(x)
        features = features.view(batch, seq_len, -1)
        rnn_out, _ = self.rnn(features)
        output = self.fc(rnn_out[:, -1, :])  # Note: We’ll use this for single-segment sequences
        return output

# Function to convert MP3 to mel spectrogram (unchanged)
def mp3_to_mel_spectrogram(mp3_file, target_shape=(128, 500), resize_shape=(224, 224)):
    y, sr = librosa.load(mp3_file, sr=None)
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
    log_S = librosa.power_to_db(S, ref=np.max)
    current_time_steps = log_S.shape[1]
    target_time_steps = target_shape[1]
    if current_time_steps < target_time_steps:
        pad_width = target_time_steps - current_time_steps
        log_S_resized = np.pad(log_S, ((0, 0), (0, pad_width)), mode='constant')
    elif current_time_steps > target_time_steps:
        log_S_resized = log_S[:, :target_time_steps]
    else:
        log_S_resized = log_S
    log_S_resized = cv2.resize(log_S_resized, resize_shape, interpolation=cv2.INTER_CUBIC)
    return log_S_resized

# Load class mapping globally
with open('class_mapping.json', 'r') as f:
    class_names = json.load(f)

# Revised inference function to predict per segment
def infer_birdcall(model, mp3_file, segment_length=500, device="cuda"):
    model.eval()
    # Load audio and compute mel spectrogram
    y, sr = librosa.load(mp3_file, sr=None)
    S = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=128, fmax=8000)
    log_S = librosa.power_to_db(S, ref=np.max)
    # Segment the spectrogram
    num_segments = log_S.shape[1] // segment_length
    if num_segments == 0:
        segments = [log_S]
    else:
        segments = [log_S[:, i * segment_length:(i + 1) * segment_length] for i in range(num_segments)]
    
    predictions = []
    # Process each segment individually
    for seg in segments:
        seg_resized = cv2.resize(seg, (224, 224), interpolation=cv2.INTER_CUBIC)
        seg_rgb = np.repeat(seg_resized[:, :, np.newaxis], 3, axis=-1)
        # Create a tensor with batch size 1 and sequence length 1
        seg_tensor = torch.from_numpy(seg_rgb).permute(2, 0, 1).float().unsqueeze(0).unsqueeze(0).to(device)  # Shape: (1, 1, 3, 224, 224)
        output = model(seg_tensor)
        pred = torch.max(output, dim=1)[1].cpu().numpy()[0]
        predicted_bird = class_names[str(pred)]  # Convert pred to string to match JSON keys
        predictions.append(predicted_bird)
    return predictions

# Initialize the model
resnet = models.resnet50(weights='IMAGENET1K_V2')
num_features = resnet.fc.in_features
resnet.fc = nn.Identity()
num_classes = len(class_names)  # Should be 114
model = BirdCallRNN(resnet, num_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(torch.load('model_weights.pth', map_location=device))
model.eval()

# Prediction function for Gradio
def predict_bird(file_path):
    predictions = infer_birdcall(model, file_path, segment_length=500, device=str(device))
    # Format predictions as a numbered list
    formatted_predictions = "\n".join([f"{i+1}. {pred}" for i, pred in enumerate(predictions)])
    return formatted_predictions  # Return formatted list of predictions

# Custom Gradio interface with additional components
def gradio_interface(file_path):
    # Predict bird species
    prediction = predict_bird(file_path)
    
    # Display the uploaded MP3 file with a play button
    audio_player = gr.Audio(file_path, label="Uploaded MP3 File", visible=True, autoplay=True)
    
    # Display images with titles
    bird_species_image = gr.Image("1.jpg", label="Bird Species")
    bird_description_image = gr.Image("2.jpg", label="Bird Description")
    bird_origins_image = gr.Image("3.jpg", label="Bird Origins")
    
    return prediction, audio_player, bird_species_image, bird_description_image, bird_origins_image

# Launch Gradio interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.File(label="Upload MP3 file", file_types=['.mp3']),
    outputs=[
        gr.Textbox(label="Predicted Bird Species"),
        gr.Audio(label="Uploaded MP3 File"),
        gr.Image(label="Bird Species"),
        gr.Image(label="Bird Description"),
        gr.Image(label="Bird Origins")
    ]
)
interface.launch()