| """Tests for src.agents.tools — Tool dataclass + registry + 4 tool wrappers.""" |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from types import SimpleNamespace |
| from unittest.mock import patch |
|
|
| import pytest |
| from pydantic import BaseModel |
|
|
| from src.agents.tools import ( |
| Tool, |
| build_default_tools, |
| BBBPipelineInput, |
| EEGPipelineInput, |
| MRIPipelineInput, |
| RetrieveContextInput, |
| ) |
|
|
|
|
| class _DummyInput(BaseModel): |
| x: int |
| y: str = "default" |
|
|
|
|
| class _DummyOutput(BaseModel): |
| result: int |
|
|
|
|
| class TestTool: |
| def test_openai_schema_shape(self) -> None: |
| tool = Tool( |
| name="dummy", |
| description="A dummy tool", |
| input_model=_DummyInput, |
| output_model=_DummyOutput, |
| execute=lambda inp: _DummyOutput(result=inp.x * 2), |
| ) |
| schema = tool.openai_schema() |
| assert schema["type"] == "function" |
| assert schema["function"]["name"] == "dummy" |
| assert schema["function"]["description"] == "A dummy tool" |
| params = schema["function"]["parameters"] |
| assert params["type"] == "object" |
| assert "x" in params["properties"] |
| assert "x" in params["required"] |
| assert "y" not in params["required"] |
|
|
| def test_invoke_validates_and_returns_dict(self) -> None: |
| tool = Tool( |
| name="dummy", |
| description="d", |
| input_model=_DummyInput, |
| output_model=_DummyOutput, |
| execute=lambda inp: _DummyOutput(result=inp.x * 2), |
| ) |
| out = tool.invoke({"x": 5}) |
| assert out == {"result": 10} |
|
|
| def test_invoke_invalid_input_raises(self) -> None: |
| tool = Tool( |
| name="dummy", |
| description="d", |
| input_model=_DummyInput, |
| output_model=_DummyOutput, |
| execute=lambda inp: _DummyOutput(result=inp.x * 2), |
| ) |
| with pytest.raises(ValueError, match="invalid input"): |
| tool.invoke({"y": "missing-x"}) |
|
|
|
|
| class TestBuildDefaultTools: |
| def test_default_set_has_seven_tools(self, tmp_path: Path) -> None: |
| |
| tools = build_default_tools(rag_index_dir=None) |
| names = {t.name for t in tools} |
| assert names == { |
| "run_bbb_pipeline", |
| "run_eeg_pipeline", |
| "run_mri_pipeline", |
| "retrieve_context", |
| "run_fusion", |
| "compute_bbb_leakage_score", |
| "adjust_drug_dose", |
| } |
|
|
| def test_each_tool_has_pydantic_input_model(self) -> None: |
| tools = build_default_tools(rag_index_dir=None) |
| for t in tools: |
| assert issubclass(t.input_model, BaseModel) |
| assert issubclass(t.output_model, BaseModel) |
|
|
| def test_input_models_have_smiles_paths(self) -> None: |
| |
| assert "smiles" in BBBPipelineInput.model_fields |
| assert "input_path" in EEGPipelineInput.model_fields |
| assert "input_dir" in MRIPipelineInput.model_fields |
| assert "sites_csv" in MRIPipelineInput.model_fields |
| assert "sites_csv" not in MRIPipelineInput.model_json_schema().get("required", []) |
| assert "query" in RetrieveContextInput.model_fields |
| assert "k" in RetrieveContextInput.model_fields |
|
|
| def test_retrieve_context_short_circuits_when_no_index(self) -> None: |
| tools = build_default_tools(rag_index_dir=None) |
| retrieve = next(t for t in tools if t.name == "retrieve_context") |
| out = retrieve.invoke({"query": "anything", "k": 3}) |
| assert out == {"query": "anything", "chunks": []} |
|
|
| def test_processed_dir_parameter_threads_to_executors(self, tmp_path: Path) -> None: |
| |
| |
| |
| tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path) |
| names = {t.name for t in tools} |
| assert "run_eeg_pipeline" in names |
| assert "run_mri_pipeline" in names |
|
|
| def test_default_processed_dir_when_omitted(self) -> None: |
| |
| tools = build_default_tools(rag_index_dir=None) |
| |
| assert len(tools) == 7 |
|
|
| def test_bbb_executor_translates_httpexception_to_valueerror(self) -> None: |
| from fastapi import HTTPException |
|
|
| tools = build_default_tools(rag_index_dir=None) |
| bbb = next(t for t in tools if t.name == "run_bbb_pipeline") |
|
|
| with patch("src.api.routes.predict_bbb", |
| side_effect=HTTPException(status_code=503, detail="model missing")): |
| with pytest.raises(ValueError, match="bbb tool failed"): |
| bbb.invoke({"smiles": "CCO"}) |
|
|
| def test_mri_executor_defaults_sites_csv_to_input_dir_sites_csv(self, tmp_path: Path) -> None: |
| tools = build_default_tools(rag_index_dir=None, processed_dir=tmp_path / "processed") |
| mri = next(t for t in tools if t.name == "run_mri_pipeline") |
| input_dir = tmp_path / "mri" |
| input_dir.mkdir() |
|
|
| with patch( |
| "src.api.routes.run_mri", |
| return_value=SimpleNamespace( |
| output_path=str(tmp_path / "processed" / "mri_features.parquet"), |
| rows=2, |
| columns=3, |
| duration_sec=0.1, |
| ), |
| ) as run_mri: |
| out = mri.invoke({"input_dir": str(input_dir)}) |
|
|
| assert out["rows"] == 2 |
| req = run_mri.call_args.args[0] |
| assert req.input_dir == str(input_dir) |
| assert req.sites_csv == str(input_dir / "sites.csv") |
|
|