|
|
|
|
|
""" |
|
|
Simple test script for the Marine Species Identification API. |
|
|
This script can be used to quickly test the API functionality. |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import base64 |
|
|
import json |
|
|
import time |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import io |
|
|
|
|
|
|
|
|
def create_test_image(width: int = 640, height: int = 480) -> str: |
|
|
"""Create a test image and return as base64 string.""" |
|
|
|
|
|
image = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
image[100:200, 100:200] = [255, 0, 0] |
|
|
image[300:400, 300:400] = [0, 255, 0] |
|
|
|
|
|
pil_image = Image.fromarray(image) |
|
|
|
|
|
|
|
|
buffer = io.BytesIO() |
|
|
pil_image.save(buffer, format="JPEG", quality=85) |
|
|
image_bytes = buffer.getvalue() |
|
|
|
|
|
return base64.b64encode(image_bytes).decode('utf-8') |
|
|
|
|
|
|
|
|
def test_api(base_url: str = "http://localhost:7860"): |
|
|
"""Test the API endpoints.""" |
|
|
|
|
|
print(f"🧪 Testing Marine Species Identification API at {base_url}") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("1. Testing root endpoint...") |
|
|
try: |
|
|
response = requests.get(f"{base_url}/") |
|
|
print(f" Status: {response.status_code}") |
|
|
if response.status_code == 200: |
|
|
print(f" Response: {response.json()}") |
|
|
print() |
|
|
except Exception as e: |
|
|
print(f" Error: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
print("2. Testing health check...") |
|
|
try: |
|
|
response = requests.get(f"{base_url}/api/v1/health") |
|
|
print(f" Status: {response.status_code}") |
|
|
if response.status_code == 200: |
|
|
health_data = response.json() |
|
|
print(f" API Status: {health_data.get('status')}") |
|
|
print(f" Model Loaded: {health_data.get('model_loaded')}") |
|
|
print() |
|
|
except Exception as e: |
|
|
print(f" Error: {e}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("3. Testing API info...") |
|
|
try: |
|
|
response = requests.get(f"{base_url}/api/v1/info") |
|
|
print(f" Status: {response.status_code}") |
|
|
if response.status_code == 200: |
|
|
info_data = response.json() |
|
|
print(f" API Name: {info_data.get('name')}") |
|
|
print(f" Version: {info_data.get('version')}") |
|
|
model_info = info_data.get('model_info', {}) |
|
|
print(f" Model Classes: {model_info.get('total_classes')}") |
|
|
print() |
|
|
except Exception as e: |
|
|
print(f" Error: {e}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("4. Testing species list...") |
|
|
try: |
|
|
response = requests.get(f"{base_url}/api/v1/species") |
|
|
print(f" Status: {response.status_code}") |
|
|
if response.status_code == 200: |
|
|
species_data = response.json() |
|
|
total_species = species_data.get('total_count', 0) |
|
|
print(f" Total Species: {total_species}") |
|
|
if total_species > 0: |
|
|
print(f" First 3 species:") |
|
|
for species in species_data.get('species', [])[:3]: |
|
|
print(f" - {species.get('class_name')} (ID: {species.get('class_id')})") |
|
|
print() |
|
|
except Exception as e: |
|
|
print(f" Error: {e}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("5. Testing marine species detection...") |
|
|
try: |
|
|
|
|
|
print(" Creating test image...") |
|
|
test_image_b64 = create_test_image() |
|
|
|
|
|
|
|
|
detection_request = { |
|
|
"image": test_image_b64, |
|
|
"confidence_threshold": 0.25, |
|
|
"iou_threshold": 0.45, |
|
|
"image_size": 640, |
|
|
"return_annotated_image": True |
|
|
} |
|
|
|
|
|
print(" Sending detection request...") |
|
|
start_time = time.time() |
|
|
|
|
|
response = requests.post( |
|
|
f"{base_url}/api/v1/detect", |
|
|
json=detection_request, |
|
|
timeout=30 |
|
|
) |
|
|
|
|
|
end_time = time.time() |
|
|
request_time = end_time - start_time |
|
|
|
|
|
print(f" Status: {response.status_code}") |
|
|
print(f" Request Time: {request_time:.2f}s") |
|
|
|
|
|
if response.status_code == 200: |
|
|
detection_data = response.json() |
|
|
detections = detection_data.get('detections', []) |
|
|
processing_time = detection_data.get('processing_time', 0) |
|
|
|
|
|
print(f" Processing Time: {processing_time:.3f}s") |
|
|
print(f" Detections Found: {len(detections)}") |
|
|
|
|
|
if detections: |
|
|
print(" Top detections:") |
|
|
for i, detection in enumerate(detections[:3]): |
|
|
print(f" {i+1}. {detection.get('class_name')} " |
|
|
f"(confidence: {detection.get('confidence'):.3f})") |
|
|
|
|
|
|
|
|
if detection_data.get('annotated_image'): |
|
|
print(" ✅ Annotated image returned") |
|
|
else: |
|
|
print(" ❌ No annotated image returned") |
|
|
|
|
|
elif response.status_code == 503: |
|
|
print(" ⚠️ Service unavailable (model may not be loaded)") |
|
|
else: |
|
|
print(f" ❌ Error: {response.text}") |
|
|
|
|
|
print() |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Error: {e}") |
|
|
print() |
|
|
|
|
|
print("🎉 API testing completed!") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import sys |
|
|
|
|
|
|
|
|
base_url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860" |
|
|
|
|
|
test_api(base_url) |
|
|
|