hackathon / tests /api /test_routes.py
bekir32419's picture
Add project files
c0a7163
"""Tests for /pipeline/{bbb,eeg,mri} POST endpoints."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pandas as pd
import pytest
from fastapi.testclient import TestClient
from src.api.main import app
client = TestClient(app)
_FIXTURES = Path(__file__).resolve().parents[1] / "fixtures"
class TestBBBRoute:
def test_returns_200_with_valid_input(self, tmp_path: Path):
out = tmp_path / "out.parquet"
resp = client.post(
"/pipeline/bbb",
json={
"input_path": str(_FIXTURES / "bbbp_sample.csv"),
"output_path": str(out),
},
)
assert resp.status_code == 200
body = resp.json()
assert body["status"] == "ok"
assert body["rows"] > 0
assert out.exists()
def test_returns_404_when_input_missing(self, tmp_path: Path):
resp = client.post(
"/pipeline/bbb",
json={
"input_path": str(tmp_path / "does_not_exist.csv"),
"output_path": str(tmp_path / "out.parquet"),
},
)
assert resp.status_code == 404
def test_returns_422_on_malformed_body(self):
resp = client.post("/pipeline/bbb", json={"banana": 1})
assert resp.status_code == 422 # pydantic validation
class TestEEGRoute:
def test_returns_200_with_valid_input(self, tmp_path: Path):
fif = _FIXTURES / "eeg_sample.fif"
out = tmp_path / "out.parquet"
resp = client.post(
"/pipeline/eeg",
json={"input_path": str(fif), "output_path": str(out)},
)
assert resp.status_code == 200
assert resp.json()["rows"] > 0
class TestMRIRoute:
def test_returns_200_with_valid_input(self, tmp_path: Path):
from tests.fixtures.build_mri_fixture import build as build_mri
fixture_dir = build_mri(out_dir=tmp_path / "mri_fixture")
out = tmp_path / "out.parquet"
resp = client.post(
"/pipeline/mri",
json={
"input_dir": str(fixture_dir),
"sites_csv": str(fixture_dir / "sites.csv"),
"output_path": str(out),
},
)
assert resp.status_code == 200
assert resp.json()["rows"] > 0
class TestPipelineWrap:
def test_wrap_skips_mlflow_lookup_when_disabled(self, tmp_path: Path, monkeypatch):
from src.api import routes
out = tmp_path / "out.parquet"
pd.DataFrame({"x": [1]}).to_parquet(out)
monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1")
with patch("src.api.routes.mlflow.search_runs") as search_runs:
resp = routes._wrap("bbb_pipeline", out, lambda: None)
search_runs.assert_not_called()
assert resp.status == "ok"
assert resp.mlflow_run_id is None
class TestBBBPredictRoute:
def _setup_model_artifact(self, tmp_path: Path) -> Path:
"""Build features + train + save a tiny model. Returns artifact path."""
from src.pipelines import bbb_pipeline
from src.models import bbb_model
import pandas as pd
features_path = tmp_path / "features.parquet"
bbb_pipeline.run_pipeline(
input_path=_FIXTURES / "bbbp_sample.csv",
output_path=features_path,
)
df = pd.read_parquet(features_path)
model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42)
artifact = tmp_path / "bbb_model.joblib"
bbb_model.save(model, artifact)
return artifact
@pytest.fixture
def _set_bbb_model_path(self, tmp_path: Path, monkeypatch):
"""Build a model artifact and point BBB_MODEL_PATH at it for the test."""
artifact = self._setup_model_artifact(tmp_path)
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
return artifact
def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch):
artifact = self._setup_model_artifact(tmp_path)
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
resp = client.post(
"/predict/bbb",
json={"smiles": "CCO", "top_k": 5},
)
assert resp.status_code == 200
body = resp.json()
assert body["label"] in (0, 1)
assert body["label_text"] in ("permeable", "non-permeable")
assert 0.0 <= body["confidence"] <= 1.0
assert len(body["top_features"]) == 5
for f in body["top_features"]:
assert f["feature"].startswith("fp_")
assert isinstance(f["shap_value"], float)
# Day-6 calibration assertions: trained test fixture model has
# _neurobridge_calibration metadata, so calibration must be populated.
assert body["calibration"] is not None
cal = body["calibration"]
valid_thresholds = [0.50, 0.60, 0.70, 0.75, 0.80, 0.90]
assert any(
cal["threshold"] == pytest.approx(t) for t in valid_thresholds
), f"threshold {cal['threshold']} not in {valid_thresholds}"
assert cal["threshold"] <= body["confidence"]
assert 0.0 <= cal["precision"] <= 1.0
assert isinstance(cal["support"], int)
assert cal["support"] >= 0
def test_predict_response_includes_drift_z_and_rolling_n(
self, _set_bbb_model_path,
):
"""T1B: drift_z and rolling_n keys must always appear in the body."""
# Reset deque before this test so rolling_n starts deterministic.
from src.api import routes
routes.WORKER_CONFIDENCE_DEQUE.clear()
resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 5})
assert resp.status_code == 200, resp.text
body = resp.json()
assert "drift_z" in body
assert "rolling_n" in body
# First request: buffer has 1 sample (just appended), so warming up.
assert body["rolling_n"] == 1
assert body["drift_z"] is None # <10 samples = warming up
def test_predict_deque_rolls_at_100(self, _set_bbb_model_path):
"""T1B: after 100 predictions, deque caps at maxlen=100 (rolls)."""
from src.api import routes
routes.WORKER_CONFIDENCE_DEQUE.clear()
# Fire 105 calls; final rolling_n must be 100, not 105.
last_body = None
for _ in range(105):
resp = client.post(
"/predict/bbb", json={"smiles": "CCO", "top_k": 3},
)
assert resp.status_code == 200
last_body = resp.json()
assert last_body["rolling_n"] == 100
# By call 105, drift_z is computable (≥10 samples) — assert numeric.
assert isinstance(last_body["drift_z"], float)
def test_predict_response_includes_provenance(self, _set_bbb_model_path):
"""T2: provenance field is present in body (fields may be None)."""
from src.api import routes
routes.WORKER_CONFIDENCE_DEQUE.clear()
resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 3})
assert resp.status_code == 200, resp.text
body = resp.json()
assert "provenance" in body
assert body["provenance"] is not None, "provenance should be populated even when MLflow is empty"
prov = body["provenance"]
assert "mlflow_run_id" in prov
assert "model_version" in prov
assert prov["model_version"] == "v1" # default until bumped manually
assert "train_date" in prov
assert "n_examples" in prov
# n_examples comes from train_stats — must be a positive int for the test fixture
assert isinstance(prov["n_examples"], int) and prov["n_examples"] >= 1
def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch):
artifact = self._setup_model_artifact(tmp_path)
monkeypatch.setenv("BBB_MODEL_PATH", str(artifact))
resp = client.post(
"/predict/bbb",
json={"smiles": "this_is_not_a_smiles", "top_k": 5},
)
assert resp.status_code == 400
def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch):
monkeypatch.setenv("BBB_MODEL_PATH", str(tmp_path / "does_not_exist.joblib"))
resp = client.post(
"/predict/bbb",
json={"smiles": "CCO", "top_k": 5},
)
assert resp.status_code == 503
class TestMRIPredictRoute:
def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch):
monkeypatch.setenv("MRI_MODEL_PATH", str(tmp_path / "missing.onnx"))
resp = client.post(
"/predict/mri",
json={"input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz")},
)
assert resp.status_code == 503
assert "MRI model artifact not available" in resp.text
def test_returns_404_when_input_missing(self, tmp_path: Path, monkeypatch):
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx
artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
monkeypatch.setenv("MRI_MODEL_PATH", str(artifact))
resp = client.post(
"/predict/mri",
json={"input_path": str(tmp_path / "missing.nii.gz"), "target_shape": [8, 8, 8]},
)
assert resp.status_code == 404
def test_returns_200_with_prediction(self, tmp_path: Path, monkeypatch):
from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx
artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx")
monkeypatch.setenv("MRI_MODEL_PATH", str(artifact))
resp = client.post(
"/predict/mri",
json={
"input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz"),
"target_shape": [8, 8, 8],
"label_names": ["control", "abnormal"],
},
)
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["label"] == 1
assert body["label_text"] == "abnormal"
assert body["confidence"] > 0.5
assert body["input_path"].endswith("subject_0.nii.gz")
assert body["model_path"] == str(artifact)
assert len(body["probabilities"]) == 2
class TestMRIDiagnosticsRoute:
def test_returns_200_with_pre_and_post_data(self, tmp_path: Path):
from tests.fixtures.build_mri_fixture import build as build_mri
fixture_dir = build_mri(out_dir=tmp_path / "mri")
resp = client.post(
"/pipeline/mri/diagnostics",
json={
"input_dir": str(fixture_dir),
"sites_csv": str(fixture_dir / "sites.csv"),
},
)
assert resp.status_code == 200
body = resp.json()
assert len(body["rows"]) > 0
assert body["site_gap_pre"] >= 0.0
assert body["site_gap_post"] >= 0.0
# Reduction factor is the headline KPI
assert body["reduction_factor"] >= 1.0 # ComBat must reduce, not amplify
states = {r["harmonization_state"] for r in body["rows"]}
assert states == {"Pre-ComBat", "Post-ComBat"}
def test_returns_404_when_input_dir_missing(self, tmp_path: Path):
resp = client.post(
"/pipeline/mri/diagnostics",
json={
"input_dir": str(tmp_path / "does_not_exist"),
"sites_csv": str(tmp_path / "sites.csv"),
},
)
assert resp.status_code == 404
class TestExplainBBBRoute:
"""Day-7 T3B: POST /explain/bbb."""
def test_returns_200_with_template_source(self, monkeypatch):
"""Kill-switch on → /explain/bbb returns rationale with source=template."""
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
body = {
"smiles": "CCO",
"label": 1,
"label_text": "permeable",
"confidence": 0.82,
"top_features": [
{"feature": "fp_341", "shap_value": 0.045},
{"feature": "fp_902", "shap_value": -0.031},
{"feature": "fp_77", "shap_value": 0.022},
],
"calibration": {"threshold": 0.80, "precision": 0.92, "support": 18},
"drift_z": 0.42,
"user_question": "Why permeable?",
}
resp = client.post("/explain/bbb", json=body)
assert resp.status_code == 200, resp.text
out = resp.json()
assert out["source"] == "template"
assert out["model"] is None
# Template must mention all three features
for feat in ("fp_341", "fp_902", "fp_77"):
assert feat in out["rationale"]
assert "permeable" in out["rationale"]
class TestExplainEEGRoute:
"""Day-8 T1B: POST /explain/eeg."""
def test_returns_200_with_template_source(self, monkeypatch):
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
body = {
"rows": 30,
"columns": 95,
"duration_sec": 4.32,
"mlflow_run_id": "abc12345",
"user_question": "Why were epochs dropped?",
}
resp = client.post("/explain/eeg", json=body)
assert resp.status_code == 200, resp.text
out = resp.json()
assert out["source"] == "template"
assert out["model"] is None
assert "30" in out["rationale"]
assert "95" in out["rationale"]
class TestExplainMRIRoute:
"""Day-8 T1B: POST /explain/mri."""
def test_returns_200_with_template_source(self, monkeypatch):
monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1")
body = {
"site_gap_pre": 5.0004,
"site_gap_post": 0.0015,
"reduction_factor": 3290.0,
"n_subjects": 6,
"user_question": "Why does ComBat matter?",
}
resp = client.post("/explain/mri", json=body)
assert resp.status_code == 200, resp.text
out = resp.json()
assert out["source"] == "template"
assert "3290" in out["rationale"]
assert "6" in out["rationale"]
class TestExperimentsRoutes:
"""Day-8 T2A: GET /experiments/runs and POST /experiments/diff."""
def test_runs_endpoint_returns_list(self):
"""GET /experiments/runs returns a runs list (may be empty if no MLflow data)."""
resp = client.get("/experiments/runs")
assert resp.status_code == 200, resp.text
body = resp.json()
assert "runs" in body
assert isinstance(body["runs"], list)
# If any runs exist, each must have the expected keys
for run in body["runs"]:
for key in ("run_id", "experiment_name", "start_time", "status", "metrics", "params"):
assert key in run
def test_diff_endpoint_handles_unknown_runs_gracefully(self):
"""POST /experiments/diff with bogus run ids returns 404 (not 500)."""
resp = client.post(
"/experiments/diff",
json={"run_id_a": "nonexistent_aaa", "run_id_b": "nonexistent_bbb"},
)
assert resp.status_code in (404, 200), (
f"unexpected status {resp.status_code}: {resp.text}"
)
# 404 is the documented contract; 200 with empty rows is acceptable too
# because some MLflow stores treat unknown ids as "empty result".
body = resp.json()
if resp.status_code == 200:
assert body.get("rows", []) == []