TabPFN commited on
Commit
def7fc3
1 Parent(s): a0b5459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -14,7 +14,9 @@ import os
14
  import matplotlib.pyplot as plt
15
  from matplotlib.colors import ListedColormap
16
 
17
- classifier = TabPFNClassifier(base_path=tabpfn_path, device='cpu', N_ensemble_configurations=4)
 
 
18
 
19
 
20
  def compute(df_table):
 
14
  import matplotlib.pyplot as plt
15
  from matplotlib.colors import ListedColormap
16
 
17
+ default_device = "cuda:0" if torch.cuda.is_available() else "cpu:0"
18
+
19
+ classifier = TabPFNClassifier(base_path=tabpfn_path, device=default_device, N_ensemble_configurations=4)
20
 
21
 
22
  def compute(df_table):