cds-agent / src /backend /validation /run_validation.py
bshepp
fix: resolve 13 Pylance type-checking warnings
0fe4d92
"""
Unified validation runner for the Clinical Decision Support Agent.
Runs all three dataset validations (MedQA, MTSamples, PMC Case Reports)
and produces a combined summary report.
Usage:
# From src/backend directory:
python -m validation.run_validation --all --max-cases 10
python -m validation.run_validation --medqa --max-cases 20
python -m validation.run_validation --mtsamples --max-cases 15
python -m validation.run_validation --pmc --max-cases 10
# Fetch data only (no pipeline execution):
python -m validation.run_validation --fetch-only
"""
from __future__ import annotations
import asyncio
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
# Ensure backend is importable
BACKEND_DIR = Path(__file__).resolve().parent.parent
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
# Load .env and export HF_TOKEN so huggingface_hub picks it up
from dotenv import load_dotenv
load_dotenv(BACKEND_DIR / ".env")
hf_token = os.getenv("HF_TOKEN")
if hf_token:
os.environ["HF_TOKEN"] = hf_token
from validation.base import (
ValidationSummary,
print_summary,
save_results,
)
from validation.harness_medqa import fetch_medqa, validate_medqa
from validation.harness_mtsamples import fetch_mtsamples, validate_mtsamples
from validation.harness_pmc import fetch_pmc_cases, validate_pmc
async def run_all_validations(
run_medqa: bool = True,
run_mtsamples: bool = True,
run_pmc: bool = True,
max_cases: int = 10,
seed: int = 42,
include_drug_check: bool = True,
include_guidelines: bool = True,
delay: float = 2.0,
fetch_only: bool = False,
resume: bool = False,
) -> dict:
"""
Run validation against selected datasets.
Returns dict of {dataset_name: ValidationSummary}
"""
results = {}
start = time.time()
# ── MedQA ──
if run_medqa:
print("\n" + "=" * 60)
print(" DATASET 1: MedQA (USMLE-style diagnostic accuracy)")
print("=" * 60)
cases = await fetch_medqa(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} MedQA cases (fetch-only mode)")
else:
summary = await validate_medqa(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["medqa"] = summary
# ── MTSamples ──
if run_mtsamples:
print("\n" + "=" * 60)
print(" DATASET 2: MTSamples (clinical note parsing robustness)")
print("=" * 60)
cases = await fetch_mtsamples(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} MTSamples cases (fetch-only mode)")
else:
summary = await validate_mtsamples(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["mtsamples"] = summary
# ── PMC Case Reports ──
if run_pmc:
print("\n" + "=" * 60)
print(" DATASET 3: PMC Case Reports (real-world diagnostic accuracy)")
print("=" * 60)
cases = await fetch_pmc_cases(max_cases=max_cases, seed=seed)
if fetch_only:
print(f" Fetched {len(cases)} PMC cases (fetch-only mode)")
else:
summary = await validate_pmc(
cases,
include_drug_check=include_drug_check,
include_guidelines=include_guidelines,
delay_between_cases=delay,
resume=resume,
)
print_summary(summary)
save_results(summary)
results["pmc"] = summary
# ── Combined Summary ──
total_duration = time.time() - start
if results and not fetch_only:
_print_combined_summary(results, total_duration)
_save_combined_report(results, total_duration)
return results
def _print_combined_summary(results: dict, total_duration: float):
"""Print a combined summary across all datasets."""
print("\n" + "=" * 70)
print(" COMBINED VALIDATION REPORT")
print("=" * 70)
# Header
print(f"\n {'Dataset':<15} {'Cases':>6} {'Success':>8} {'Key Metric':>25} {'Value':>8}")
print(f" {'-'*15} {'-'*6} {'-'*8} {'-'*25} {'-'*8}")
for name, summary in results.items():
# Pick the most important metric for each dataset
if name == "medqa":
key_metric = "top3_accuracy"
elif name == "mtsamples":
key_metric = "parse_success"
elif name == "pmc":
key_metric = "diagnostic_accuracy"
else:
key_metric = list(summary.metrics.keys())[0] if summary.metrics else "N/A"
value = summary.metrics.get(key_metric, 0.0)
print(
f" {name:<15} {summary.total_cases:>6} "
f"{summary.successful_cases:>8} "
f"{key_metric:>25} {value:>7.1%}"
)
# All metrics
print(f"\n {'-' * 66}")
for name, summary in results.items():
print(f"\n {name.upper()} metrics:")
for metric, value in sorted(summary.metrics.items()):
if "time" in metric and isinstance(value, (int, float)):
print(f" {metric:<35} {value:.0f}ms")
elif isinstance(value, float):
print(f" {metric:<35} {value:.1%}")
# Totals
total_cases = sum(s.total_cases for s in results.values())
total_success = sum(s.successful_cases for s in results.values())
print(f"\n Total cases: {total_cases}")
print(f" Total success: {total_success}")
print(f" Total duration: {total_duration:.1f}s ({total_duration/60:.1f}min)")
print(f" Timestamp: {datetime.now(timezone.utc).isoformat()}")
print("=" * 70)
def _save_combined_report(results: dict, total_duration: float):
"""Save combined report to JSON."""
results_dir = Path(__file__).resolve().parent / "results"
results_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
path = results_dir / f"combined_{ts}.json"
combined = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"total_duration_sec": total_duration,
"datasets": {},
}
for name, summary in results.items():
combined["datasets"][name] = {
"total_cases": summary.total_cases,
"successful_cases": summary.successful_cases,
"failed_cases": summary.failed_cases,
"metrics": summary.metrics,
"run_duration_sec": summary.run_duration_sec,
}
path.write_text(json.dumps(combined, indent=2, default=str))
print(f"\n Combined report saved to: {path}")
def main():
"""CLI entry point."""
import argparse
parser = argparse.ArgumentParser(
description="CDS Agent Validation Suite",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m validation.run_validation --all --max-cases 10
python -m validation.run_validation --medqa --max-cases 50
python -m validation.run_validation --fetch-only
python -m validation.run_validation --medqa --pmc --max-cases 20 --no-drugs
""",
)
# Dataset selection
data_group = parser.add_argument_group("Datasets")
data_group.add_argument("--all", action="store_true", help="Run all three datasets")
data_group.add_argument("--medqa", action="store_true", help="Run MedQA validation")
data_group.add_argument("--mtsamples", action="store_true", help="Run MTSamples validation")
data_group.add_argument("--pmc", action="store_true", help="Run PMC Case Reports validation")
# Configuration
config_group = parser.add_argument_group("Configuration")
config_group.add_argument("--max-cases", type=int, default=10, help="Cases per dataset (default: 10)")
config_group.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
config_group.add_argument("--delay", type=float, default=2.0, help="Delay between cases in seconds (default: 2.0)")
config_group.add_argument("--no-drugs", action="store_true", help="Skip drug interaction checks")
config_group.add_argument("--no-guidelines", action="store_true", help="Skip guideline retrieval")
config_group.add_argument("--resume", action="store_true", help="Resume from checkpoint (skip already-completed cases)")
config_group.add_argument("--fetch-only", action="store_true", help="Only download data, don't run pipeline")
args = parser.parse_args()
# Default to --all if nothing specified
if not any([args.all, args.medqa, args.mtsamples, args.pmc]):
args.all = True
run_medqa = args.all or args.medqa
run_mtsamples = args.all or args.mtsamples
run_pmc = args.all or args.pmc
print("=" * 58)
print(" Clinical Decision Support Agent - Validation Suite")
print("=" * 58)
print(f"\n Datasets: {'MedQA ' if run_medqa else ''}{'MTSamples ' if run_mtsamples else ''}{'PMC ' if run_pmc else ''}")
print(f" Cases/dataset: {args.max_cases}")
print(f" Drug check: {'Yes' if not args.no_drugs else 'No'}")
print(f" Guidelines: {'Yes' if not args.no_guidelines else 'No'}")
print(f" Resume: {'Yes' if args.resume else 'No'}")
print(f" Fetch only: {'Yes' if args.fetch_only else 'No'}")
asyncio.run(run_all_validations(
run_medqa=run_medqa,
run_mtsamples=run_mtsamples,
run_pmc=run_pmc,
max_cases=args.max_cases,
seed=args.seed,
include_drug_check=not args.no_drugs,
include_guidelines=not args.no_guidelines,
delay=args.delay,
fetch_only=args.fetch_only,
resume=args.resume,
))
if __name__ == "__main__":
main()