| """Tests for /pipeline/{bbb,eeg,mri} POST endpoints.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from unittest.mock import patch |
|
|
| import pandas as pd |
| import pytest |
| from fastapi.testclient import TestClient |
|
|
| from src.api.main import app |
|
|
|
|
| client = TestClient(app) |
| _FIXTURES = Path(__file__).resolve().parents[1] / "fixtures" |
|
|
|
|
| class TestBBBRoute: |
| def test_returns_200_with_valid_input(self, tmp_path: Path): |
| out = tmp_path / "out.parquet" |
| resp = client.post( |
| "/pipeline/bbb", |
| json={ |
| "input_path": str(_FIXTURES / "bbbp_sample.csv"), |
| "output_path": str(out), |
| }, |
| ) |
| assert resp.status_code == 200 |
| body = resp.json() |
| assert body["status"] == "ok" |
| assert body["rows"] > 0 |
| assert out.exists() |
|
|
| def test_returns_404_when_input_missing(self, tmp_path: Path): |
| resp = client.post( |
| "/pipeline/bbb", |
| json={ |
| "input_path": str(tmp_path / "does_not_exist.csv"), |
| "output_path": str(tmp_path / "out.parquet"), |
| }, |
| ) |
| assert resp.status_code == 404 |
|
|
| def test_returns_422_on_malformed_body(self): |
| resp = client.post("/pipeline/bbb", json={"banana": 1}) |
| assert resp.status_code == 422 |
|
|
|
|
| class TestEEGRoute: |
| def test_returns_200_with_valid_input(self, tmp_path: Path): |
| fif = _FIXTURES / "eeg_sample.fif" |
| out = tmp_path / "out.parquet" |
| resp = client.post( |
| "/pipeline/eeg", |
| json={"input_path": str(fif), "output_path": str(out)}, |
| ) |
| assert resp.status_code == 200 |
| assert resp.json()["rows"] > 0 |
|
|
|
|
| class TestMRIRoute: |
| def test_returns_200_with_valid_input(self, tmp_path: Path): |
| from tests.fixtures.build_mri_fixture import build as build_mri |
| fixture_dir = build_mri(out_dir=tmp_path / "mri_fixture") |
| out = tmp_path / "out.parquet" |
| resp = client.post( |
| "/pipeline/mri", |
| json={ |
| "input_dir": str(fixture_dir), |
| "sites_csv": str(fixture_dir / "sites.csv"), |
| "output_path": str(out), |
| }, |
| ) |
| assert resp.status_code == 200 |
| assert resp.json()["rows"] > 0 |
|
|
|
|
| class TestPipelineWrap: |
| def test_wrap_skips_mlflow_lookup_when_disabled(self, tmp_path: Path, monkeypatch): |
| from src.api import routes |
|
|
| out = tmp_path / "out.parquet" |
| pd.DataFrame({"x": [1]}).to_parquet(out) |
| monkeypatch.setenv("NEUROBRIDGE_DISABLE_MLFLOW", "1") |
|
|
| with patch("src.api.routes.mlflow.search_runs") as search_runs: |
| resp = routes._wrap("bbb_pipeline", out, lambda: None) |
|
|
| search_runs.assert_not_called() |
| assert resp.status == "ok" |
| assert resp.mlflow_run_id is None |
|
|
|
|
| class TestBBBPredictRoute: |
| def _setup_model_artifact(self, tmp_path: Path) -> Path: |
| """Build features + train + save a tiny model. Returns artifact path.""" |
| from src.pipelines import bbb_pipeline |
| from src.models import bbb_model |
| import pandas as pd |
| features_path = tmp_path / "features.parquet" |
| bbb_pipeline.run_pipeline( |
| input_path=_FIXTURES / "bbbp_sample.csv", |
| output_path=features_path, |
| ) |
| df = pd.read_parquet(features_path) |
| model = bbb_model.train(df, label_col="p_np", n_estimators=10, random_state=42) |
| artifact = tmp_path / "bbb_model.joblib" |
| bbb_model.save(model, artifact) |
| return artifact |
|
|
| @pytest.fixture |
| def _set_bbb_model_path(self, tmp_path: Path, monkeypatch): |
| """Build a model artifact and point BBB_MODEL_PATH at it for the test.""" |
| artifact = self._setup_model_artifact(tmp_path) |
| monkeypatch.setenv("BBB_MODEL_PATH", str(artifact)) |
| return artifact |
|
|
| def test_returns_200_with_prediction_and_attributions(self, tmp_path: Path, monkeypatch): |
| artifact = self._setup_model_artifact(tmp_path) |
| monkeypatch.setenv("BBB_MODEL_PATH", str(artifact)) |
|
|
| resp = client.post( |
| "/predict/bbb", |
| json={"smiles": "CCO", "top_k": 5}, |
| ) |
| assert resp.status_code == 200 |
| body = resp.json() |
| assert body["label"] in (0, 1) |
| assert body["label_text"] in ("permeable", "non-permeable") |
| assert 0.0 <= body["confidence"] <= 1.0 |
| assert len(body["top_features"]) == 5 |
| for f in body["top_features"]: |
| assert f["feature"].startswith("fp_") |
| assert isinstance(f["shap_value"], float) |
| |
| |
| assert body["calibration"] is not None |
| cal = body["calibration"] |
| valid_thresholds = [0.50, 0.60, 0.70, 0.75, 0.80, 0.90] |
| assert any( |
| cal["threshold"] == pytest.approx(t) for t in valid_thresholds |
| ), f"threshold {cal['threshold']} not in {valid_thresholds}" |
| assert cal["threshold"] <= body["confidence"] |
| assert 0.0 <= cal["precision"] <= 1.0 |
| assert isinstance(cal["support"], int) |
| assert cal["support"] >= 0 |
|
|
| def test_predict_response_includes_drift_z_and_rolling_n( |
| self, _set_bbb_model_path, |
| ): |
| """T1B: drift_z and rolling_n keys must always appear in the body.""" |
| |
| from src.api import routes |
| routes.WORKER_CONFIDENCE_DEQUE.clear() |
|
|
| resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 5}) |
| assert resp.status_code == 200, resp.text |
| body = resp.json() |
| assert "drift_z" in body |
| assert "rolling_n" in body |
| |
| assert body["rolling_n"] == 1 |
| assert body["drift_z"] is None |
|
|
| def test_predict_deque_rolls_at_100(self, _set_bbb_model_path): |
| """T1B: after 100 predictions, deque caps at maxlen=100 (rolls).""" |
| from src.api import routes |
| routes.WORKER_CONFIDENCE_DEQUE.clear() |
| |
| last_body = None |
| for _ in range(105): |
| resp = client.post( |
| "/predict/bbb", json={"smiles": "CCO", "top_k": 3}, |
| ) |
| assert resp.status_code == 200 |
| last_body = resp.json() |
| assert last_body["rolling_n"] == 100 |
| |
| assert isinstance(last_body["drift_z"], float) |
|
|
| def test_predict_response_includes_provenance(self, _set_bbb_model_path): |
| """T2: provenance field is present in body (fields may be None).""" |
| from src.api import routes |
| routes.WORKER_CONFIDENCE_DEQUE.clear() |
|
|
| resp = client.post("/predict/bbb", json={"smiles": "CCO", "top_k": 3}) |
| assert resp.status_code == 200, resp.text |
| body = resp.json() |
| assert "provenance" in body |
| assert body["provenance"] is not None, "provenance should be populated even when MLflow is empty" |
| prov = body["provenance"] |
| assert "mlflow_run_id" in prov |
| assert "model_version" in prov |
| assert prov["model_version"] == "v1" |
| assert "train_date" in prov |
| assert "n_examples" in prov |
| |
| assert isinstance(prov["n_examples"], int) and prov["n_examples"] >= 1 |
|
|
| def test_returns_400_on_invalid_smiles(self, tmp_path: Path, monkeypatch): |
| artifact = self._setup_model_artifact(tmp_path) |
| monkeypatch.setenv("BBB_MODEL_PATH", str(artifact)) |
|
|
| resp = client.post( |
| "/predict/bbb", |
| json={"smiles": "this_is_not_a_smiles", "top_k": 5}, |
| ) |
| assert resp.status_code == 400 |
|
|
| def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch): |
| monkeypatch.setenv("BBB_MODEL_PATH", str(tmp_path / "does_not_exist.joblib")) |
| resp = client.post( |
| "/predict/bbb", |
| json={"smiles": "CCO", "top_k": 5}, |
| ) |
| assert resp.status_code == 503 |
|
|
|
|
| class TestMRIPredictRoute: |
| def test_returns_503_when_artifact_missing(self, tmp_path: Path, monkeypatch): |
| monkeypatch.setenv("MRI_MODEL_PATH", str(tmp_path / "missing.onnx")) |
|
|
| resp = client.post( |
| "/predict/mri", |
| json={"input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz")}, |
| ) |
|
|
| assert resp.status_code == 503 |
| assert "MRI model artifact not available" in resp.text |
|
|
| def test_returns_404_when_input_missing(self, tmp_path: Path, monkeypatch): |
| from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx |
|
|
| artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx") |
| monkeypatch.setenv("MRI_MODEL_PATH", str(artifact)) |
|
|
| resp = client.post( |
| "/predict/mri", |
| json={"input_path": str(tmp_path / "missing.nii.gz"), "target_shape": [8, 8, 8]}, |
| ) |
|
|
| assert resp.status_code == 404 |
|
|
| def test_returns_200_with_prediction(self, tmp_path: Path, monkeypatch): |
| from tests.fixtures.build_dummy_mri_onnx import build as build_dummy_mri_onnx |
|
|
| artifact = build_dummy_mri_onnx(tmp_path / "mri_model.onnx") |
| monkeypatch.setenv("MRI_MODEL_PATH", str(artifact)) |
|
|
| resp = client.post( |
| "/predict/mri", |
| json={ |
| "input_path": str(_FIXTURES / "mri_sample" / "subject_0.nii.gz"), |
| "target_shape": [8, 8, 8], |
| "label_names": ["control", "abnormal"], |
| }, |
| ) |
|
|
| assert resp.status_code == 200, resp.text |
| body = resp.json() |
| assert body["label"] == 1 |
| assert body["label_text"] == "abnormal" |
| assert body["confidence"] > 0.5 |
| assert body["input_path"].endswith("subject_0.nii.gz") |
| assert body["model_path"] == str(artifact) |
| assert len(body["probabilities"]) == 2 |
|
|
|
|
| class TestMRIDiagnosticsRoute: |
| def test_returns_200_with_pre_and_post_data(self, tmp_path: Path): |
| from tests.fixtures.build_mri_fixture import build as build_mri |
| fixture_dir = build_mri(out_dir=tmp_path / "mri") |
| resp = client.post( |
| "/pipeline/mri/diagnostics", |
| json={ |
| "input_dir": str(fixture_dir), |
| "sites_csv": str(fixture_dir / "sites.csv"), |
| }, |
| ) |
| assert resp.status_code == 200 |
| body = resp.json() |
| assert len(body["rows"]) > 0 |
| assert body["site_gap_pre"] >= 0.0 |
| assert body["site_gap_post"] >= 0.0 |
| |
| assert body["reduction_factor"] >= 1.0 |
| states = {r["harmonization_state"] for r in body["rows"]} |
| assert states == {"Pre-ComBat", "Post-ComBat"} |
|
|
| def test_returns_404_when_input_dir_missing(self, tmp_path: Path): |
| resp = client.post( |
| "/pipeline/mri/diagnostics", |
| json={ |
| "input_dir": str(tmp_path / "does_not_exist"), |
| "sites_csv": str(tmp_path / "sites.csv"), |
| }, |
| ) |
| assert resp.status_code == 404 |
|
|
|
|
| class TestExplainBBBRoute: |
| """Day-7 T3B: POST /explain/bbb.""" |
|
|
| def test_returns_200_with_template_source(self, monkeypatch): |
| """Kill-switch on → /explain/bbb returns rationale with source=template.""" |
| monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1") |
| body = { |
| "smiles": "CCO", |
| "label": 1, |
| "label_text": "permeable", |
| "confidence": 0.82, |
| "top_features": [ |
| {"feature": "fp_341", "shap_value": 0.045}, |
| {"feature": "fp_902", "shap_value": -0.031}, |
| {"feature": "fp_77", "shap_value": 0.022}, |
| ], |
| "calibration": {"threshold": 0.80, "precision": 0.92, "support": 18}, |
| "drift_z": 0.42, |
| "user_question": "Why permeable?", |
| } |
| resp = client.post("/explain/bbb", json=body) |
| assert resp.status_code == 200, resp.text |
| out = resp.json() |
| assert out["source"] == "template" |
| assert out["model"] is None |
| |
| for feat in ("fp_341", "fp_902", "fp_77"): |
| assert feat in out["rationale"] |
| assert "permeable" in out["rationale"] |
|
|
|
|
| class TestExplainEEGRoute: |
| """Day-8 T1B: POST /explain/eeg.""" |
|
|
| def test_returns_200_with_template_source(self, monkeypatch): |
| monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1") |
| body = { |
| "rows": 30, |
| "columns": 95, |
| "duration_sec": 4.32, |
| "mlflow_run_id": "abc12345", |
| "user_question": "Why were epochs dropped?", |
| } |
| resp = client.post("/explain/eeg", json=body) |
| assert resp.status_code == 200, resp.text |
| out = resp.json() |
| assert out["source"] == "template" |
| assert out["model"] is None |
| assert "30" in out["rationale"] |
| assert "95" in out["rationale"] |
|
|
|
|
| class TestExplainMRIRoute: |
| """Day-8 T1B: POST /explain/mri.""" |
|
|
| def test_returns_200_with_template_source(self, monkeypatch): |
| monkeypatch.setenv("NEUROBRIDGE_DISABLE_LLM", "1") |
| body = { |
| "site_gap_pre": 5.0004, |
| "site_gap_post": 0.0015, |
| "reduction_factor": 3290.0, |
| "n_subjects": 6, |
| "user_question": "Why does ComBat matter?", |
| } |
| resp = client.post("/explain/mri", json=body) |
| assert resp.status_code == 200, resp.text |
| out = resp.json() |
| assert out["source"] == "template" |
| assert "3290" in out["rationale"] |
| assert "6" in out["rationale"] |
|
|
|
|
| class TestExperimentsRoutes: |
| """Day-8 T2A: GET /experiments/runs and POST /experiments/diff.""" |
|
|
| def test_runs_endpoint_returns_list(self): |
| """GET /experiments/runs returns a runs list (may be empty if no MLflow data).""" |
| resp = client.get("/experiments/runs") |
| assert resp.status_code == 200, resp.text |
| body = resp.json() |
| assert "runs" in body |
| assert isinstance(body["runs"], list) |
| |
| for run in body["runs"]: |
| for key in ("run_id", "experiment_name", "start_time", "status", "metrics", "params"): |
| assert key in run |
|
|
| def test_diff_endpoint_handles_unknown_runs_gracefully(self): |
| """POST /experiments/diff with bogus run ids returns 404 (not 500).""" |
| resp = client.post( |
| "/experiments/diff", |
| json={"run_id_a": "nonexistent_aaa", "run_id_b": "nonexistent_bbb"}, |
| ) |
| assert resp.status_code in (404, 200), ( |
| f"unexpected status {resp.status_code}: {resp.text}" |
| ) |
| |
| |
| body = resp.json() |
| if resp.status_code == 200: |
| assert body.get("rows", []) == [] |
|
|