File size: 12,921 Bytes
bf4dd33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
#!/usr/bin/env python3
"""BirdNET Real-Time Audio Classification Script

This script captures audio from the microphone and uses the BirdNET ONNX model
to predict bird species in real-time with continuous display updates.

Created using Copilot.
"""

from __future__ import annotations

import numpy as np
import sounddevice as sd
import onnxruntime as ort
import argparse
import os
import time
import threading
from collections import deque
from datetime import datetime
import queue


class RealTimeBirdDetector:
    """Real-time bird detection using microphone input."""

    def __init__(
        self,
        model_path: str = "model.onnx",
        labels_path: str = "BirdNET_GLOBAL_6K_V2.4_Labels.txt",
        sample_rate: int = 48000,
        window_duration: float = 3.0,
        confidence_threshold: float = 0.1,
        top_k: int = 5,
        update_interval: float = 1.0,
    ):
        """
        Initialize the real-time bird detector.

        Args:
            model_path: Path to the ONNX model file
            labels_path: Path to the species labels file
            sample_rate: Audio sample rate (48kHz for BirdNET)
            window_duration: Duration of each analysis window in seconds
            confidence_threshold: Minimum confidence for detections
            top_k: Number of top predictions to display
            update_interval: How often to update predictions (seconds)
        """
        self.model_path = model_path
        self.labels_path = labels_path
        self.sample_rate = sample_rate
        self.window_duration = window_duration
        self.window_size = int(sample_rate * window_duration)
        self.confidence_threshold = confidence_threshold
        self.top_k = top_k
        self.update_interval = update_interval

        # Audio buffer for continuous recording
        self.audio_buffer = deque(maxlen=self.window_size * 2)  # 6 seconds buffer
        self.audio_queue = queue.Queue()

        # Detection results
        self.current_detections = []
        self.detection_history = deque(maxlen=100)  # Keep last 100 detections
        self.running = False

        # Load model and labels
        self._load_model()
        self._load_labels()

    def _load_model(self) -> None:
        """Load the ONNX model."""
        try:
            print(f"Loading ONNX model: {self.model_path}")
            self.session = ort.InferenceSession(self.model_path)

            # Get model info
            input_info = self.session.get_inputs()[0]
            output_info = self.session.get_outputs()[0]
            print(f"Model input: {input_info.name}, shape: {input_info.shape}")
            print(f"Model output: {output_info.name}, shape: {output_info.shape}")

        except Exception as e:
            raise RuntimeError(f"Error loading ONNX model {self.model_path}: {str(e)}")

    def _load_labels(self) -> None:
        """Load species labels from file."""
        try:
            print(f"Loading labels from: {self.labels_path}")
            self.labels = []
            with open(self.labels_path, "r", encoding="utf-8") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        # Format: "Scientific_name_Common Name"
                        if "_" in line:
                            common_name = line.split("_", 1)[1]
                            self.labels.append(common_name)
                        else:
                            self.labels.append(line)
            print(f"Loaded {len(self.labels)} species labels")

        except Exception as e:
            raise RuntimeError(
                f"Error loading labels file {self.labels_path}: {str(e)}"
            )

    def _audio_callback(
        self, indata: np.ndarray, frames: int, time_info, status
    ) -> None:
        """Callback function for audio input."""
        if status:
            print(f"Audio status: {status}")

        # Convert stereo to mono if needed
        if len(indata.shape) > 1:
            audio_data = np.mean(indata, axis=1)
        else:
            audio_data = indata.flatten()

        # Add to queue for processing
        self.audio_queue.put(audio_data.copy())

    def _process_audio_buffer(self) -> None:
        """Process audio data from the queue."""
        while self.running:
            try:
                # Get audio data from queue (with timeout)
                audio_chunk = self.audio_queue.get(timeout=0.1)

                # Add to rolling buffer
                self.audio_buffer.extend(audio_chunk)

                # Process if we have enough data
                if len(self.audio_buffer) >= self.window_size:
                    # Get the most recent window
                    window_data = np.array(list(self.audio_buffer)[-self.window_size :])

                    # Run inference
                    self._analyze_audio_window(window_data)

            except queue.Empty:
                continue
            except Exception as e:
                print(f"Error processing audio: {e}")

    def _analyze_audio_window(self, audio_data: np.ndarray) -> None:
        """Analyze a single audio window."""
        try:
            # Ensure correct format
            audio_data = audio_data.astype(np.float32)

            # Add batch dimension
            input_data = np.expand_dims(audio_data, axis=0)

            # Get input name from the model
            input_name = self.session.get_inputs()[0].name

            # Run inference
            outputs = self.session.run(None, {input_name: input_data})
            predictions = outputs[0]

            # Get scores for this window
            predictions = np.array(predictions)
            if len(predictions.shape) > 1:
                scores = predictions[0]
            else:
                scores = predictions

            # Find detections above threshold
            above_threshold = np.where(scores > self.confidence_threshold)[0]

            # Create detection results
            detections = []
            for idx in above_threshold:
                confidence = float(scores[idx])
                species_name = (
                    self.labels[idx] if idx < len(self.labels) else f"Class {idx}"
                )
                detections.append(
                    {
                        "species": species_name,
                        "confidence": confidence,
                        "timestamp": datetime.now(),
                    }
                )

            # Sort by confidence
            detections.sort(key=lambda x: x["confidence"], reverse=True)

            # Update current detections
            self.current_detections = detections[: self.top_k]

            # Add to history
            if detections:
                self.detection_history.extend(detections[: self.top_k])

        except Exception as e:
            print(f"Error during inference: {e}")

    def _display_results(self) -> None:
        """Continuously display detection results."""
        while self.running:
            try:
                # Clear screen (works on most terminals)
                os.system("clear" if os.name == "posix" else "cls")

                # Display header
                print("🎀 BirdNET Real-Time Detection")
                print("=" * 50)
                print(f"Listening... (Confidence > {self.confidence_threshold:.2f})")
                print(f"Time: {datetime.now().strftime('%H:%M:%S')}")
                print()

                # Display current detections
                if self.current_detections:
                    print(
                        f"🐦 Current Detections (Top {len(self.current_detections)}):"
                    )
                    print("-" * 40)
                    for i, detection in enumerate(self.current_detections, 1):
                        confidence = detection["confidence"]
                        species = detection["species"]
                        # Add confidence bars
                        bar_length = int(confidence * 20)  # Scale to 20 chars
                        bar = "β–ˆ" * bar_length + "β–‘" * (20 - bar_length)
                        print(f"{i:2d}. {species}")
                        print(f"    {bar} {confidence:.4f}")
                else:
                    print("πŸ” No detections above threshold...")

                print()

                # Display recent activity
                if self.detection_history:
                    print("πŸ“Š Recent Activity (Last 10):")
                    print("-" * 40)
                    recent = list(self.detection_history)[-10:]
                    for detection in reversed(recent):
                        timestamp = detection["timestamp"].strftime("%H:%M:%S")
                        species = detection["species"]
                        confidence = detection["confidence"]
                        print(f"{timestamp} - {species} ({confidence:.3f})")

                print()
                print("Press Ctrl+C to stop")

                # Wait before next update
                time.sleep(self.update_interval)

            except KeyboardInterrupt:
                break
            except Exception as e:
                print(f"Display error: {e}")

    def start_detection(self) -> None:
        """Start real-time detection."""
        try:
            print("Starting real-time bird detection...")
            print(f"Sample rate: {self.sample_rate} Hz")
            print(f"Window size: {self.window_duration} seconds")
            print(f"Confidence threshold: {self.confidence_threshold}")
            print("Press Ctrl+C to stop\n")

            self.running = True

            # Start audio processing thread
            audio_thread = threading.Thread(
                target=self._process_audio_buffer, daemon=True
            )
            audio_thread.start()

            # Start display thread
            display_thread = threading.Thread(target=self._display_results, daemon=True)
            display_thread.start()

            # Start audio input stream
            with sd.InputStream(
                callback=self._audio_callback,
                channels=1,
                samplerate=self.sample_rate,
                blocksize=int(self.sample_rate * 0.1),  # 100ms blocks
                dtype=np.float32,
            ):
                print("🎀 Microphone active - listening for birds...")

                # Keep main thread alive
                try:
                    while self.running:
                        time.sleep(0.1)
                except KeyboardInterrupt:
                    pass

        except Exception as e:
            print(f"Error during detection: {e}")
        finally:
            self.running = False
            print("\nπŸ›‘ Detection stopped.")

    def stop_detection(self) -> None:
        """Stop detection."""
        self.running = False


