| from __future__ import annotations | |
| from pathlib import Path | |
| from fastapi import HTTPException | |
| from fastapi.testclient import TestClient | |
| from api.app import app, get_inference_service | |
| class DummyService: | |
| threshold = 0.74 | |
| model_path = Path("models/model.pkl") | |
| preprocessor_path = Path("models/preprocessor.pkl") | |
| def predict_records(self, records): | |
| outputs = [] | |
| for record in records: | |
| amount = float(record["Amount"]) | |
| prob = 0.9 if amount > 200 else 0.1 | |
| outputs.append( | |
| { | |
| "is_fraud": prob >= self.threshold, | |
| "fraud_probability": prob, | |
| "risk_level": "high" if prob >= 0.7 else "low", | |
| "threshold": self.threshold, | |
| } | |
| ) | |
| return outputs | |
| def _transaction(amount: float = 10.0) -> dict[str, float]: | |
| payload = {"Time": 0.0, "Amount": amount} | |
| for i in range(1, 29): | |
| payload[f"V{i}"] = 0.0 | |
| return payload | |
| def test_health_endpoint() -> None: | |
| app.dependency_overrides[get_inference_service] = lambda: DummyService() | |
| client = TestClient(app) | |
| response = client.get("/health") | |
| assert response.status_code == 200 | |
| body = response.json() | |
| assert body["status"] == "ok" | |
| assert body["model_loaded"] is True | |
| app.dependency_overrides.clear() | |
| def test_predict_endpoint_valid_payload() -> None: | |
| app.dependency_overrides[get_inference_service] = lambda: DummyService() | |
| client = TestClient(app) | |
| response = client.post("/predict", json=_transaction(amount=350.0)) | |
| assert response.status_code == 200 | |
| body = response.json() | |
| assert body["is_fraud"] is True | |
| assert body["risk_level"] == "high" | |
| assert response.headers.get("X-Request-ID") | |
| app.dependency_overrides.clear() | |
| def test_predict_endpoint_invalid_payload() -> None: | |
| app.dependency_overrides[get_inference_service] = lambda: DummyService() | |
| client = TestClient(app) | |
| payload = _transaction() | |
| payload.pop("V28") | |
| response = client.post("/predict", json=payload) | |
| assert response.status_code == 422 | |
| app.dependency_overrides.clear() | |
| def test_batch_prediction_endpoint() -> None: | |
| app.dependency_overrides[get_inference_service] = lambda: DummyService() | |
| client = TestClient(app) | |
| response = client.post( | |
| "/predict/batch", | |
| json={"transactions": [_transaction(20.0), _transaction(300.0)]}, | |
| ) | |
| assert response.status_code == 200 | |
| body = response.json() | |
| assert len(body["predictions"]) == 2 | |
| assert body["predictions"][0]["is_fraud"] is False | |
| assert body["predictions"][1]["is_fraud"] is True | |
| app.dependency_overrides.clear() | |
| def test_metrics_endpoint_tracks_predictions_and_requests() -> None: | |
| app.dependency_overrides[get_inference_service] = lambda: DummyService() | |
| client = TestClient(app) | |
| before = client.get("/metrics") | |
| assert before.status_code == 200 | |
| before_body = before.json() | |
| predict_response = client.post("/predict", json=_transaction(amount=350.0)) | |
| assert predict_response.status_code == 200 | |
| after = client.get("/metrics") | |
| assert after.status_code == 200 | |
| after_body = after.json() | |
| assert after_body["total_requests"] >= before_body["total_requests"] + 2 | |
| assert after_body["total_predictions"] >= before_body["total_predictions"] + 1 | |
| assert 0.0 <= after_body["error_rate"] <= 1.0 | |
| assert 0.0 <= after_body["fraud_prediction_rate"] <= 1.0 | |
| app.dependency_overrides.clear() | |
| def test_health_returns_503_when_service_unavailable() -> None: | |
| def _raise(): | |
| raise HTTPException(status_code=503, detail="Model artifact not found") | |
| app.dependency_overrides[get_inference_service] = _raise | |
| client = TestClient(app) | |
| response = client.get("/health") | |
| assert response.status_code == 503 | |
| assert "Model artifact not found" in response.json()["detail"] | |
| app.dependency_overrides.clear() | |