| """Evaluate a trained classification model (ViT + LoRA) on the TN5000 test set. |
| |
| Produces: |
| - AUC-ROC, F1-Score, Sensitivity, Specificity, ECE |
| - Confusion matrix |
| - Per-class classification report |
| - Inference latency measurements (Teacher vs Student comparison) |
| |
| Usage:: |
| |
| python scripts/evaluate_classification.py --checkpoint outputs/classification/best |
| python scripts/evaluate_classification.py --checkpoint outputs/classification/best --split test |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from peft import PeftModel |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import AutoImageProcessor, AutoModelForImageClassification |
|
|
| from thyroid_vfm.config import ClassificationConfig, load_yaml_config |
| from thyroid_vfm.data.transforms import build_ultrasound_transform |
| from thyroid_vfm.data.voc import Tn5000ClassificationDataset |
| from thyroid_vfm.evaluation.metrics import compute_classification_metrics |
|
|
|
|
| def _build_test_loader(config: ClassificationConfig, processor, split: str = "test"): |
| class_to_id = {name: idx for idx, name in enumerate(config.dataset.class_names)} |
| dataset = Tn5000ClassificationDataset( |
| root_dir=config.dataset.root_dir, |
| split=split, |
| class_to_id=class_to_id, |
| transform=build_ultrasound_transform(config.dataset.image_size, train=False), |
| images_dir=config.dataset.images_dir, |
| annotations_dir=config.dataset.annotations_dir, |
| splits_dir=config.dataset.splits_dir, |
| ) |
|
|
| def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]: |
| images = [item["image"] for item in batch] |
| labels = torch.tensor([item["label"] for item in batch], dtype=torch.long) |
| encoded = processor(images=images, return_tensors="pt") |
| encoded["labels"] = labels |
| return encoded |
|
|
| return DataLoader( |
| dataset, |
| batch_size=config.batch_size, |
| shuffle=False, |
| num_workers=config.num_workers, |
| collate_fn=collate_fn, |
| ) |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| checkpoint_dir: Path, |
| config: ClassificationConfig, |
| split: str = "test", |
| ) -> None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f" Device: {device}") |
|
|
| |
| print(f" Loading model from {checkpoint_dir} ...") |
| base_model = AutoModelForImageClassification.from_pretrained( |
| config.model_name, |
| num_labels=config.num_labels, |
| ignore_mismatched_sizes=True, |
| ) |
| model = PeftModel.from_pretrained(base_model, str(checkpoint_dir)) |
|
|
| |
| full_state = checkpoint_dir / "full_model.pt" |
| if full_state.exists(): |
| state = torch.load(str(full_state), map_location="cpu", weights_only=True) |
| model.load_state_dict(state, strict=False) |
| print(" Loaded full model state (with classifier head).") |
|
|
| model.to(device).eval() |
| processor = AutoImageProcessor.from_pretrained(str(checkpoint_dir)) |
|
|
| loader = _build_test_loader(config, processor, split=split) |
| print(f" Evaluating on {len(loader.dataset)} samples (split={split!r})...\n") |
|
|
| all_labels = [] |
| all_preds = [] |
| all_probs = [] |
| latencies = [] |
|
|
| for batch in tqdm(loader, desc="Inference"): |
| batch = {k: v.to(device) for k, v in batch.items()} |
|
|
| t0 = time.perf_counter() |
| outputs = model(**batch) |
| if device.type == "cuda": |
| torch.cuda.synchronize() |
| latencies.append(time.perf_counter() - t0) |
|
|
| logits = outputs.logits |
| probs = F.softmax(logits, dim=-1) |
| preds = logits.argmax(dim=-1) |
|
|
| all_labels.append(batch["labels"].cpu().numpy()) |
| all_preds.append(preds.cpu().numpy()) |
| all_probs.append(probs.cpu().numpy()) |
|
|
| labels = np.concatenate(all_labels) |
| preds = np.concatenate(all_preds) |
| probs = np.concatenate(all_probs) |
|
|
| metrics = compute_classification_metrics(labels, preds, probs, config.dataset.class_names) |
| print(metrics.summary()) |
|
|
| |
| total_samples = len(labels) |
| total_time = sum(latencies) |
| per_sample_ms = (total_time / total_samples) * 1000 |
| print(f"\n Inference Latency:") |
| print(f" Total time : {total_time:.3f} s") |
| print(f" Per sample : {per_sample_ms:.2f} ms") |
| print(f" Throughput : {total_samples / total_time:.1f} samples/s") |
|
|
| |
| results = { |
| "split": split, |
| "num_samples": total_samples, |
| "accuracy": metrics.accuracy, |
| "auc_roc": metrics.auc_roc, |
| "f1": metrics.f1, |
| "sensitivity": metrics.sensitivity, |
| "specificity": metrics.specificity, |
| "ece": metrics.ece, |
| "latency_per_sample_ms": round(per_sample_ms, 3), |
| "throughput_samples_per_s": round(total_samples / total_time, 1), |
| } |
| results_path = checkpoint_dir.parent / f"eval_results_{split}.json" |
| results_path.write_text(json.dumps(results, indent=2), encoding="utf-8") |
| print(f"\n Results saved to {results_path}") |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Evaluate a trained ViT+LoRA model.") |
| parser.add_argument( |
| "--checkpoint", |
| type=Path, |
| default=Path("outputs/classification/best"), |
| help="Path to the saved PEFT checkpoint.", |
| ) |
| parser.add_argument( |
| "--config", |
| type=Path, |
| default=Path("configs/classification_vit_lora.yaml"), |
| help="Path to the classification config.", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="test", |
| choices=["train", "val", "test"], |
| help="Which split to evaluate on.", |
| ) |
| args = parser.parse_args() |
| config = load_yaml_config(args.config, ClassificationConfig) |
| evaluate(args.checkpoint, config, args.split) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|