Spaces:
Build error
Build error
| import sys | |
| tabpfn_path = 'TabPFN' | |
| sys.path.insert(0, tabpfn_path) # our submodule of the TabPFN repo (at 045c8400203ebd062346970b4f2c0ccda5a40618) | |
| from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import gradio as gr | |
| import openml | |
| from sklearn.model_selection import cross_val_score | |
| def compute(file, y_attribute, cv_folds): | |
| if file is None: | |
| return 'Please upload a .arff file', y_attribute | |
| if file.name.endswith('.arff'): | |
| dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name) | |
| X_, _, categorical_indicator_, attribute_names_ = dataset.get_data( | |
| dataset_format="array") | |
| if y_attribute not in attribute_names_: | |
| return f"**Select attribute from {', '.join(attribute_names_)}**", y_attribute | |
| X, y, categorical_indicator_, attribute_names_ = dataset.get_data( | |
| dataset_format="array", target=y_attribute) | |
| else: | |
| return 'Please upload a .arff file', y_attribute | |
| order = np.arange(y.shape[0]) | |
| np.random.seed(13) | |
| np.random.shuffle(order) | |
| X, y = torch.tensor(X[order]), torch.tensor(y[order]) | |
| classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu') | |
| scores = cross_val_score(classifier, X, y, cv=cv_folds, scoring='roc_auc_ovo') | |
| print(scores) | |
| # classifier.fit(x_train, y_train) | |
| # y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True) | |
| # print(file, type(file)) | |
| return f"ROC AUC OVO Cross Val mean is {sum(scores) / len(scores)} from {scores}. " + ( | |
| "The PFN is only trained for datasets with up to 1024 training examples and it had to extrapolate to greater datasets for this evaluation." if len( | |
| y) // cv_folds > 1024 else ""), y_attribute | |
| def upload_file(file): | |
| if file is None: | |
| return | |
| if file.name.endswith('.arff'): | |
| dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name) | |
| print(y_attribute) | |
| X_, _, categorical_indicator_, attribute_names_ = dataset.get_data( | |
| dataset_format="array") | |
| return f"Select attribute from {', '.join(attribute_names_)}", attribute_names_[-1] | |
| else: | |
| return 'Please upload a .arff file', None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("""This demo allows you to play with the **TabPFN**. | |
| Upload a .arff file, select an attribute to predict and the number of cross validation folds and get the ROC AUC OVO score for one seed. | |
| """) | |
| cv_folds = gr.Dropdown([2, 3, 4, 5], value=2, label='Number of CV folds') | |
| out_text = gr.Markdown() | |
| y_attribute = gr.Textbox(label='y attribute') | |
| inp_file = gr.File( | |
| label='Drop a .arff file.') | |
| examples = gr.Examples(examples=['balance-scale.arff'], | |
| inputs=[inp_file], | |
| outputs=[out_text, y_attribute], | |
| fn=upload_file, | |
| cache_examples=True) | |
| btn = gr.Button("Calculate ROC AUC OVO") | |
| # out_table = gr.DataFrame() | |
| inp_file.change(fn=upload_file, inputs=inp_file, outputs=[out_text, y_attribute]) | |
| btn.click(fn=compute, inputs=[inp_file, y_attribute, cv_folds], outputs=[out_text, y_attribute]) | |
| demo.launch() |