Mithridatium / scripts /check_evaluator.py
Pelumi Oluwategbe
changes
c887cca
import argparse
import mithridatium.evaluator as evaluator
import mithridatium.loader as loader
from mithridatium.data import build_dataloader
from mithridatium.io import load_preprocess_config
def test_build_dataloader_one_batch():
# expects models/resnet18_bd.json from Issue 1
pp = load_preprocess_config("models/resnet18_bd.pth")
loader = build_dataloader("cifar10", "test", pp, batch_size=8)
x, y = next(iter(loader))
assert x.ndim == 4 and x.shape[1] == 3 # NCHW RGB
assert y.ndim == 1
# optional: verify spatial dims match config
assert x.shape[-2:] == pp.input_size
def main():
parser = argparse.ArgumentParser()
'''
.venv/bin/python -m scripts.check_evaluator --model models/resnet18_poison.pth
'''
parser.add_argument("--model", type=str, default="models/resnet18_bd.pth", help="Path to model checkpoint")
parser.add_argument("--batch_size", type=int, default=256, help="Batch size for evaluation")
args = parser.parse_args()
# Load model from checkpoint
model, feature_module = loader.load_resnet18(args.model)
# Prepare CIFAR-10 test set
pp = load_preprocess_config(args.model)
test_loader = build_dataloader("cifar10", "test", pp, batch_size=args.batch_size)
# Extract embeddings
embs, labels = evaluator.extract_embeddings(model, test_loader, feature_module)
print(f"Embeddings shape: {embs.shape}")
# Evaluate accuracy
loss, accy = evaluator.evaluate(model, test_loader)
print(f"Test accuracy: {accy*100:.2f}% | Test loss: {loss:.4f}")
if __name__ == "__main__":
main()