{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.\n", "\n", "classifier = TabPFNClassifier(device='cpu')\n", "classifier.fit(train_xs, train_ys)\n", "prediction_ = classifier.predict(test_xs)\n", "\n", "The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.\n", "Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error." ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "import torch\n", "import numpy as np\n", "import os\n", "import random\n", "\n", "from model_builder import get_model, get_default_spec, save_model, load_model\n", "from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier\n", "\n", "from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids\n", "\n", "from scripts import tabular_metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "base_path = '.'" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Load datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "jupyter": { "outputs_hidden": true }, "tags": [] }, "outputs": [], "source": [ "max_samples = 10000\n", "bptt = 10000\n", "\n", "cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n", "cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)\n", "\n", "# Loading longer OpenML Datasets for generalization experiments (optional)\n", "# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)\n", "\n", "random.seed(0)\n", "random.shuffle(cc_valid_datasets_multiclass)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import get_openml_classification" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = openml.datasets.get_dataset(31)\n", "X, y, categorical_indicator, attribute_names = dataset.get_data(\n", " dataset_format=\"array\", target=dataset.default_target_attribute\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_datasets(selector, task_type, suite='cc'):\n", " if task_type == 'binary':\n", " ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary\n", " else:\n", " if suite == 'openml':\n", " ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass\n", " elif suite == 'cc':\n", " ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass\n", " else:\n", " raise Exception(\"Unknown suite\")\n", " return ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_string, longer, task_type = '', 1, 'multiclass'\n", "eval_positions = [1000]\n", "bptt = 2000\n", " \n", "test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "### Select a dataset for prediction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "[(i, test_datasets[i][0]) for i in range(len(test_datasets))]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "evaluation_dataset_index = 4 # Index of the dataset to predict\n", "ds = test_datasets[evaluation_dataset_index]\n", "print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xs, ys = ds[1].clone(), ds[2].clone()\n", "eval_position = xs.shape[0] // 2\n", "train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]\n", "test_xs, test_ys = xs[eval_position:], ys[eval_position:]" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "### Predict using a Fitted and Tuned Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classifier = TabPFNClassifier(device='cpu')\n", "classifier.fit(train_xs, train_ys)\n", "prediction_ = classifier.predict_proba(test_xs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)\n", "'AUC', float(roc), 'Cross Entropy', float(ce)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" } }, "nbformat": 4, "nbformat_minor": 4 }