TabPFNEvaluationDemo / TabPFN /tabular_evaluation.py
Samuel Mueller
init
e487255
import time
import os
from pathlib import Path
from tqdm import tqdm
import random
import numpy as np
from torch import nn
from utils import torch_nanmean
from datasets import *
from model_builder import load_model
from scripts.tabular_baselines import get_scoring_string
from scripts import tabular_metrics
from scripts.transformer_prediction_interface import *
from scripts.baseline_prediction_interface import *
"""
===============================
PUBLIC FUNCTIONS FOR EVALUATION
===============================
"""
def eval_model(i, e, valid_datasets, test_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
metrics_test, config_sample, model_path = eval_model_on_ds(i, e, test_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
metrics_valid, _, _ = eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device=device, eval_addition=eval_addition, **kwargs)
return {'mean_auc_test': metrics_test['mean_roc_at_1000'], 'mean_auc_valid': metrics_valid['mean_roc_at_1000'], 'mean_ce_test': metrics_test['mean_ce_at_1000'], 'mean_ce_valid': metrics_valid['mean_ce_at_1000'], 'config_sample': config_sample, 'model_path': model_path}
def eval_model_on_ds(i, e, valid_datasets, eval_positions, bptt, add_name, base_path, device='cpu', eval_addition='', **kwargs):
# How to use: evaluate_without_fitting(i,0,valid_datasets, [1024], 100000, add_name=model_string, base_path=base_path,)
def check_file(e):
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt'
model_path = os.path.join(base_path, model_file)
# print('Evaluate ', model_path)
results_file = os.path.join(base_path,
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl')
if not Path(model_path).is_file(): # or Path(results_file).is_file():
# print('checkpoint exists: ', Path(model_file).is_file(), ', results are written:', Path(results_file).is_file())
return None, None, None
return model_file, model_path, results_file
if e == -1: # use last checkpoint, if e == -1
for e_ in range(100, -1, -1):
model_file_, model_path_, results_file_ = check_file(e_)
if model_file_ is not None:
e = e_
model_file, model_path, results_file = model_file_, model_path_, results_file_
break
else:
model_file, model_path, results_file = check_file(e)
model, config_sample = load_model(base_path, model_file, device, None, verbose=False)
params = {'max_features': config_sample['num_features']
, 'rescale_features': config_sample["normalize_by_used_features"]
, 'normalize_to_ranking': config_sample["normalize_to_ranking"]
, 'normalize_with_sqrt': config_sample.get("normalize_with_sqrt", False)
}
metrics_valid = evaluate(datasets=valid_datasets, model=model[2], method='transformer', device=device, overwrite=True,
extend_features=True
# just removed the style keyword but transformer is trained with style, just empty
, save=False
, metric_used=tabular_metrics.cross_entropy
, return_tensor=True
, verbose=False
, eval_positions=eval_positions
, bptt=bptt
, base_path=None
, inference_mode=True
, **params
, **kwargs)
tabular_metrics.calculate_score_per_method(tabular_metrics.auc_metric, 'roc', metrics_valid, valid_datasets, eval_positions)
tabular_metrics.calculate_score_per_method(tabular_metrics.cross_entropy, 'ce', metrics_valid, valid_datasets, eval_positions)
return metrics_valid, config_sample, model_path
def evaluate(datasets, bptt, eval_positions, metric_used, model
, verbose=False
, return_tensor=False
, **kwargs):
"""
Evaluates a list of datasets for a model function.
:param datasets: List of datasets
:param bptt: maximum sequence length
:param eval_positions: List of positions where to evaluate models
:param verbose: If True, is verbose.
:param metric_used: Which metric is optimized for.
:param return_tensor: Wheater to return results as a pytorch.tensor or numpy, this is only relevant for transformer.
:param kwargs:
:return:
"""
overall_result = {'metric_used': get_scoring_string(metric_used)
, 'bptt': bptt
, 'eval_positions': eval_positions}
aggregated_metric_datasets, num_datasets = torch.tensor(0.0), 0
# For each dataset
for [ds_name, X, y, categorical_feats, _, _] in tqdm.tqdm(datasets, desc='Iterate over datasets') if verbose else datasets:
dataset_bptt = min(len(X), bptt)
# if verbose and dataset_bptt < bptt:
# print(f'Dataset too small for given sequence length, reducing to {len(X)} ({bptt})')
aggregated_metric, num = torch.tensor(0.0), 0
ds_result = {}
for eval_position in (eval_positions if verbose else eval_positions):
eval_position_real = int(dataset_bptt * 0.5) if 2 * eval_position > dataset_bptt else eval_position
eval_position_bptt = int(eval_position_real * 2.0)
r = evaluate_position(X, y, model=model
, num_classes=len(torch.unique(y))
, categorical_feats = categorical_feats
, bptt = eval_position_bptt
, ds_name=ds_name
, eval_position = eval_position_real
, metric_used = metric_used
,**kwargs)
if r is None:
continue
_, outputs, ys, best_configs, time_used = r
if torch.is_tensor(outputs):
outputs = outputs.to(outputs.device)
ys = ys.to(outputs.device)
ys = ys.T
ds_result[f'{ds_name}_best_configs_at_{eval_position}'] = best_configs
ds_result[f'{ds_name}_outputs_at_{eval_position}'] = outputs
ds_result[f'{ds_name}_ys_at_{eval_position}'] = ys
ds_result[f'{ds_name}_time_at_{eval_position}'] = time_used
new_metric = torch_nanmean(torch.stack([metric_used(ys[i], outputs[i]) for i in range(ys.shape[0])]))
if not return_tensor:
make_scalar = lambda x: float(x.detach().cpu().numpy()) if (torch.is_tensor(x) and (len(x.shape) == 0)) else x
new_metric = make_scalar(new_metric)
ds_result = {k: make_scalar(ds_result[k]) for k in ds_result.keys()}
lib = torch if return_tensor else np
if not lib.isnan(new_metric).any():
aggregated_metric, num = aggregated_metric + new_metric, num + 1
overall_result.update(ds_result)
if num > 0:
aggregated_metric_datasets, num_datasets = (aggregated_metric_datasets + (aggregated_metric / num)), num_datasets + 1
overall_result['mean_metric'] = aggregated_metric_datasets / num_datasets
return overall_result
"""
===============================
INTERNAL HELPER FUNCTIONS
===============================
"""
def check_file_exists(path):
"""Checks if a pickle file exists. Returns None if not, else returns the unpickled file."""
if (os.path.isfile(path)):
print(f'loading results from {path}')
with open(path, 'rb') as f:
return np.load(f, allow_pickle=True).tolist()
return None
def generate_valid_split(X, y, bptt, eval_position, split_number=1):
"""Generates a deteministic train-(test/valid) split. Both splits must contain the same classes and all classes in
the entire datasets. If no such split can be sampled in 7 passes, returns None.
:param X: torch tensor, feature values
:param y: torch tensor, class values
:param bptt: Number of samples in train + test
:param eval_position: Number of samples in train, i.e. from which index values are in test
:param split_number: The split id
:return:
"""
done, seed = False, 13
torch.manual_seed(split_number)
perm = torch.randperm(X.shape[0]) if split_number > 1 else torch.arange(0, X.shape[0])
X, y = X[perm], y[perm]
while not done:
if seed > 20:
return None, None # No split could be generated in 7 passes, return None
random.seed(seed)
i = random.randint(0, len(X) - bptt) if len(X) - bptt > 0 else 0
y_ = y[i:i + bptt]
# Checks if all classes from dataset are contained and classes in train and test are equal (contain same
# classes) and
done = len(torch.unique(y_)) == len(torch.unique(y))
done = done and torch.all(torch.unique(y_) == torch.unique(y))
done = done and len(torch.unique(y_[:eval_position])) == len(torch.unique(y_[eval_position:]))
done = done and torch.all(torch.unique(y_[:eval_position]) == torch.unique(y_[eval_position:]))
seed = seed + 1
eval_xs = torch.stack([X[i:i + bptt].clone()], 1)
eval_ys = torch.stack([y[i:i + bptt].clone()], 1)
return eval_xs, eval_ys
def evaluate_position(X, y, categorical_feats, model, bptt
, eval_position, overwrite, save, base_path, path_interfix, method, ds_name, fetch_only=False
, max_time=300, split_number=1
, per_step_normalization=False, **kwargs):
"""
Evaluates a dataset with a 'bptt' number of training samples.
:param X: Dataset X
:param y: Dataset labels
:param categorical_feats: Indices of categorical features.
:param model: Model function
:param bptt: Sequence length.
:param eval_position: Number of training samples.
:param overwrite: Wheater to ove
:param overwrite: If True, results on disk are overwritten.
:param save:
:param path_interfix: Used for constructing path to write on disk.
:param method: Model name.
:param ds_name: Datset name.
:param fetch_only: Wheater to calculate or only fetch results.
:param per_step_normalization:
:param kwargs:
:return:
"""
if save:
path = os.path.join(base_path, f'results/tabular/{path_interfix}/results_{method}_{ds_name}_{eval_position}_{bptt}_{split_number}.npy')
#log_path =
## Load results if on disk
if not overwrite:
result = check_file_exists(path)
if result is not None:
if not fetch_only:
print(f'Loaded saved result for {path}')
return result
elif fetch_only:
print(f'Could not load saved result for {path}')
return None
## Generate data splits
eval_xs, eval_ys = generate_valid_split(X, y, bptt, eval_position, split_number=split_number)
if eval_xs is None:
return None
print(f"No dataset could be generated {ds_name} {bptt}")
eval_ys = (eval_ys > torch.unique(eval_ys).unsqueeze(0)).sum(axis=1).unsqueeze(-1)
start_time = time.time()
if isinstance(model, nn.Module): # Two separate predict interfaces for transformer and baselines
outputs, best_configs = transformer_predict(model, eval_xs, eval_ys, eval_position, categorical_feats=categorical_feats, **kwargs), None
else:
_, outputs, best_configs = baseline_predict(model, eval_xs, eval_ys, categorical_feats
, eval_pos=eval_position
, max_time=max_time, **kwargs)
eval_ys = eval_ys[eval_position:]
if outputs is None:
return None
if torch.is_tensor(outputs): # Transfers data to cpu for saving
outputs = outputs.cpu()
eval_ys = eval_ys.cpu()
ds_result = None, outputs, eval_ys, best_configs, time.time() - start_time
if save:
with open(path, 'wb') as f:
np.save(f, ds_result)
print(f'saved results to {path}')
return ds_result