Spaces:
Sleeping
Sleeping
| """ | |
| Tests for src/services/agents/ — agentic RAG pipeline. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from unittest.mock import MagicMock | |
| # ----------------------------------------------------------------------- | |
| # Mock context and LLM | |
| # ----------------------------------------------------------------------- | |
| class MockMessage: | |
| def __init__(self, content: str): | |
| self.content = content | |
| class MockLLM: | |
| """Programmable mock LLM that returns canned responses.""" | |
| def __init__(self, responses: list[str] | None = None): | |
| self._responses = responses or [] | |
| self._call_count = 0 | |
| def invoke(self, messages: list) -> MockMessage: | |
| if self._call_count < len(self._responses): | |
| resp = self._responses[self._call_count] | |
| else: | |
| resp = '{"score": 80}' | |
| self._call_count += 1 | |
| return MockMessage(resp) | |
| class MockContext: | |
| llm: Any | None = None | |
| embedding_service: Any | None = None | |
| opensearch_client: Any | None = None | |
| cache: Any | None = None | |
| tracer: Any | None = None | |
| # ----------------------------------------------------------------------- | |
| # Guardrail node | |
| # ----------------------------------------------------------------------- | |
| class TestGuardrailNode: | |
| def test_in_scope_query(self): | |
| from src.services.agents.nodes.guardrail_node import guardrail_node | |
| ctx = MockContext(llm=MockLLM(['{"score": 85}'])) | |
| state = {"query": "What does high HbA1c mean?"} | |
| result = guardrail_node(state, context=ctx) | |
| assert result["is_in_scope"] is True | |
| assert result["guardrail_score"] == 85.0 | |
| def test_out_of_scope_query(self): | |
| from src.services.agents.nodes.guardrail_node import guardrail_node | |
| ctx = MockContext(llm=MockLLM(['{"score": 10}'])) | |
| state = {"query": "What is the weather today?"} | |
| result = guardrail_node(state, context=ctx) | |
| assert result["is_in_scope"] is False | |
| assert result["routing_decision"] == "out_of_scope" | |
| def test_biomarkers_bypass(self): | |
| from src.services.agents.nodes.guardrail_node import guardrail_node | |
| ctx = MockContext(llm=MockLLM()) | |
| state = {"query": "analyze", "biomarkers": {"Glucose": 185}} | |
| result = guardrail_node(state, context=ctx) | |
| assert result["is_in_scope"] is True | |
| assert result["guardrail_score"] == 95.0 | |
| def test_llm_failure_defaults_in_scope(self): | |
| from src.services.agents.nodes.guardrail_node import guardrail_node | |
| broken_llm = MagicMock() | |
| broken_llm.invoke.side_effect = Exception("LLM down") | |
| ctx = MockContext(llm=broken_llm) | |
| state = {"query": "What is HbA1c?"} | |
| result = guardrail_node(state, context=ctx) | |
| assert result["is_in_scope"] is True # benefit of the doubt | |
| # ----------------------------------------------------------------------- | |
| # Out-of-scope node | |
| # ----------------------------------------------------------------------- | |
| class TestOutOfScopeNode: | |
| def test_returns_rejection(self): | |
| from src.services.agents.nodes.out_of_scope_node import out_of_scope_node | |
| from src.services.agents.prompts import OUT_OF_SCOPE_RESPONSE | |
| ctx = MockContext() | |
| result = out_of_scope_node({}, context=ctx) | |
| assert result["final_answer"] == OUT_OF_SCOPE_RESPONSE | |
| # ----------------------------------------------------------------------- | |
| # Grade documents node | |
| # ----------------------------------------------------------------------- | |
| class TestGradeDocumentsNode: | |
| def test_grades_relevant(self): | |
| from src.services.agents.nodes.grade_documents_node import grade_documents_node | |
| ctx = MockContext(llm=MockLLM(['{"relevant": true}', '{"relevant": false}'])) | |
| state = { | |
| "query": "diabetes treatment", | |
| "retrieved_documents": [ | |
| {"id": "1", "text": "Diabetes is treated with insulin."}, | |
| {"id": "2", "text": "The weather is sunny today."}, | |
| ], | |
| } | |
| result = grade_documents_node(state, context=ctx) | |
| assert len(result["relevant_documents"]) == 1 | |
| assert result["grading_results"][0]["relevant"] is True | |
| assert result["grading_results"][1]["relevant"] is False | |
| def test_empty_docs_needs_rewrite(self): | |
| from src.services.agents.nodes.grade_documents_node import grade_documents_node | |
| ctx = MockContext() | |
| state = {"query": "test", "retrieved_documents": []} | |
| result = grade_documents_node(state, context=ctx) | |
| assert result["needs_rewrite"] is True | |
| # ----------------------------------------------------------------------- | |
| # Rewrite query node | |
| # ----------------------------------------------------------------------- | |
| class TestRewriteQueryNode: | |
| def test_rewrites(self): | |
| from src.services.agents.nodes.rewrite_query_node import rewrite_query_node | |
| ctx = MockContext(llm=MockLLM(["diabetes HbA1c glucose management guidelines"])) | |
| state = {"query": "sugar problems"} | |
| result = rewrite_query_node(state, context=ctx) | |
| assert "diabetes" in result["rewritten_query"].lower() or result["rewritten_query"] | |
| def test_llm_failure_keeps_original(self): | |
| from src.services.agents.nodes.rewrite_query_node import rewrite_query_node | |
| broken_llm = MagicMock() | |
| broken_llm.invoke.side_effect = Exception("timeout") | |
| ctx = MockContext(llm=broken_llm) | |
| state = {"query": "original query"} | |
| result = rewrite_query_node(state, context=ctx) | |
| assert result["rewritten_query"] == "original query" | |
| # ----------------------------------------------------------------------- | |
| # Generate answer node | |
| # ----------------------------------------------------------------------- | |
| class TestGenerateAnswerNode: | |
| def test_generates_answer(self): | |
| from src.services.agents.nodes.generate_answer_node import generate_answer_node | |
| ctx = MockContext(llm=MockLLM(["Based on the evidence, HbA1c of 8.2% indicates poor glycemic control."])) | |
| state = { | |
| "query": "What does HbA1c 8.2 mean?", | |
| "relevant_documents": [ | |
| {"title": "Diabetes Guide", "section": "Diagnosis", "text": "HbA1c above 6.5% indicates diabetes."} | |
| ], | |
| } | |
| result = generate_answer_node(state, context=ctx) | |
| assert "final_answer" in result | |
| assert len(result["final_answer"]) > 10 | |
| def test_llm_failure_returns_fallback(self): | |
| from src.services.agents.nodes.generate_answer_node import generate_answer_node | |
| broken_llm = MagicMock() | |
| broken_llm.invoke.side_effect = Exception("dead") | |
| ctx = MockContext(llm=broken_llm) | |
| state = {"query": "test", "relevant_documents": []} | |
| result = generate_answer_node(state, context=ctx) | |
| assert "apologize" in result["final_answer"].lower() | |
| assert len(result["errors"]) > 0 | |
| # ----------------------------------------------------------------------- | |
| # Agentic RAG state | |
| # ----------------------------------------------------------------------- | |
| class TestAgenticRAGState: | |
| def test_state_is_typed_dict(self): | |
| from src.services.agents.state import AgenticRAGState | |
| # Should be usable as a dict type hint | |
| state: AgenticRAGState = { | |
| "query": "test", | |
| "errors": [], | |
| } | |
| assert state["query"] == "test" | |