File size: 5,595 Bytes
1ccdd5a
 
 
 
 
 
d213847
1ccdd5a
 
 
 
d213847
 
 
1ccdd5a
 
d213847
1ccdd5a
 
d213847
1ccdd5a
 
d213847
1ccdd5a
d213847
1ccdd5a
 
 
 
 
 
 
 
 
 
 
 
d213847
1ccdd5a
 
 
 
 
 
d213847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee8b768
d213847
 
 
 
 
 
 
 
 
1ccdd5a
 
d213847
 
1ccdd5a
 
 
 
 
 
ee8b768
 
1ccdd5a
d213847
ee8b768
 
d213847
ee8b768
d213847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee8b768
d213847
 
 
 
 
 
 
 
 
 
 
1ccdd5a
 
ee8b768
cc56a6c
ee8b768
 
 
 
 
 
 
1ccdd5a
 
 
 
ee8b768
 
 
cc56a6c
ee8b768
 
 
 
 
 
 
1ccdd5a
 
 
ee8b768
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import sys
tabpfn_path = 'TabPFN'
sys.path.insert(0, tabpfn_path) # our submodule of the TabPFN repo (at 045c8400203ebd062346970b4f2c0ccda5a40618)
from TabPFN.scripts.transformer_prediction_interface import TabPFNClassifier

import numpy as np
from pathlib import Path
import pandas as pd
import torch
import gradio as gr
import openml
import os
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

def compute(table: np.array):
    vfunc = np.vectorize(lambda s: len(str(s)))
    non_empty_row_mask = (vfunc(table).sum(1) != 0)
    table = table[non_empty_row_mask]
    empty_mask = table == '(predict)'
    empty_inds = np.where(empty_mask)
    if not len(empty_inds[0]):
        return "⚠️ **ERROR: Please leave at least one field blank for prediction.**", None, None
    if not np.all(empty_inds[1][0] == empty_inds[1]):
        return "⚠️ **Please only leave fields of one column blank for prediction.**", None, None
    y_column = empty_inds[1][0]
    eval_lines = empty_inds[0]

    train_table = np.delete(table, eval_lines, axis=0)
    eval_table = table[eval_lines]

    try:
        x_train = torch.tensor(np.delete(train_table, y_column, axis=1).astype(np.float32))
        x_eval = torch.tensor(np.delete(eval_table, y_column, axis=1).astype(np.float32))

        y_train = train_table[:, y_column]
    except ValueError:
        return "⚠️ **Please only add numbers (to the inputs) or leave fields empty.**", None, None

    classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
    classifier.fit(x_train, y_train)
    y_eval, p_eval = classifier.predict(x_eval, return_winning_probability=True)

    # print(file, type(file))
    out_table = pd.DataFrame(table.copy().astype(str))
    out_table.iloc[eval_lines, y_column] = [f"{y_e} (p={p_e:.2f})" for y_e, p_e in zip(y_eval, p_eval)]
    out_table = out_table.iloc[eval_lines, :]
    out_table.columns = headers
    
    # PLOTTING
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    cm = plt.cm.RdBu
    cm_bright = ListedColormap(["#FF0000", "#0000FF"])
    
    # Plot the training points
    vfunc = np.vectorize(lambda x : np.where(classifier.classes_ == x)[0])
    y_train_index = vfunc(y_train)
    y_train_index = y_train_index == 0
    y_train = y_train_index
    #x_train = x_train[y_train_index <= 1]
    #y_train = y_train[y_train_index <= 1]
    #y_train_index = y_train_index[y_train_index <= 1]
    
    ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train_index, cmap=cm_bright)
    
    classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
    classifier.fit(x_train[:, 0:2], y_train)
    
    DecisionBoundaryDisplay.from_estimator(
        classifier, x_train[:, 0:2], alpha=0.6, ax=ax, eps=2.0, grid_resolution=100, response_method="predict_proba"
    )
    plt.xlabel(headers[0])
    plt.ylabel(headers[1])

    return None, out_table, fig


def upload_file(file, remove_entries=10):
    global headers
    if file.name.endswith('.arff'):
        dataset = openml.datasets.OpenMLDataset('t', 'test', data_file=file.name)
        X_, _, categorical_indicator_, attribute_names_ = dataset.get_data(
            dataset_format="array"
        )
        df = pd.DataFrame(X_, columns=attribute_names_)
        headers = np.arange(len(df.columns))
        df.columns = headers
    elif file.name.endswith('.csv') or file.name.endswith('.data'):
        df = pd.read_csv(file.name, header='infer')
        headers = np.arange(len(df.columns))
        df.columns = headers
        
    df.iloc[0:remove_entries, -1] = ''
    return df


def update_table(table):
    global headers
    table = pd.DataFrame(table)
    vfunc = np.vectorize(lambda s: len(str(s)))
    non_empty_row_mask = (vfunc(table).sum(1) != 0)
    table = table[non_empty_row_mask]
    empty_mask = table == ''
    empty_inds = np.where(empty_mask)
    if not len(empty_inds[0]):
        return table
    
    y_column = empty_inds[1][0]
    eval_lines = empty_inds[0]
    
    table.iloc[eval_lines, y_column] = ''
    table.columns = headers
    
    return table

headers = []

gr.Markdown("""This demo allows you to play with the **TabPFN**.
            The TabPFN will classify the values for all empty cells in the label column.
            Please, provide everything but the label column as numeric values.
            You can also upload datasets to fill the table automatically.
                """)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            inp_table = gr.DataFrame(type='numpy', value=upload_file(Path('iris.csv'), remove_entries=10)
                                     , headers=[''] * 3)
            
            inp_file = gr.File(
            label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
            
            examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
                           inputs=[inp_file],
                           outputs=[inp_table],
                           fn=upload_file,
                           cache_examples=True)
            
            #inp_table.change(fn=update_table, inputs=inp_table, outputs=inp_table)
    
        with gr.Column():

            btn = gr.Button("Calculate Predictions")
            out_text = gr.Markdown()
            out_plot = gr.Plot()
            out_table = gr.DataFrame()
    
    btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table, out_plot])

    inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)

demo.launch(share=True)