TabPFN commited on
Commit
ee8b768
1 Parent(s): 9645115

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -64,9 +64,7 @@ def compute(table: np.array):
64
 
65
  ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train_index, cmap=cm_bright)
66
 
67
- classifier = TabPFNClassifier(device='cpu', base_path='/home/hollmann/',
68
- model_string=model_string, N_ensemble_configurations=1
69
- , no_preprocess_mode=False, i=i, feature_shift_decoder=True, multiclass_decoder='permutation')
70
  classifier.fit(x_train[:, 0:2], y_train)
71
 
72
  DecisionBoundaryDisplay.from_estimator(
@@ -86,13 +84,14 @@ def upload_file(file, remove_entries=10):
86
  dataset_format="array"
87
  )
88
  df = pd.DataFrame(X_, columns=attribute_names_)
89
- headers = df.columns
 
90
  elif file.name.endswith('.csv') or file.name.endswith('.data'):
91
  df = pd.read_csv(file.name, header='infer')
92
- headers = df.columns
93
- #df.columns = np.arange(len(df.columns))
94
 
95
- df.iloc[0:remove_entries, -1] = '(predict)'
96
  return df
97
 
98
 
@@ -110,7 +109,7 @@ def update_table(table):
110
  y_column = empty_inds[1][0]
111
  eval_lines = empty_inds[0]
112
 
113
- table.iloc[eval_lines, y_column] = '(predict)'
114
  table.columns = headers
115
 
116
  return table
@@ -124,28 +123,31 @@ gr.Markdown("""This demo allows you to play with the **TabPFN**.
124
  """)
125
 
126
  with gr.Blocks() as demo:
127
- with gr.Tab("Enter Input Data"):
128
- inp_file = gr.File(
129
- label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
130
-
131
- inp_table = gr.DataFrame(type='numpy', value=upload_file(Path('iris.csv'), remove_entries=10), headers=[''] * 3)
132
- inp_table.change(fn=update_table, inputs=inp_table, outputs=inp_table)
133
-
134
- with gr.Tab("Run Predictions"):
135
-
136
- btn = gr.Button("Start")
137
- out_text = gr.Markdown()
138
- out_table = gr.DataFrame()
139
- out_plot = gr.Plot()
140
-
141
- btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table, out_plot])
142
-
143
- examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
144
  inputs=[inp_file],
145
  outputs=[inp_table],
146
  fn=upload_file,
147
  cache_examples=True)
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
150
 
151
- demo.launch()
 
64
 
65
  ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train_index, cmap=cm_bright)
66
 
67
+ classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu')
 
 
68
  classifier.fit(x_train[:, 0:2], y_train)
69
 
70
  DecisionBoundaryDisplay.from_estimator(
 
84
  dataset_format="array"
85
  )
86
  df = pd.DataFrame(X_, columns=attribute_names_)
87
+ headers = np.arange(len(df.columns))
88
+ df.columns = headers
89
  elif file.name.endswith('.csv') or file.name.endswith('.data'):
90
  df = pd.read_csv(file.name, header='infer')
91
+ headers = np.arange(len(df.columns))
92
+ df.columns = headers
93
 
94
+ df.iloc[0:remove_entries, -1] = ''
95
  return df
96
 
97
 
 
109
  y_column = empty_inds[1][0]
110
  eval_lines = empty_inds[0]
111
 
112
+ table.iloc[eval_lines, y_column] = ''
113
  table.columns = headers
114
 
115
  return table
 
123
  """)
124
 
125
  with gr.Blocks() as demo:
126
+ with gr.Row():
127
+ with gr.Column(scale=1):
128
+ inp_table = gr.DataFrame(type='numpy', value=upload_file(Path('iris.csv'), remove_entries=10)
129
+ , headers=[''] * 3)
130
+
131
+ inp_file = gr.File(
132
+ label='Drop either a .csv (without header, only numeric values for all but the labels) or a .arff file.')
133
+
134
+ examples = gr.Examples(examples=['iris.csv', 'balance-scale.arff'],
 
 
 
 
 
 
 
 
135
  inputs=[inp_file],
136
  outputs=[inp_table],
137
  fn=upload_file,
138
  cache_examples=True)
139
+
140
+ #inp_table.change(fn=update_table, inputs=inp_table, outputs=inp_table)
141
+
142
+ with gr.Column(scale=1):
143
+
144
+ btn = gr.Button("Calculate Predictions")
145
+ out_text = gr.Markdown()
146
+ out_plot = gr.Plot()
147
+ out_table = gr.DataFrame()
148
+
149
+ btn.click(fn=compute, inputs=inp_table, outputs=[out_text, out_table, out_plot])
150
 
151
  inp_file.change(fn=upload_file, inputs=inp_file, outputs=inp_table)
152
 
153
+ demo.launch(share=True)