| """Shared pytest fixtures for TrialPath test suite.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from unittest.mock import AsyncMock, MagicMock, patch |
|
|
| import pytest |
|
|
| try: |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
| except ImportError: |
| pass |
|
|
| from app.services.mock_data import ( |
| MOCK_ELIGIBILITY_LEDGERS, |
| MOCK_PATIENT_PROFILE, |
| MOCK_TRIAL_CANDIDATES, |
| ) |
| from trialpath.models import ( |
| EligibilityLedger, |
| PatientProfile, |
| SearchAnchors, |
| TrialCandidate, |
| ) |
|
|
| |
| |
| |
|
|
|
|
| @pytest.fixture() |
| def sample_profile() -> PatientProfile: |
| """Return the shared mock patient profile.""" |
| return MOCK_PATIENT_PROFILE |
|
|
|
|
| @pytest.fixture() |
| def sample_trials() -> list[TrialCandidate]: |
| """Return the shared mock trial candidates.""" |
| return list(MOCK_TRIAL_CANDIDATES) |
|
|
|
|
| @pytest.fixture() |
| def sample_ledgers() -> list[EligibilityLedger]: |
| """Return the shared mock eligibility ledgers.""" |
| return list(MOCK_ELIGIBILITY_LEDGERS) |
|
|
|
|
| @pytest.fixture() |
| def sample_anchors(sample_profile: PatientProfile) -> SearchAnchors: |
| """Build SearchAnchors from the mock profile.""" |
| assert sample_profile.diagnosis is not None |
| assert sample_profile.performance_status is not None |
| return SearchAnchors( |
| condition=sample_profile.diagnosis.primary_condition, |
| subtype=sample_profile.diagnosis.histology, |
| biomarkers=[b.name for b in sample_profile.biomarkers], |
| stage=sample_profile.diagnosis.stage, |
| age=sample_profile.demographics.age, |
| performance_status_max=sample_profile.performance_status.value, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.fixture() |
| def mock_medgemma(): |
| """Patch MedGemmaExtractor with a mock that returns sample profile data.""" |
| with patch("trialpath.services.medgemma_extractor.MedGemmaExtractor") as cls: |
| instance = MagicMock() |
| instance.extract = AsyncMock(return_value=MOCK_PATIENT_PROFILE) |
| instance.evaluate_medical_criterion = AsyncMock( |
| return_value={ |
| "decision": "met", |
| "confidence": 0.9, |
| "reasoning": "Criterion satisfied based on profile data.", |
| } |
| ) |
| cls.return_value = instance |
| yield instance |
|
|
|
|
| @pytest.fixture() |
| def mock_gemini(): |
| """Patch GeminiPlanner with a mock that returns structured outputs.""" |
| with patch("trialpath.services.gemini_planner.GeminiPlanner") as cls: |
| instance = MagicMock() |
| instance.generate_search_anchors = AsyncMock( |
| return_value=SearchAnchors( |
| condition="Non-Small Cell Lung Cancer", |
| biomarkers=["EGFR"], |
| stage="IIIB", |
| ) |
| ) |
| instance.evaluate_eligibility = AsyncMock( |
| return_value={ |
| "overall_assessment": "uncertain", |
| "criteria": [], |
| } |
| ) |
| instance.refine_search = AsyncMock( |
| return_value=SearchAnchors( |
| condition="NSCLC", |
| biomarkers=["EGFR"], |
| stage="IIIB", |
| ) |
| ) |
| instance.relax_search = AsyncMock( |
| return_value=SearchAnchors( |
| condition="Lung Cancer", |
| ) |
| ) |
| instance.slice_criteria = AsyncMock( |
| return_value=[ |
| {"text": "Age >= 18", "type": "structural"}, |
| {"text": "EGFR mutation positive", "type": "medical"}, |
| ] |
| ) |
| instance.evaluate_structural_criterion = AsyncMock( |
| return_value={ |
| "decision": "met", |
| "confidence": 0.95, |
| "reasoning": "Patient is 62, meets age requirement.", |
| } |
| ) |
| instance.aggregate_assessments = AsyncMock(return_value=MOCK_ELIGIBILITY_LEDGERS[0]) |
| instance.analyze_gaps = AsyncMock( |
| return_value=[ |
| { |
| "description": "Brain MRI status unknown", |
| "recommended_action": "Order brain MRI", |
| "clinical_importance": "high", |
| } |
| ] |
| ) |
| cls.return_value = instance |
| yield instance |
|
|
|
|
| @pytest.fixture() |
| def mock_mcp(): |
| """Patch ClinicalTrialsMCPClient with a mock.""" |
| with patch("trialpath.services.mcp_client.ClinicalTrialsMCPClient") as cls: |
| instance = AsyncMock() |
| instance.search_studies.return_value = { |
| "studies": [t.model_dump() for t in MOCK_TRIAL_CANDIDATES] |
| } |
| instance.get_study.return_value = MOCK_TRIAL_CANDIDATES[0].model_dump() |
| cls.return_value = instance |
| yield instance |
|
|
|
|
| |
| |
| |
|
|
|
|
| @pytest.fixture(scope="session") |
| def live_env(): |
| """Ensure env vars are loaded; skip the entire session block if missing.""" |
| if not os.environ.get("GEMINI_API_KEY"): |
| pytest.skip("GEMINI_API_KEY not set — skipping live tests") |
|
|
|
|
| @pytest.fixture(scope="session") |
| def live_gemini(live_env): |
| """Return a real GeminiPlanner wired to the Gemini API.""" |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| return GeminiPlanner() |
|
|
|
|
| @pytest.fixture(scope="session") |
| def live_mcp_client(live_env): |
| """Return a real ClinicalTrialsMCPClient.""" |
| from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
|
|
| return ClinicalTrialsMCPClient() |
|
|
|
|
| @pytest.fixture(scope="session") |
| def live_medgemma(live_env): |
| """Return a real MedGemmaExtractor (skip if no HF_TOKEN).""" |
| if not os.environ.get("HF_TOKEN"): |
| pytest.skip("HF_TOKEN not set — skipping MedGemma live tests") |
|
|
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| return MedGemmaExtractor() |
|
|