|
|
""" |
|
|
Tests for API endpoints and WebSocket functionality. |
|
|
""" |
|
|
|
|
|
import pytest |
|
|
import json |
|
|
import base64 |
|
|
from unittest.mock import Mock, patch, MagicMock |
|
|
from fastapi.testclient import TestClient |
|
|
import numpy as np |
|
|
import cv2 |
|
|
|
|
|
from api.api_server import app, ProcessingRequest, ProcessingResponse |
|
|
from api.websocket import WebSocketHandler, WSMessage, MessageType |
|
|
|
|
|
|
|
|
class TestAPIEndpoints: |
|
|
"""Test REST API endpoints.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def client(self): |
|
|
"""Create test client.""" |
|
|
return TestClient(app) |
|
|
|
|
|
@pytest.fixture |
|
|
def auth_headers(self): |
|
|
"""Create authentication headers.""" |
|
|
|
|
|
return {"Authorization": "Bearer test-token"} |
|
|
|
|
|
def test_root_endpoint(self, client): |
|
|
"""Test root endpoint.""" |
|
|
response = client.get("/") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert "name" in data |
|
|
assert data["name"] == "BackgroundFX Pro API" |
|
|
|
|
|
def test_health_check(self, client): |
|
|
"""Test health check endpoint.""" |
|
|
response = client.get("/health") |
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
assert data["status"] == "healthy" |
|
|
assert "services" in data |
|
|
|
|
|
@patch('api.api_server.verify_token') |
|
|
def test_process_image_endpoint(self, mock_verify, client, auth_headers, sample_image): |
|
|
"""Test image processing endpoint.""" |
|
|
mock_verify.return_value = "test-user" |
|
|
|
|
|
|
|
|
_, buffer = cv2.imencode('.jpg', sample_image) |
|
|
|
|
|
files = {"file": ("test.jpg", buffer.tobytes(), "image/jpeg")} |
|
|
data = { |
|
|
"background": "blur", |
|
|
"quality": "high" |
|
|
} |
|
|
|
|
|
with patch('api.api_server.process_image_task'): |
|
|
response = client.post( |
|
|
"/api/v1/process/image", |
|
|
headers=auth_headers, |
|
|
files=files, |
|
|
data=data |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
result = response.json() |
|
|
assert "job_id" in result |
|
|
assert result["status"] == "processing" |
|
|
|
|
|
@patch('api.api_server.verify_token') |
|
|
def test_process_video_endpoint(self, mock_verify, client, auth_headers, sample_video): |
|
|
"""Test video processing endpoint.""" |
|
|
mock_verify.return_value = "test-user" |
|
|
|
|
|
with open(sample_video, 'rb') as f: |
|
|
files = {"file": ("test.mp4", f.read(), "video/mp4")} |
|
|
|
|
|
data = { |
|
|
"background": "office", |
|
|
"quality": "medium" |
|
|
} |
|
|
|
|
|
with patch('api.api_server.process_video_task'): |
|
|
response = client.post( |
|
|
"/api/v1/process/video", |
|
|
headers=auth_headers, |
|
|
files=files, |
|
|
data=data |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
result = response.json() |
|
|
assert "job_id" in result |
|
|
|
|
|
@patch('api.api_server.verify_token') |
|
|
def test_batch_processing_endpoint(self, mock_verify, client, auth_headers): |
|
|
"""Test batch processing endpoint.""" |
|
|
mock_verify.return_value = "test-user" |
|
|
|
|
|
batch_request = { |
|
|
"items": [ |
|
|
{"id": "1", "input_path": "/tmp/img1.jpg", "output_path": "/tmp/out1.jpg"}, |
|
|
{"id": "2", "input_path": "/tmp/img2.jpg", "output_path": "/tmp/out2.jpg"} |
|
|
], |
|
|
"parallel": True, |
|
|
"priority": "normal" |
|
|
} |
|
|
|
|
|
with patch('api.api_server.process_batch_task'): |
|
|
response = client.post( |
|
|
"/api/v1/batch", |
|
|
headers=auth_headers, |
|
|
json=batch_request |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
result = response.json() |
|
|
assert "job_id" in result |
|
|
|
|
|
@patch('api.api_server.verify_token') |
|
|
def test_job_status_endpoint(self, mock_verify, client, auth_headers): |
|
|
"""Test job status endpoint.""" |
|
|
mock_verify.return_value = "test-user" |
|
|
|
|
|
job_id = "test-job-123" |
|
|
|
|
|
with patch.object(app.state.job_manager, 'get_job') as mock_get: |
|
|
mock_get.return_value = ProcessingResponse( |
|
|
job_id=job_id, |
|
|
status="completed", |
|
|
progress=1.0 |
|
|
) |
|
|
|
|
|
response = client.get( |
|
|
f"/api/v1/job/{job_id}", |
|
|
headers=auth_headers |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
result = response.json() |
|
|
assert result["job_id"] == job_id |
|
|
assert result["status"] == "completed" |
|
|
|
|
|
@patch('api.api_server.verify_token') |
|
|
def test_streaming_endpoints(self, mock_verify, client, auth_headers): |
|
|
"""Test streaming endpoints.""" |
|
|
mock_verify.return_value = "test-user" |
|
|
|
|
|
|
|
|
stream_request = { |
|
|
"source": "0", |
|
|
"stream_type": "webcam", |
|
|
"output_format": "hls" |
|
|
} |
|
|
|
|
|
with patch.object(app.state.video_processor, 'start_stream_processing') as mock_start: |
|
|
mock_start.return_value = True |
|
|
|
|
|
response = client.post( |
|
|
"/api/v1/stream/start", |
|
|
headers=auth_headers, |
|
|
json=stream_request |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
result = response.json() |
|
|
assert result["status"] == "streaming" |
|
|
|
|
|
|
|
|
with patch.object(app.state.video_processor, 'stop_stream_processing'): |
|
|
response = client.get( |
|
|
"/api/v1/stream/stop", |
|
|
headers=auth_headers |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
class TestWebSocket: |
|
|
"""Test WebSocket functionality.""" |
|
|
|
|
|
@pytest.fixture |
|
|
def ws_handler(self): |
|
|
"""Create WebSocket handler.""" |
|
|
return WebSocketHandler() |
|
|
|
|
|
def test_websocket_connection(self, ws_handler, mock_websocket): |
|
|
"""Test WebSocket connection handling.""" |
|
|
|
|
|
async def test_connect(): |
|
|
await ws_handler.handle_connection(mock_websocket) |
|
|
|
|
|
|
|
|
assert mock_websocket.accept.called or True |
|
|
|
|
|
def test_message_parsing(self, ws_handler): |
|
|
"""Test WebSocket message parsing.""" |
|
|
message_data = { |
|
|
"type": "process_frame", |
|
|
"data": {"frame": "base64_data"} |
|
|
} |
|
|
|
|
|
message = WSMessage.from_dict(message_data) |
|
|
|
|
|
assert message.type == MessageType.PROCESS_FRAME |
|
|
assert message.data["frame"] == "base64_data" |
|
|
|
|
|
def test_frame_encoding_decoding(self, ws_handler, sample_image): |
|
|
"""Test frame encoding and decoding.""" |
|
|
|
|
|
_, buffer = cv2.imencode('.jpg', sample_image) |
|
|
encoded = base64.b64encode(buffer).decode('utf-8') |
|
|
|
|
|
|
|
|
decoded = ws_handler.frame_processor._decode_frame(encoded) |
|
|
|
|
|
assert decoded is not None |
|
|
assert decoded.shape == sample_image.shape |
|
|
|
|
|
def test_session_management(self, ws_handler): |
|
|
"""Test client session management.""" |
|
|
mock_ws = MagicMock() |
|
|
|
|
|
|
|
|
async def test_add(): |
|
|
session = await ws_handler.session_manager.add_session(mock_ws, "test-client") |
|
|
assert session.client_id == "test-client" |
|
|
|
|
|
|
|
|
assert ws_handler.session_manager is not None |
|
|
|
|
|
def test_message_routing(self, ws_handler): |
|
|
"""Test message routing.""" |
|
|
messages = [ |
|
|
WSMessage(type=MessageType.PING, data={}), |
|
|
WSMessage(type=MessageType.UPDATE_CONFIG, data={"quality": "high"}), |
|
|
WSMessage(type=MessageType.START_STREAM, data={"source": 0}) |
|
|
] |
|
|
|
|
|
for msg in messages: |
|
|
assert msg.type in MessageType |
|
|
assert isinstance(msg.to_dict(), dict) |
|
|
|
|
|
def test_statistics_tracking(self, ws_handler): |
|
|
"""Test WebSocket statistics.""" |
|
|
stats = ws_handler.get_statistics() |
|
|
|
|
|
assert "uptime" in stats |
|
|
assert "total_connections" in stats |
|
|
assert "active_connections" in stats |
|
|
assert "total_frames_processed" in stats |
|
|
|
|
|
|
|
|
class TestAPIIntegration: |
|
|
"""Integration tests for API.""" |
|
|
|
|
|
@pytest.mark.integration |
|
|
def test_full_image_processing_flow(self, client, sample_image, temp_dir): |
|
|
"""Test complete image processing flow.""" |
|
|
|
|
|
with patch('api.api_server.verify_token', return_value="test-user"): |
|
|
|
|
|
_, buffer = cv2.imencode('.jpg', sample_image) |
|
|
files = {"file": ("test.jpg", buffer.tobytes(), "image/jpeg")} |
|
|
|
|
|
response = client.post( |
|
|
"/api/v1/process/image", |
|
|
files=files, |
|
|
data={"background": "blur", "quality": "low"} |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
job_data = response.json() |
|
|
job_id = job_data["job_id"] |
|
|
|
|
|
|
|
|
response = client.get(f"/api/v1/job/{job_id}") |
|
|
|
|
|
|
|
|
assert response.status_code in [200, 404] |
|
|
|
|
|
@pytest.mark.integration |
|
|
@pytest.mark.slow |
|
|
def test_concurrent_requests(self, client): |
|
|
"""Test handling concurrent requests.""" |
|
|
import concurrent.futures |
|
|
|
|
|
def make_request(): |
|
|
response = client.get("/health") |
|
|
return response.status_code |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: |
|
|
futures = [executor.submit(make_request) for _ in range(10)] |
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)] |
|
|
|
|
|
assert all(status == 200 for status in results) |
|
|
|
|
|
@pytest.mark.integration |
|
|
def test_error_handling(self, client): |
|
|
"""Test API error handling.""" |
|
|
|
|
|
response = client.get("/api/v1/invalid") |
|
|
assert response.status_code == 404 |
|
|
|
|
|
|
|
|
response = client.get("/api/v1/stats") |
|
|
assert response.status_code in [401, 422] |
|
|
|
|
|
|
|
|
with patch('api.api_server.verify_token', return_value="test-user"): |
|
|
files = {"file": ("test.txt", b"text content", "text/plain")} |
|
|
response = client.post( |
|
|
"/api/v1/process/image", |
|
|
files=files, |
|
|
headers={"Authorization": "Bearer test"} |
|
|
) |
|
|
assert response.status_code == 400 |
|
|
|
|
|
|
|
|
class TestAPIPerformance: |
|
|
"""Performance tests for API.""" |
|
|
|
|
|
@pytest.mark.slow |
|
|
def test_response_time(self, client, performance_timer): |
|
|
"""Test API response times.""" |
|
|
endpoints = ["/", "/health"] |
|
|
|
|
|
for endpoint in endpoints: |
|
|
with performance_timer as timer: |
|
|
response = client.get(endpoint) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
assert timer.elapsed < 0.1 |
|
|
|
|
|
@pytest.mark.slow |
|
|
def test_file_upload_performance(self, client, performance_timer): |
|
|
"""Test file upload performance.""" |
|
|
|
|
|
large_data = np.random.randint(0, 255, (1024, 1024, 3), dtype=np.uint8) |
|
|
_, buffer = cv2.imencode('.jpg', large_data) |
|
|
|
|
|
with patch('api.api_server.verify_token', return_value="test-user"): |
|
|
with patch('api.api_server.process_image_task'): |
|
|
with performance_timer as timer: |
|
|
response = client.post( |
|
|
"/api/v1/process/image", |
|
|
files={"file": ("large.jpg", buffer.tobytes(), "image/jpeg")}, |
|
|
headers={"Authorization": "Bearer test"} |
|
|
) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
assert timer.elapsed < 2.0 |