import torch def eval_knn(x_train, y_train, x_test, y_test, k=200): """ k-nearest neighbors classifier accuracy """ d = torch.cdist(x_test, x_train) topk = torch.topk(d, k=k, dim=1, largest=False) labels = y_train[topk.indices] pred = torch.empty_like(y_test) for i in range(len(labels)): x = labels[i].unique(return_counts=True) pred[i] = x[0][x[1].argmax()] acc = (pred == y_test).float().mean().cpu().item() del d, topk, labels, pred return acc