| """Parlant tool definitions for the TrialPath agent.""" |
|
|
| import json |
|
|
| from parlant.sdk import ToolContext, ToolResult, tool |
|
|
| from trialpath.config import ( |
| GEMINI_API_KEY, |
| GEMINI_MODEL, |
| HF_TOKEN, |
| MCP_URL, |
| MEDGEMMA_ENDPOINT_URL, |
| ) |
|
|
| |
| |
| |
|
|
| _extractor = None |
| _planner = None |
| _mcp_client = None |
|
|
|
|
| def _get_extractor(): |
| global _extractor |
| if _extractor is None: |
| from trialpath.services.medgemma_extractor import MedGemmaExtractor |
|
|
| _extractor = MedGemmaExtractor( |
| endpoint_url=MEDGEMMA_ENDPOINT_URL, |
| hf_token=HF_TOKEN, |
| ) |
| return _extractor |
|
|
|
|
| def _get_planner(): |
| global _planner |
| if _planner is None: |
| from trialpath.services.gemini_planner import GeminiPlanner |
|
|
| _planner = GeminiPlanner(model=GEMINI_MODEL, api_key=GEMINI_API_KEY) |
| return _planner |
|
|
|
|
| def _get_mcp_client(): |
| global _mcp_client |
| if _mcp_client is None: |
| from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
|
|
| _mcp_client = ClinicalTrialsMCPClient(mcp_url=MCP_URL) |
| return _mcp_client |
|
|
|
|
| |
| |
| |
|
|
|
|
| @tool |
| async def extract_patient_profile( |
| context: ToolContext, |
| document_urls: str, |
| metadata: str, |
| ) -> ToolResult: |
| """Extract a structured patient profile from uploaded medical documents. |
| |
| Args: |
| context: Parlant tool context. |
| document_urls: JSON list of document file paths. |
| metadata: JSON object with known patient metadata (age, sex). |
| """ |
| extractor = _get_extractor() |
| urls = json.loads(document_urls) |
| meta = json.loads(metadata) |
| profile = await extractor.extract(urls, meta) |
|
|
| return ToolResult( |
| data=profile, |
| metadata={"source": "medgemma", "doc_count": len(urls)}, |
| ) |
|
|
|
|
| @tool |
| async def generate_search_anchors( |
| context: ToolContext, |
| patient_profile: str, |
| ) -> ToolResult: |
| """Generate search parameters from a patient profile for ClinicalTrials.gov. |
| |
| Args: |
| context: Parlant tool context. |
| patient_profile: JSON string of PatientProfile data. |
| """ |
| planner = _get_planner() |
| profile = json.loads(patient_profile) |
| anchors = await planner.generate_search_anchors(profile) |
|
|
| return ToolResult( |
| data=anchors.model_dump(), |
| metadata={"source": "gemini"}, |
| ) |
|
|
|
|
| @tool |
| async def search_clinical_trials( |
| context: ToolContext, |
| search_anchors: str, |
| ) -> ToolResult: |
| """Search ClinicalTrials.gov for matching trials using search anchors. |
| |
| Args: |
| context: Parlant tool context. |
| search_anchors: JSON string of SearchAnchors data. |
| """ |
| from trialpath.models.search_anchors import SearchAnchors |
|
|
| client = _get_mcp_client() |
| anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| try: |
| raw_studies = await client.search(anchors) |
| except Exception: |
| |
| raw_studies = await client.search_direct(anchors) |
|
|
| from trialpath.services.mcp_client import ClinicalTrialsMCPClient |
|
|
| trials = [ClinicalTrialsMCPClient.normalize_trial(s).model_dump() for s in raw_studies] |
|
|
| return ToolResult( |
| data={"trials": trials, "count": len(trials)}, |
| metadata={"source": "clinicaltrials_mcp"}, |
| ) |
|
|
|
|
| @tool |
| async def refine_search_query( |
| context: ToolContext, |
| search_anchors: str, |
| result_count: str, |
| ) -> ToolResult: |
| """Refine search parameters when too many results returned. |
| |
| Args: |
| context: Parlant tool context. |
| search_anchors: JSON string of current SearchAnchors. |
| result_count: Number of results from last search. |
| """ |
| from trialpath.models.search_anchors import SearchAnchors |
|
|
| planner = _get_planner() |
| anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| refined = await planner.refine_search(anchors, int(result_count)) |
|
|
| return ToolResult( |
| data=refined.model_dump(), |
| metadata={"action": "refine", "prev_count": int(result_count)}, |
| ) |
|
|
|
|
| @tool |
| async def relax_search_query( |
| context: ToolContext, |
| search_anchors: str, |
| result_count: str, |
| ) -> ToolResult: |
| """Relax search parameters when too few results returned. |
| |
| Args: |
| context: Parlant tool context. |
| search_anchors: JSON string of current SearchAnchors. |
| result_count: Number of results from last search. |
| """ |
| from trialpath.models.search_anchors import SearchAnchors |
|
|
| planner = _get_planner() |
| anchors = SearchAnchors.model_validate(json.loads(search_anchors)) |
| relaxed = await planner.relax_search(anchors, int(result_count)) |
|
|
| return ToolResult( |
| data=relaxed.model_dump(), |
| metadata={"action": "relax", "prev_count": int(result_count)}, |
| ) |
|
|
|
|
| @tool |
| async def evaluate_trial_eligibility( |
| context: ToolContext, |
| patient_profile: str, |
| trial_candidate: str, |
| ) -> ToolResult: |
| """Evaluate patient eligibility for a clinical trial using dual-model approach. |
| |
| Medical criteria evaluated by MedGemma, structural by Gemini. |
| |
| Args: |
| context: Parlant tool context. |
| patient_profile: JSON string of PatientProfile data. |
| trial_candidate: JSON string of TrialCandidate data. |
| """ |
| profile = json.loads(patient_profile) |
| trial = json.loads(trial_candidate) |
|
|
| planner = _get_planner() |
| extractor = _get_extractor() |
|
|
| |
| criteria = await planner.slice_criteria(trial) |
|
|
| |
| assessments = [] |
| for criterion in criteria: |
| if criterion.get("category") == "medical": |
| result = await extractor.evaluate_medical_criterion(criterion["text"], profile, []) |
| else: |
| result = await planner.evaluate_structural_criterion(criterion["text"], profile) |
| assessments.append({**criterion, **result}) |
|
|
| |
| ledger = await planner.aggregate_assessments(profile, trial, assessments) |
|
|
| return ToolResult( |
| data=ledger.model_dump(), |
| metadata={"source": "dual_model", "criteria_count": len(criteria)}, |
| ) |
|
|
|
|
| @tool |
| async def analyze_gaps( |
| context: ToolContext, |
| patient_profile: str, |
| eligibility_ledgers: str, |
| ) -> ToolResult: |
| """Analyze eligibility gaps across all evaluated trials. |
| |
| Args: |
| context: Parlant tool context. |
| patient_profile: JSON string of PatientProfile data. |
| eligibility_ledgers: JSON list of EligibilityLedger data. |
| """ |
| planner = _get_planner() |
| profile = json.loads(patient_profile) |
| ledgers = json.loads(eligibility_ledgers) |
| gaps = await planner.analyze_gaps(profile, ledgers) |
|
|
| return ToolResult( |
| data={"gaps": gaps, "count": len(gaps)}, |
| metadata={"source": "gemini"}, |
| ) |
|
|
|
|
| ALL_TOOLS = [ |
| extract_patient_profile, |
| generate_search_anchors, |
| search_clinical_trials, |
| refine_search_query, |
| relax_search_query, |
| evaluate_trial_eligibility, |
| analyze_gaps, |
| ] |
|
|