Spaces:
Sleeping
Sleeping
| # src/evaluation/eval_accuracy.py | |
| import argparse | |
| from collections import defaultdict | |
| import numpy as np | |
| from tqdm import tqdm | |
| from sklearn.metrics import accuracy_score, classification_report | |
| from torchvision.datasets import OxfordIIITPet | |
| from src.registry import get_model | |
| import torch | |
| def load_test_dataset(data_root: str): | |
| """ | |
| Load Oxford-IIIT Pet test split without transforms, so we get PIL images. | |
| Targets will be integer class indices (0..36). | |
| """ | |
| dataset = OxfordIIITPet( | |
| root=data_root, | |
| split="test", | |
| target_types="category", | |
| transform=None, # we want raw PIL here | |
| ) | |
| return dataset | |
| def load_model_direct(model_id: str): | |
| """ | |
| Workaround loader that bypasses registry and constructs models | |
| using their actual existing constructor signatures. | |
| Modify only the paths here if needed. | |
| """ | |
| if model_id == "lr_raw": | |
| from src.inference.lr_model import LRModel | |
| # Adjust to match your actual LRModel __init__ | |
| return LRModel("checkpoints/lr_model.joblib", "configs/labels.json") | |
| elif model_id == "svm_raw": | |
| from src.inference.svm_model import SVMModel | |
| return SVMModel("checkpoints/svm_model.joblib", "configs/labels.json") | |
| elif model_id == "resnet_pt_lr": | |
| from src.inference.resnet_pt_lr_model import ResNetPTLRModel | |
| # If these require device or not, match your working constructor | |
| return ResNetPTLRModel( | |
| ckpt_path="checkpoints/resnet_pt_lr_head.joblib", | |
| labels_path="configs/labels.json", | |
| ) | |
| elif model_id == "resnet_pt_svm": | |
| from src.inference.resnet_pt_svm_model import ResNetPTSVMModel | |
| return ResNetPTSVMModel( | |
| ckpt_path="checkpoints/resnet_pt_svm_head.joblib", | |
| labels_path="configs/labels.json", | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model_id: {model_id}") | |
| def evaluate_model_on_dataset(model_id: str, data_root: str): | |
| """ | |
| Evaluate a single model (by id from registry) on the Oxford-IIIT Pet test split. | |
| Uses model.predict(PIL.Image, top_k=5) API. | |
| Returns a dict with: | |
| - top1_acc | |
| - top5_acc | |
| - report_dict (per-class and aggregate metrics) | |
| """ | |
| print(f"\n=== Evaluating model: {model_id} ===") | |
| dataset = load_test_dataset(data_root) | |
| model = load_model_direct(model_id) | |
| y_true = [] | |
| y_pred_top1 = [] | |
| top5_correct = 0 | |
| for idx in tqdm(range(len(dataset)), desc=f"Running {model_id}"): | |
| img, target = dataset[idx] # img: PIL.Image, target: int | |
| # Try to call with top_k; if the model doesn't support it, fall back gracefully | |
| try: | |
| result = model.predict(img, top_k=5) | |
| except TypeError: | |
| # Older / simpler API: predict(img) without top_k | |
| result = model.predict(img) | |
| # Top-1 prediction (must exist) | |
| pred_id = int(result.get("class_id")) | |
| y_true.append(int(target)) | |
| y_pred_top1.append(pred_id) | |
| # Try to get top_k list; if not present, create a synthetic one using only top-1 | |
| top_k = result.get("top_k") | |
| if not top_k: | |
| # Fallback: just treat the top-1 prediction as the only candidate. | |
| # This means Top-5 == Top-1 for such models, which is acceptable as a workaround. | |
| cname = result.get("class_name", "") | |
| top_k = [{ | |
| "class_id": pred_id, | |
| "class_name": cname, | |
| "probability": 1.0 | |
| }] | |
| # Top-5 correct? (GT in top_k list) | |
| if any(int(entry.get("class_id")) == int(target) for entry in top_k): | |
| top5_correct += 1 | |
| y_true = np.array(y_true) | |
| y_pred_top1 = np.array(y_pred_top1) | |
| n = len(y_true) | |
| # Overall Top-1 accuracy | |
| top1_acc = accuracy_score(y_true, y_pred_top1) | |
| # Overall Top-5 accuracy | |
| top5_acc = top5_correct / float(n) | |
| # Detailed precision/recall/F1 per class + aggregate | |
| report = classification_report( | |
| y_true, | |
| y_pred_top1, | |
| digits=4, | |
| output_dict=True # gives a nice dict we can log/inspect | |
| ) | |
| print(f"Top-1 accuracy ({model_id}): {top1_acc:.4f}") | |
| print(f"Top-5 accuracy ({model_id}): {top5_acc:.4f}") | |
| print("\nMacro avg (from classification_report):") | |
| print(report["macro avg"]) | |
| print("\nWeighted avg (from classification_report):") | |
| print(report["weighted avg"]) | |
| return { | |
| "model_id": model_id, | |
| "top1_acc": top1_acc, | |
| "top5_acc": top5_acc, | |
| "report": report, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--data-root", | |
| type=str, | |
| default="data/oxford-iiit-pet", | |
| help="Root directory of Oxford-IIIT Pet dataset.", | |
| ) | |
| args = parser.parse_args() | |
| # List all models you want to evaluate | |
| model_ids = [ | |
| "lr_raw", | |
| "svm_raw", | |
| "resnet_pt_lr", | |
| "resnet_pt_svm", | |
| ] | |
| all_results = [] | |
| for mid in model_ids: | |
| res = evaluate_model_on_dataset(mid, args.data_root) | |
| all_results.append(res) | |
| # Print a compact summary table at the end | |
| print("\n===== Summary (Top-1 & Top-5) =====") | |
| print(f"{'Model':25s} {'Top-1':>8s} {'Top-5':>8s}") | |
| print("-" * 50) | |
| for res in all_results: | |
| name = res["model_id"] | |
| t1 = res["top1_acc"] | |
| t5 = res["top5_acc"] | |
| print(f"{name:25s} {t1:8.4f} {t5:8.4f}") | |
| if __name__ == "__main__": | |
| # Make sure torch doesn't spawn too many threads on some systems | |
| torch.set_num_threads(4) | |
| main() | |