import torch from datasets import get_ds from cfg import get_cfg from methods import get_method from eval.sgd import eval_sgd from eval.knn import eval_knn from eval.lbfgs import eval_lbfgs from eval.get_data import get_data if __name__ == "__main__": cfg = get_cfg() model_full = get_method(cfg.method)(cfg) model_full.cuda().eval() if cfg.fname is None: print("evaluating random model") else: model_full.load_state_dict(torch.load(cfg.fname)) ds = get_ds(cfg.dataset)(None, cfg, cfg.num_workers) device = "cpu" if cfg.clf == "lbfgs" else "cuda" if cfg.eval_head: model = lambda x: model_full.head(model_full.model(x)) out_size = cfg.emb else: model = model_full.model out_size = model_full.out_size x_train, y_train = get_data(model, ds.clf, out_size, device) x_test, y_test = get_data(model, ds.test, out_size, device) if cfg.clf == "sgd": acc = eval_sgd(x_train, y_train, x_test, y_test) if cfg.clf == "knn": acc = eval_knn(x_train, y_train, x_test, y_test) elif cfg.clf == "lbfgs": acc = eval_lbfgs(x_train, y_train, x_test, y_test) print(acc)