AI_Project / scripts /evaluate_classification.py
Untraceable09's picture
Add files using upload-large-folder tool
ad34663 verified
Raw
History Blame Contribute Delete
6.02 kB
"""Evaluate a trained classification model (ViT + LoRA) on the TN5000 test set.
Produces:
- AUC-ROC, F1-Score, Sensitivity, Specificity, ECE
- Confusion matrix
- Per-class classification report
- Inference latency measurements (Teacher vs Student comparison)
Usage::
python scripts/evaluate_classification.py --checkpoint outputs/classification/best
python scripts/evaluate_classification.py --checkpoint outputs/classification/best --split test
"""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from peft import PeftModel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification
from thyroid_vfm.config import ClassificationConfig, load_yaml_config
from thyroid_vfm.data.transforms import build_ultrasound_transform
from thyroid_vfm.data.voc import Tn5000ClassificationDataset
from thyroid_vfm.evaluation.metrics import compute_classification_metrics
def _build_test_loader(config: ClassificationConfig, processor, split: str = "test"):
class_to_id = {name: idx for idx, name in enumerate(config.dataset.class_names)}
dataset = Tn5000ClassificationDataset(
root_dir=config.dataset.root_dir,
split=split,
class_to_id=class_to_id,
transform=build_ultrasound_transform(config.dataset.image_size, train=False),
images_dir=config.dataset.images_dir,
annotations_dir=config.dataset.annotations_dir,
splits_dir=config.dataset.splits_dir,
)
def collate_fn(batch: list[dict]) -> dict[str, torch.Tensor]:
images = [item["image"] for item in batch]
labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
encoded = processor(images=images, return_tensors="pt")
encoded["labels"] = labels
return encoded
return DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=collate_fn,
)
@torch.no_grad()
def evaluate(
checkpoint_dir: Path,
config: ClassificationConfig,
split: str = "test",
) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
# Load model and processor
print(f" Loading model from {checkpoint_dir} ...")
base_model = AutoModelForImageClassification.from_pretrained(
config.model_name,
num_labels=config.num_labels,
ignore_mismatched_sizes=True,
)
model = PeftModel.from_pretrained(base_model, str(checkpoint_dir))
# Load full model state (includes re-initialized classifier head)
full_state = checkpoint_dir / "full_model.pt"
if full_state.exists():
state = torch.load(str(full_state), map_location="cpu", weights_only=True)
model.load_state_dict(state, strict=False)
print(" Loaded full model state (with classifier head).")
model.to(device).eval()
processor = AutoImageProcessor.from_pretrained(str(checkpoint_dir))
loader = _build_test_loader(config, processor, split=split)
print(f" Evaluating on {len(loader.dataset)} samples (split={split!r})...\n")
all_labels = []
all_preds = []
all_probs = []
latencies = []
for batch in tqdm(loader, desc="Inference"):
batch = {k: v.to(device) for k, v in batch.items()}
t0 = time.perf_counter()
outputs = model(**batch)
if device.type == "cuda":
torch.cuda.synchronize()
latencies.append(time.perf_counter() - t0)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)
preds = logits.argmax(dim=-1)
all_labels.append(batch["labels"].cpu().numpy())
all_preds.append(preds.cpu().numpy())
all_probs.append(probs.cpu().numpy())
labels = np.concatenate(all_labels)
preds = np.concatenate(all_preds)
probs = np.concatenate(all_probs)
metrics = compute_classification_metrics(labels, preds, probs, config.dataset.class_names)
print(metrics.summary())
# Latency stats
total_samples = len(labels)
total_time = sum(latencies)
per_sample_ms = (total_time / total_samples) * 1000
print(f"\n Inference Latency:")
print(f" Total time : {total_time:.3f} s")
print(f" Per sample : {per_sample_ms:.2f} ms")
print(f" Throughput : {total_samples / total_time:.1f} samples/s")
# Save results
results = {
"split": split,
"num_samples": total_samples,
"accuracy": metrics.accuracy,
"auc_roc": metrics.auc_roc,
"f1": metrics.f1,
"sensitivity": metrics.sensitivity,
"specificity": metrics.specificity,
"ece": metrics.ece,
"latency_per_sample_ms": round(per_sample_ms, 3),
"throughput_samples_per_s": round(total_samples / total_time, 1),
}
results_path = checkpoint_dir.parent / f"eval_results_{split}.json"
results_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
print(f"\n Results saved to {results_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate a trained ViT+LoRA model.")
parser.add_argument(
"--checkpoint",
type=Path,
default=Path("outputs/classification/best"),
help="Path to the saved PEFT checkpoint.",
)
parser.add_argument(
"--config",
type=Path,
default=Path("configs/classification_vit_lora.yaml"),
help="Path to the classification config.",
)
parser.add_argument(
"--split",
type=str,
default="test",
choices=["train", "val", "test"],
help="Which split to evaluate on.",
)
args = parser.parse_args()
config = load_yaml_config(args.config, ClassificationConfig)
evaluate(args.checkpoint, config, args.split)
if __name__ == "__main__":
main()