| | """ |
| | Inference script for PAM-SDZWA-v1 (Peruvian Amazon Species Classifier) |
| | |
| | This model classifies 53 species found in Peruvian Amazon rainforest habitats. |
| | Developed by Mathias Tobler from the San Diego Zoo Wildlife Alliance Conservation |
| | Technology Lab using their animl-py framework. |
| | |
| | Model: Peru Amazon v0.86 |
| | Input: Variable size (extracted from model config) |
| | Framework: TensorFlow/Keras (TensorFlow 1.x compatible) |
| | Classes: 53 Amazonian species and taxonomic groups |
| | Developer: San Diego Zoo Wildlife Alliance (Mathias Tobler) |
| | License: MIT |
| | Info: https://github.com/conservationtechlab |
| | |
| | Author: Peter van Lunteren |
| | Created: 2026-01-14 |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import os |
| | from pathlib import Path |
| |
|
| | import cv2 |
| | import numpy as np |
| | import tensorflow as tf |
| | from PIL import Image, ImageFile |
| | from tensorflow.keras.models import load_model |
| |
|
| | |
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| |
|
| |
|
| | class ModelInference: |
| | """TensorFlow/Keras inference implementation for Peruvian Amazon species classifier.""" |
| |
|
| | def __init__(self, model_dir: Path, model_path: Path): |
| | """ |
| | Initialize with model paths. |
| | |
| | Args: |
| | model_dir: Directory containing model files and class labels |
| | model_path: Path to Peru-Amazon_0.86.h5 file |
| | """ |
| | self.model_dir = model_dir |
| | self.model_path = model_path |
| | self.model = None |
| | self.img_size = None |
| | self.class_map = {} |
| | self.class_ids_sorted = [] |
| |
|
| | def check_gpu(self) -> bool: |
| | """ |
| | Check GPU availability for TensorFlow inference. |
| | |
| | Returns: |
| | True if GPU available, False otherwise |
| | """ |
| | return len(tf.config.list_logical_devices('GPU')) > 0 |
| |
|
| | def load_model(self) -> None: |
| | """ |
| | Load TensorFlow/Keras model and class labels into memory. |
| | |
| | This function is called once during worker initialization. |
| | The model is stored in self.model and reused for all subsequent |
| | classification requests. |
| | |
| | Raises: |
| | RuntimeError: If model loading fails |
| | FileNotFoundError: If model_path or label file is invalid |
| | """ |
| | if not self.model_path.exists(): |
| | raise FileNotFoundError(f"Model file not found: {self.model_path}") |
| |
|
| | try: |
| | |
| | self.model = load_model(str(self.model_path)) |
| |
|
| | |
| | |
| | self.img_size = self.model.get_config()["layers"][0]["config"]["batch_input_shape"][1] |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e |
| |
|
| | |
| | label_file = self.model_dir / "Peru-Amazon_0.86.txt" |
| | if not label_file.exists(): |
| | raise FileNotFoundError(f"Class label file not found: {label_file}") |
| |
|
| | try: |
| | with open(label_file, 'r') as file: |
| | for line in file: |
| | parts = line.strip().split('"') |
| | if len(parts) >= 4: |
| | identifier = parts[1].strip() |
| | animal_name = parts[3].strip() |
| | if identifier.isdigit(): |
| | self.class_map[str(identifier)] = str(animal_name) |
| |
|
| | |
| | |
| | self.class_ids_sorted = sorted(self.class_map.values()) |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Failed to load class labels from {label_file}: {e}") from e |
| |
|
| | def get_crop( |
| | self, image: Image.Image, bbox: tuple[float, float, float, float] |
| | ) -> Image.Image: |
| | """ |
| | Crop image using SDZWA animl-py preprocessing. |
| | |
| | This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py |
| | framework approach with minimal buffering (0 pixels by default). |
| | |
| | Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py |
| | |
| | Args: |
| | image: PIL Image (full resolution) |
| | bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
| | |
| | Returns: |
| | Cropped PIL Image (not resized - resizing happens in get_classification) |
| | |
| | Raises: |
| | ValueError: If bbox is invalid |
| | """ |
| | buffer = 0 |
| | width, height = image.size |
| |
|
| | |
| | bbox1, bbox2, bbox3, bbox4 = bbox |
| | left = width * bbox1 |
| | top = height * bbox2 |
| | right = width * (bbox1 + bbox3) |
| | bottom = height * (bbox2 + bbox4) |
| |
|
| | |
| | left = max(0, int(left) - buffer) |
| | top = max(0, int(top) - buffer) |
| | right = min(width, int(right) + buffer) |
| | bottom = min(height, int(bottom) + buffer) |
| |
|
| | |
| | if left >= right or top >= bottom: |
| | raise ValueError(f"Invalid bbox dimensions after cropping: left={left}, top={top}, right={right}, bottom={bottom}") |
| |
|
| | |
| | image_cropped = image.crop((left, top, right, bottom)) |
| | return image_cropped |
| |
|
| | def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
| | """ |
| | Run TensorFlow/Keras classification on cropped image. |
| | |
| | Preprocessing follows SDZWA animl-py framework: |
| | - Resize to model input size (extracted from model config) |
| | - Convert to numpy array |
| | - No normalization or augmentation (except potential horizontal flip during training) |
| | |
| | Args: |
| | crop: Cropped PIL Image |
| | |
| | Returns: |
| | List of [class_name, confidence] lists for ALL classes, sorted by class ID. |
| | Example: [["Black-headed squirrel monkey", 0.001], ["Brazilian rabbit", 0.002], ...] |
| | NOTE: Sorting by confidence is handled by classification_worker.py |
| | |
| | Raises: |
| | RuntimeError: If model not loaded or inference fails |
| | """ |
| | if self.model is None: |
| | raise RuntimeError("Model not loaded - call load_model() first") |
| |
|
| | try: |
| | |
| | img = np.array(crop) |
| |
|
| | |
| | img = cv2.resize(img, (self.img_size, self.img_size)) |
| |
|
| | |
| | img = np.expand_dims(img, axis=0) |
| |
|
| | |
| | |
| | |
| | pred = self.model.predict(img, verbose=0)[0] |
| |
|
| | |
| | |
| | classifications = [] |
| | for i in range(len(pred)): |
| | class_name = self.class_ids_sorted[i] |
| | confidence = float(pred[i]) |
| | classifications.append([class_name, confidence]) |
| |
|
| | return classifications |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Keras classification failed: {e}") from e |
| |
|
| | def get_class_names(self) -> dict[str, str]: |
| | """ |
| | Get mapping of class IDs to species names. |
| | |
| | Class IDs are 1-indexed and correspond to the sorted order of class names. |
| | |
| | Returns: |
| | Dict mapping class ID (1-indexed string) to species name |
| | Example: {"1": "Black-headed squirrel monkey", "2": "Brazilian rabbit", ...} |
| | |
| | Raises: |
| | RuntimeError: If model not loaded |
| | """ |
| | if self.model is None: |
| | raise RuntimeError("Model not loaded - call load_model() first") |
| |
|
| | try: |
| | |
| | class_names = {} |
| | for i, class_name in enumerate(self.class_ids_sorted): |
| | class_id_str = str(i + 1) |
| | class_names[class_id_str] = class_name |
| |
|
| | return class_names |
| |
|
| | except Exception as e: |
| | raise RuntimeError(f"Failed to extract class names: {e}") from e |
| |
|