Spaces:
Build error
Build error
import torch | |
import random | |
from torch.utils.checkpoint import checkpoint | |
from utils import normalize_data, to_ranking_low_mem, remove_outliers | |
from priors.utils import normalize_by_used_features_f | |
from utils import NOP | |
from sklearn.preprocessing import PowerTransformer, QuantileTransformer, RobustScaler | |
from notebook_utils import CustomUnpickler | |
import numpy as np | |
from sklearn.base import BaseEstimator, ClassifierMixin | |
from sklearn.utils.validation import check_X_y, check_array, check_is_fitted | |
from sklearn.utils.multiclass import check_classification_targets | |
from sklearn.utils import column_or_1d | |
from pathlib import Path | |
from model_builder import load_model | |
import os | |
def load_model_workflow(i, e, add_name, base_path, device='cpu', eval_addition=''): | |
""" | |
Workflow for loading a model and setting appropriate parameters for diffable hparam tuning. | |
:param i: | |
:param e: | |
:param eval_positions_valid: | |
:param add_name: | |
:param base_path: | |
:param device: | |
:param eval_addition: | |
:return: | |
""" | |
def check_file(e): | |
model_file = f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{e}.cpkt' | |
model_path = os.path.join(base_path, model_file) | |
# print('Evaluate ', model_path) | |
results_file = os.path.join(base_path, | |
f'models_diff/prior_diff_real_results{add_name}_n_{i}_epoch_{e}_{eval_addition}.pkl') | |
if not Path(model_path).is_file(): # or Path(results_file).is_file(): | |
return None, None, None | |
return model_file, model_path, results_file | |
model_file = None | |
if e == -1: | |
for e_ in range(100, -1, -1): | |
model_file_, model_path_, results_file_ = check_file(e_) | |
if model_file_ is not None: | |
e = e_ | |
model_file, model_path, results_file = model_file_, model_path_, results_file_ | |
break | |
else: | |
model_file, model_path, results_file = check_file(e) | |
if model_file is None: | |
print('No checkpoint found') | |
return None | |
print(f'Loading {model_file}') | |
model, c = load_model(base_path, model_file, device, eval_positions=[], verbose=False) | |
return model, c, results_file | |
class TabPFNClassifier(BaseEstimator, ClassifierMixin): | |
def __init__(self, device='cpu', base_path='.'): | |
# Model file specification (Model name, Epoch) | |
model_string = '' | |
i, e = '8x_lr0.0003', -1 | |
# File which contains result of hyperparameter tuning run: style (i.e. hyperparameters) and a dataframe with results. | |
style_file = 'prior_tuning_result.pkl' | |
model, c, results_file = load_model_workflow(i, e, add_name=model_string, base_path=base_path, device=device, | |
eval_addition='') | |
style, temperature = self.load_result_minimal(style_file, i, e, base_path=base_path) | |
self.device = device | |
self.base_path = base_path | |
self.model = model | |
self.c = c | |
self.style = style | |
self.temperature = temperature | |
self.max_num_features = self.c['num_features'] | |
self.max_num_classes = self.c['max_num_classes'] | |
def load_result_minimal(self, path, i, e, base_path='.'): | |
with open(os.path.join(base_path,path), 'rb') as output: | |
_, _, _, style, temperature, optimization_route = CustomUnpickler(output).load() | |
return style, temperature | |
def fit(self, X, y): | |
# Check that X and y have correct shape | |
X, y = check_X_y(X, y) | |
y = self._validate_targets(y) | |
self.X_ = X | |
self.y_ = y | |
if X.shape[1] > self.max_num_features: | |
raise ValueError("The number of features for this classifier is restricted to ", self.max_num_features) | |
if len(np.unique(y)) > self.max_num_classes: | |
raise ValueError("The number of classes for this classifier is restricted to ", self.max_num_classes) | |
# Return the classifier | |
return self | |
def _validate_targets(self, y): | |
y_ = column_or_1d(y, warn=True) | |
check_classification_targets(y) | |
cls, y = np.unique(y_, return_inverse=True) | |
if len(cls) < 2: | |
raise ValueError( | |
"The number of classes has to be greater than one; got %d class" | |
% len(cls) | |
) | |
self.classes_ = cls | |
return np.asarray(y, dtype=np.float64, order="C") | |
def predict_proba(self, X): | |
# Check is fit had been called | |
check_is_fitted(self) | |
# Input validation | |
X = check_array(X) | |
X_full = np.concatenate([self.X_, X], axis=0) | |
X_full = torch.tensor(X_full, device=self.device).float().unsqueeze(1) | |
y_full = np.concatenate([self.y_, self.y_[0] + np.zeros_like(X[:, 0])], axis=0) | |
y_full = torch.tensor(y_full, device=self.device).float().unsqueeze(1) | |
eval_pos = self.X_.shape[0] | |
prediction = transformer_predict(self.model[2], X_full, y_full, eval_pos, | |
device=self.device, | |
style=self.style, | |
inference_mode=True, | |
N_ensemble_configurations=10, | |
softmax_temperature=self.temperature | |
, **get_params_from_config(self.c)) | |
prediction_ = prediction.squeeze(0) | |
return prediction_.detach().cpu().numpy() | |
def predict(self, X, return_winning_probability=False): | |
p = self.predict_proba(X) | |
y = np.argmax(self.predict_proba(X), axis=-1) | |
y = self.classes_.take(np.asarray(y, dtype=np.intp)) | |
if return_winning_probability: | |
return y, p.max(axis=-1) | |
return y | |
def transformer_predict(model, eval_xs, eval_ys, eval_position, | |
device='cpu', | |
max_features=100, | |
style=None, | |
inference_mode=False, | |
num_classes=2, | |
extend_features=True, | |
normalize_to_ranking=False, | |
softmax_temperature=0.0, | |
multiclass_decoder='permutation', | |
preprocess_transform='mix', | |
categorical_feats=[], | |
feature_shift_decoder=True, | |
N_ensemble_configurations=10, | |
average_logits=True, | |
normalize_with_sqrt=False, **kwargs): | |
""" | |
:param model: | |
:param eval_xs: | |
:param eval_ys: should be classes that are 0-indexed and every class until num_classes-1 is present | |
:param eval_position: | |
:param rescale_features: | |
:param device: | |
:param max_features: | |
:param style: | |
:param inference_mode: | |
:param num_classes: | |
:param extend_features: | |
:param normalize_to_ranking: | |
:param softmax_temperature: | |
:param multiclass_decoder: | |
:param preprocess_transform: | |
:param categorical_feats: | |
:param feature_shift_decoder: | |
:param N_ensemble_configurations: | |
:param average_logits: | |
:param normalize_with_sqrt: | |
:param metric_used: | |
:return: | |
""" | |
num_classes = len(torch.unique(eval_ys)) | |
def predict(eval_xs, eval_ys, used_style, softmax_temperature, return_logits): | |
# Initialize results array size S, B, Classes | |
inference_mode_call = torch.inference_mode() if inference_mode else NOP() | |
with inference_mode_call: | |
output = model( | |
(used_style.repeat(eval_xs.shape[1], 1) if used_style is not None else None, eval_xs, eval_ys.float()), | |
single_eval_pos=eval_position)[:, :, 0:num_classes] | |
output = output[:, :, 0:num_classes] / torch.exp(softmax_temperature) | |
if not return_logits: | |
output = torch.nn.functional.softmax(output, dim=-1) | |
#else: | |
# output[:, :, 1] = model((style.repeat(eval_xs.shape[1], 1) if style is not None else None, eval_xs, eval_ys.float()), | |
# single_eval_pos=eval_position) | |
# output[:, :, 1] = torch.sigmoid(output[:, :, 1]).squeeze(-1) | |
# output[:, :, 0] = 1 - output[:, :, 1] | |
#print('RESULTS', eval_ys.shape, torch.unique(eval_ys, return_counts=True), output.mean(axis=0)) | |
return output | |
def preprocess_input(eval_xs, preprocess_transform): | |
import warnings | |
if eval_xs.shape[1] > 1: | |
raise Exception("Transforms only allow one batch dim - TODO") | |
if preprocess_transform != 'none': | |
if preprocess_transform == 'power' or preprocess_transform == 'power_all': | |
pt = PowerTransformer(standardize=True) | |
elif preprocess_transform == 'quantile' or preprocess_transform == 'quantile_all': | |
pt = QuantileTransformer(output_distribution='normal') | |
elif preprocess_transform == 'robust' or preprocess_transform == 'robust_all': | |
pt = RobustScaler(unit_variance=True) | |
# eval_xs, eval_ys = normalize_data(eval_xs), normalize_data(eval_ys) | |
eval_xs = normalize_data(eval_xs) | |
# Removing empty features | |
eval_xs = eval_xs[:, 0, :].cpu().numpy() | |
sel = [len(np.unique(eval_xs[0:eval_ys.shape[0], col])) > 1 for col in range(eval_xs.shape[1])] | |
eval_xs = np.array(eval_xs[:, sel]) | |
warnings.simplefilter('error') | |
if preprocess_transform != 'none': | |
feats = set(range(eval_xs.shape[1])) if 'all' in preprocess_transform else set( | |
range(eval_xs.shape[1])) - set(categorical_feats) | |
for col in feats: | |
try: | |
pt.fit(eval_xs[0:eval_ys.shape[0], col:col + 1]) | |
trans = pt.transform(eval_xs[:, col:col + 1]) | |
# print(scipy.stats.spearmanr(trans[~np.isnan(eval_xs[:, col:col+1])], eval_xs[:, col:col+1][~np.isnan(eval_xs[:, col:col+1])])) | |
eval_xs[:, col:col + 1] = trans | |
except: | |
pass | |
warnings.simplefilter('default') | |
eval_xs = torch.tensor(eval_xs).float().unsqueeze(1).to(device) | |
# eval_xs = normalize_data(eval_xs) | |
# TODO: Cautian there is information leakage when to_ranking is used, we should not use it | |
eval_xs = remove_outliers(eval_xs) if not normalize_to_ranking else normalize_data(to_ranking_low_mem(eval_xs)) | |
# Rescale X | |
eval_xs = normalize_by_used_features_f(eval_xs, eval_xs.shape[-1], max_features, | |
normalize_with_sqrt=normalize_with_sqrt) | |
return eval_xs.detach() | |
eval_xs, eval_ys = eval_xs.to(device), eval_ys.to(device) | |
eval_ys = eval_ys[:eval_position] | |
model.to(device) | |
style = style.to(device) | |
model.eval() | |
import itertools | |
style = style.unsqueeze(0) if len(style.shape) == 1 else style | |
num_styles = style.shape[0] | |
styles_configurations = range(0, num_styles) | |
preprocess_transform_configurations = [preprocess_transform if i % 2 == 0 else 'none' for i in range(0, num_styles)] | |
if preprocess_transform == 'mix': | |
def get_preprocess(i): | |
if i == 0: | |
return 'power_all' | |
if i == 1: | |
return 'robust_all' | |
if i == 2: | |
return 'none' | |
preprocess_transform_configurations = [get_preprocess(i) for i in range(0, num_styles)] | |
styles_configurations = zip(styles_configurations, preprocess_transform_configurations) | |
feature_shift_configurations = range(0, eval_xs.shape[2]) if feature_shift_decoder else [0] | |
class_shift_configurations = range(0, len(torch.unique(eval_ys))) if multiclass_decoder == 'permutation' else [0] | |
ensemble_configurations = list(itertools.product(styles_configurations, feature_shift_configurations, class_shift_configurations)) | |
random.shuffle(ensemble_configurations) | |
ensemble_configurations = ensemble_configurations[0:N_ensemble_configurations] | |
output = None | |
eval_xs_transformed = {} | |
for ensemble_configuration in ensemble_configurations: | |
(styles_configuration, preprocess_transform_configuration), feature_shift_configuration, class_shift_configuration = ensemble_configuration | |
style_ = style[styles_configuration:styles_configuration+1, :] | |
softmax_temperature_ = softmax_temperature[styles_configuration] | |
eval_xs_, eval_ys_ = eval_xs.clone(), eval_ys.clone() | |
if preprocess_transform_configuration in eval_xs_transformed: | |
eval_xs_ = eval_xs_transformed['preprocess_transform_configuration'].clone() | |
else: | |
eval_xs_ = preprocess_input(eval_xs_, preprocess_transform=preprocess_transform_configuration) | |
eval_xs_transformed['preprocess_transform_configuration'] = eval_xs_ | |
eval_ys_ = ((eval_ys_ + class_shift_configuration) % num_classes).float() | |
eval_xs_ = torch.cat([eval_xs_[..., feature_shift_configuration:],eval_xs_[..., :feature_shift_configuration]],dim=-1) | |
# Extend X | |
if extend_features: | |
eval_xs_ = torch.cat( | |
[eval_xs_, | |
torch.zeros((eval_xs_.shape[0], eval_xs_.shape[1], max_features - eval_xs_.shape[2])).to(device)], -1) | |
#preprocess_transform_ = preprocess_transform if styles_configuration % 2 == 0 else 'none' | |
import warnings | |
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None") | |
output_ = checkpoint(predict, eval_xs_, eval_ys_, style_, softmax_temperature_, True) | |
output_ = torch.cat([output_[..., class_shift_configuration:],output_[..., :class_shift_configuration]],dim=-1) | |
#output_ = predict(eval_xs, eval_ys, style_, preprocess_transform_) | |
if not average_logits: | |
output_ = torch.nn.functional.softmax(output_, dim=-1) | |
output = output_ if output is None else output + output_ | |
output = output / len(ensemble_configurations) | |
if average_logits: | |
output = torch.nn.functional.softmax(output, dim=-1) | |
output = torch.transpose(output, 0, 1) | |
return output | |
def get_params_from_config(c): | |
return {'max_features': c['num_features'] | |
, 'rescale_features': c["normalize_by_used_features"] | |
, 'normalize_to_ranking': c["normalize_to_ranking"] | |
, 'normalize_with_sqrt': c.get("normalize_with_sqrt", False) | |
} |