| """DiaFoot.AI v2 — Evaluation Entry Point. |
| |
| Phase 4: Evaluate trained models on test set. |
| |
| Usage: |
| # Evaluate classifier |
| python scripts/evaluate.py --task classify \ |
| |
| # Evaluate segmentation |
| python scripts/evaluate.py --task segment \ |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
|
|
| from src.data.augmentation import get_val_transforms |
| from src.data.torch_dataset import DFUDataset |
| from src.evaluation.classification_metrics import ( |
| compute_classification_metrics, |
| print_classification_report, |
| ) |
| from src.evaluation.metrics import ( |
| aggregate_metrics, |
| compute_segmentation_metrics, |
| print_segmentation_report, |
| ) |
| from src.models.classifier import TriageClassifier |
| from src.models.unetpp import build_unetpp |
|
|
|
|
| def evaluate_classifier(checkpoint_path: str, splits_dir: str, device: str) -> None: |
| """Evaluate triage classifier on test set.""" |
| logger = logging.getLogger("eval_classifier") |
|
|
| model = TriageClassifier(backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model = model.to(device) |
| model.eval() |
|
|
| test_ds = DFUDataset( |
| split_csv=Path(splits_dir) / "test.csv", |
| transform=get_val_transforms(), |
| ) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=4) |
|
|
| all_labels = [] |
| all_preds = [] |
| all_probs = [] |
|
|
| with torch.no_grad(): |
| for batch in test_loader: |
| images = batch["image"].to(device) |
| labels = batch["label"] |
| logits = model(images) |
| probs = torch.softmax(logits, dim=1) |
| preds = logits.argmax(dim=1) |
|
|
| all_labels.extend(labels.numpy()) |
| all_preds.extend(preds.cpu().numpy()) |
| all_probs.extend(probs.cpu().numpy()) |
|
|
| y_true = np.array(all_labels) |
| y_pred = np.array(all_preds) |
| y_prob = np.array(all_probs) |
|
|
| metrics = compute_classification_metrics(y_true, y_pred, y_prob) |
| print_classification_report(metrics) |
|
|
| |
| output_path = Path("results/classification_metrics.json") |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| save_metrics = {k: v for k, v in metrics.items() if k != "report"} |
| with open(output_path, "w") as f: |
| json.dump(save_metrics, f, indent=2) |
| logger.info("Results saved to %s", output_path) |
|
|
|
|
| def evaluate_segmentation(checkpoint_path: str, splits_dir: str, device: str) -> None: |
| """Evaluate segmentation model on test set.""" |
| logger = logging.getLogger("eval_segmentation") |
|
|
| model = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model = model.to(device) |
| model.eval() |
|
|
| test_ds = DFUDataset( |
| split_csv=Path(splits_dir) / "test.csv", |
| transform=get_val_transforms(), |
| return_metadata=True, |
| ) |
| test_loader = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4) |
|
|
| all_metrics = [] |
| dfu_metrics = [] |
| non_dfu_metrics = [] |
|
|
| with torch.no_grad(): |
| for batch in test_loader: |
| images = batch["image"].to(device) |
| masks = batch["mask"].numpy() |
| labels = batch["label"].numpy() |
|
|
| logits = model(images) |
| preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8) |
|
|
| for i in range(len(images)): |
| pred_mask = preds[i] |
| gt_mask = masks[i] |
| m = compute_segmentation_metrics(pred_mask, gt_mask) |
| all_metrics.append(m) |
|
|
| if labels[i] == 2: |
| dfu_metrics.append(m) |
| elif labels[i] == 1: |
| non_dfu_metrics.append(m) |
|
|
| |
| summary = aggregate_metrics(all_metrics) |
| print_segmentation_report(summary) |
|
|
| |
| if dfu_metrics: |
| print("DFU images only:") |
| dfu_summary = aggregate_metrics(dfu_metrics) |
| print_segmentation_report(dfu_summary) |
|
|
| if non_dfu_metrics: |
| print("Non-DFU images only:") |
| non_dfu_summary = aggregate_metrics(non_dfu_metrics) |
| print_segmentation_report(non_dfu_summary) |
|
|
| |
| output_path = Path("results/segmentation_metrics.json") |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump(summary, f, indent=2, default=str) |
| logger.info("Results saved to %s", output_path) |
|
|
|
|
| def main() -> None: |
| """Run evaluation.""" |
| parser = argparse.ArgumentParser(description="DiaFoot.AI v2 Evaluation") |
| parser.add_argument("--task", type=str, required=True, choices=["classify", "segment"]) |
| parser.add_argument("--checkpoint", type=str, required=True) |
| parser.add_argument("--splits-dir", type=str, default="data/splits") |
| parser.add_argument("--device", type=str, default="cuda") |
| parser.add_argument("--verbose", action="store_true") |
| args = parser.parse_args() |
|
|
| logging.basicConfig( |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
|
|
| dev = args.device if torch.cuda.is_available() else "cpu" |
|
|
| if args.task == "classify": |
| evaluate_classifier(args.checkpoint, args.splits_dir, dev) |
| elif args.task == "segment": |
| evaluate_segmentation(args.checkpoint, args.splits_dir, dev) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|