Spaces:
Running
Running
| import argparse | |
| import mithridatium.evaluator as evaluator | |
| import mithridatium.loader as loader | |
| from mithridatium.data import build_dataloader | |
| from mithridatium.io import load_preprocess_config | |
| def test_build_dataloader_one_batch(): | |
| # expects models/resnet18_bd.json from Issue 1 | |
| pp = load_preprocess_config("models/resnet18_bd.pth") | |
| loader = build_dataloader("cifar10", "test", pp, batch_size=8) | |
| x, y = next(iter(loader)) | |
| assert x.ndim == 4 and x.shape[1] == 3 # NCHW RGB | |
| assert y.ndim == 1 | |
| # optional: verify spatial dims match config | |
| assert x.shape[-2:] == pp.input_size | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| ''' | |
| .venv/bin/python -m scripts.check_evaluator --model models/resnet18_poison.pth | |
| ''' | |
| parser.add_argument("--model", type=str, default="models/resnet18_bd.pth", help="Path to model checkpoint") | |
| parser.add_argument("--batch_size", type=int, default=256, help="Batch size for evaluation") | |
| args = parser.parse_args() | |
| # Load model from checkpoint | |
| model, feature_module = loader.load_resnet18(args.model) | |
| # Prepare CIFAR-10 test set | |
| pp = load_preprocess_config(args.model) | |
| test_loader = build_dataloader("cifar10", "test", pp, batch_size=args.batch_size) | |
| # Extract embeddings | |
| embs, labels = evaluator.extract_embeddings(model, test_loader, feature_module) | |
| print(f"Embeddings shape: {embs.shape}") | |
| # Evaluate accuracy | |
| loss, accy = evaluator.evaluate(model, test_loader) | |
| print(f"Test accuracy: {accy*100:.2f}% | Test loss: {loss:.4f}") | |
| if __name__ == "__main__": | |
| main() | |