def main() -> int:
    """Main function for real-time detection."""
    parser = argparse.ArgumentParser(
        description="BirdNET Real-Time Audio Classification"
    )
    parser.add_argument(
        "--model", default="model.onnx", help="Path to the ONNX model file"
    )
    parser.add_argument(
        "--labels",
        default="BirdNET_GLOBAL_6K_V2.4_Labels.txt",
        help="Path to the labels file",
    )
    parser.add_argument(
        "--confidence",
        type=float,
        default=0.1,
        help="Minimum confidence threshold for detections (default: 0.1)",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=5,
        help="Number of top predictions to show (default: 5)",
    )
    parser.add_argument(
        "--update-interval",
        type=float,
        default=1.0,
        help="Display update interval in seconds (default: 1.0)",
    )
    parser.add_argument(
        "--list-devices", action="store_true", help="List available audio input devices"
    )

    args = parser.parse_args()

    # List audio devices if requested
    if args.list_devices:
        print("Available audio input devices:")
        print(sd.query_devices())
        return 0

    # Check if files exist
    if not os.path.exists(args.model):
        print(f"Error: Model file '{args.model}' not found.")
        return 1

    if not os.path.exists(args.labels):
        print(f"Error: Labels file '{args.labels}' not found.")
        return 1

    try:
        # Create detector
        detector = RealTimeBirdDetector(
            model_path=args.model,
            labels_path=args.labels,
            confidence_threshold=args.confidence,
            top_k=args.top_k,
            update_interval=args.update_interval,
        )

        # Start detection
        detector.start_detection()

        return 0

    except Exception as e:
        print(f"Error: {str(e)}")
        return 1


if __name__ == "__main__":
    exit(main())