PFNEngineeringConstrainedBO / pfns4bo /scripts /acquisition_functions.py
rosenyu's picture
Upload 529 files
165ee00 verified
import contextlib
from .. import transformer
from .. import bar_distribution
import torch
import scipy
import math
from sklearn.preprocessing import power_transform, PowerTransformer
def log01(x, eps=.0000001, input_between_zero_and_one=False):
logx = torch.log(x + eps)
if input_between_zero_and_one:
return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps))
return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0])
def log01_batch(x, eps=.0000001, input_between_zero_and_one=False):
x = x.repeat(1, x.shape[-1] + 1, 1)
for b in range(x.shape[-1]):
x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one)
return x
def lognormed_batch(x, eval_pos, eps=.0000001):
x = x.repeat(1, x.shape[-1] + 1, 1)
for b in range(x.shape[-1]):
logx = torch.log(x[:, b, b]+eps)
x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0)
return x
def _rank_transform(x_train, x):
assert len(x_train.shape) == len(x.shape) == 1
relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1)
higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1)
pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1])
x_transformed = higher_comparison - 1 + pos_inside_interval
return x_transformed/(len(relative_to)-1.)
def rank_transform(x_train, x):
assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}"
# make sure everything is between 0 and 1
assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}"
assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}"
return_x = x.clone()
for feature_dim in range(x.shape[1]):
return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim])
return return_x
def general_power_transform(x_train, x_apply, eps, less_safe=False):
if eps > 0:
try:
pt = PowerTransformer(method='box-cox')
pt.fit(x_train.cpu()+eps)
x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device)
except ValueError as e:
print(e)
x_out = x_apply - x_train.mean(0)
else:
pt = PowerTransformer(method='yeo-johnson')
if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000):
x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
x_train = (x_train - x_train.mean(0)) / x_train.std(0)
# print('inputs are LAARGEe, normalizing them')
try:
pt.fit(x_train.cpu().double())
except ValueError as e:
print('caught this errrr', e)
if less_safe:
x_train = (x_train - x_train.mean(0)) / x_train.std(0)
x_apply = (x_apply - x_train.mean(0)) / x_train.std(0)
else:
x_train = x_train - x_train.mean(0)
x_apply = x_apply - x_train.mean(0)
pt.fit(x_train.cpu().double())
x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device)
if torch.isnan(x_out).any() or torch.isinf(x_out).any():
print('WARNING: power transform failed')
print(f"{x_train=} and {x_apply=}")
x_out = x_apply - x_train.mean(0)
return x_out
#@torch.inference_mode()
def general_acq_function(model: transformer.TransformerModel, x_given, y_given, x_eval, apply_power_transform=True,
rand_sample=False, znormalize=False, pre_normalize=False, pre_znormalize=False, predicted_mean_fbest=False,
input_znormalize=False, max_dataset_size=10_000, remove_features_with_one_value_only=False,
return_actual_ei=False, acq_function='ei', ucb_rest_prob=.05, ensemble_log_dims=False,
ensemble_type='mean_probs', # in ('mean_probs', 'max_acq')
input_power_transform=False, power_transform_eps=.0, input_power_transform_eps=.0,
input_rank_transform=False, ensemble_input_rank_transform=False,
ensemble_power_transform=False, ensemble_feature_rotation=False,
style=None, outlier_stretching_interval=0.0, verbose=False, unsafe_power_transform=False,
):
"""
Differences to HEBO:
- The noise can't be set in the same way, as it depends on the tuning of HPs via VI.
- Log EI and PI are always used directly instead of using the approximation.
This is a stochastic function, relying on torch.randn
:param model:
:param x_given: torch.Tensor of shape (N, D)
:param y_given: torch.Tensor of shape (N, 1) or (N,)
:param x_eval: torch.Tensor of shape (M, D)
:param kappa:
:param eps:
:return:
"""
assert ensemble_type in ('mean_probs', 'max_acq')
if rand_sample is not False \
and (len(x_given) == 0 or
((1 + x_given.shape[1] if rand_sample is None else max(2, rand_sample)) > x_given.shape[0])):
print('rando')
return torch.zeros_like(x_eval[:,0]) #torch.randperm(x_eval.shape[0])[0]
y_given = y_given.reshape(-1)
assert len(y_given) == len(x_given)
if apply_power_transform:
if pre_normalize:
y_normed = y_given / y_given.std()
if not torch.isinf(y_normed).any() and not torch.isnan(y_normed).any():
y_given = y_normed
elif pre_znormalize:
y_znormed = (y_given - y_given.mean()) / y_given.std()
if not torch.isinf(y_znormed).any() and not torch.isnan(y_znormed).any():
y_given = y_znormed
y_given = general_power_transform(y_given.unsqueeze(1), y_given.unsqueeze(1), power_transform_eps, less_safe=unsafe_power_transform).squeeze(1)
if verbose:
print(f"{y_given=}")
#y_given = torch.tensor(power_transform(y_given.cpu().unsqueeze(1), method='yeo-johnson', standardize=znormalize), device=y_given.device, dtype=y_given.dtype,).squeeze(1)
y_given_std = torch.tensor(1., device=y_given.device, dtype=y_given.dtype)
if znormalize and not apply_power_transform:
if len(y_given) > 1:
y_given_std = y_given.std()
y_given_mean = y_given.mean()
y_given = (y_given - y_given_mean) / y_given_std
if remove_features_with_one_value_only:
x_all = torch.cat([x_given, x_eval], dim=0)
only_one_value_feature = torch.tensor([len(torch.unique(x_all[:,i])) for i in range(x_all.shape[1])]) == 1
x_given = x_given[:,~only_one_value_feature]
x_eval = x_eval[:,~only_one_value_feature]
if outlier_stretching_interval > 0.:
tx = torch.cat([x_given, x_eval], dim=0)
m = outlier_stretching_interval
eps = 1e-10
small_values = (tx < m) & (tx > 0.)
tx[small_values] = m * (torch.log(tx[small_values] + eps) - math.log(eps)) / (math.log(m + eps) - math.log(eps))
large_values = (tx > 1. - m) & (tx < 1.)
tx[large_values] = 1. - m * (torch.log(1 - tx[large_values] + eps) - math.log(eps)) / (
math.log(m + eps) - math.log(eps))
x_given = tx[:len(x_given)]
x_eval = tx[len(x_given):]
if input_znormalize: # implementation that relies on the test set, too...
std = x_given.std(dim=0)
std[std == 0.] = 1.
mean = x_given.mean(dim=0)
x_given = (x_given - mean) / std
x_eval = (x_eval - mean) / std
if input_power_transform:
x_given = general_power_transform(x_given, x_given, input_power_transform_eps)
x_eval = general_power_transform(x_given, x_eval, input_power_transform_eps)
if input_rank_transform is True or input_rank_transform == 'full': # uses test set x statistics...
x_all = torch.cat((x_given,x_eval), dim=0)
for feature_dim in range(x_all.shape[-1]):
uniques = torch.sort(torch.unique(x_all[..., feature_dim])).values
x_eval[...,feature_dim] = torch.searchsorted(uniques,x_eval[..., feature_dim]).float() / (len(uniques)-1)
x_given[...,feature_dim] = torch.searchsorted(uniques,x_given[..., feature_dim]).float() / (len(uniques)-1)
elif input_rank_transform is False:
pass
elif input_rank_transform == 'train':
x_given = rank_transform(x_given, x_given)
x_eval = rank_transform(x_given, x_eval)
elif input_rank_transform.startswith('train'):
likelihood = float(input_rank_transform.split('_')[-1])
if torch.rand(1).item() < likelihood:
print('rank transform')
x_given = rank_transform(x_given, x_given)
x_eval = rank_transform(x_given, x_eval)
else:
raise NotImplementedError
# compute logits
criterion: bar_distribution.BarDistribution = model.criterion
x_predict = torch.cat([x_given, x_eval], dim=0)
logits_list = []
for x_feed in torch.split(x_predict, max_dataset_size, dim=0):
x_full_feed = torch.cat([x_given, x_feed], dim=0).unsqueeze(1)
y_full_feed = y_given.unsqueeze(1)
if ensemble_log_dims == '01':
x_full_feed = log01_batch(x_full_feed)
elif ensemble_log_dims == 'global01' or ensemble_log_dims is True:
x_full_feed = log01_batch(x_full_feed, input_between_zero_and_one=True)
elif ensemble_log_dims == '01-10':
x_full_feed = torch.cat((log01_batch(x_full_feed)[:, :-1], log01_batch(1. - x_full_feed)), 1)
elif ensemble_log_dims == 'norm':
x_full_feed = lognormed_batch(x_full_feed, len(x_given))
elif ensemble_log_dims is not False:
raise NotImplementedError
if ensemble_feature_rotation:
x_full_feed = torch.cat([x_full_feed[:, :, (i+torch.arange(x_full_feed.shape[2])) % x_full_feed.shape[2]] for i in range(x_full_feed.shape[2])], dim=1)
if ensemble_input_rank_transform == 'train' or ensemble_input_rank_transform is True:
x_full_feed = torch.cat([rank_transform(x_given, x_full_feed[:,i,:])[:,None] for i in range(x_full_feed.shape[1])] + [x_full_feed], dim=1)
if ensemble_power_transform:
assert apply_power_transform is False
y_full_feed = torch.cat((general_power_transform(y_full_feed, y_full_feed, power_transform_eps), y_full_feed), dim=1)
if style is not None:
if callable(style):
style = style()
if isinstance(style, torch.Tensor):
style = style.to(x_full_feed.device)
else:
style = torch.tensor(style, device=x_full_feed.device).view(1, 1).repeat(x_full_feed.shape[1], 1)
logits = model(
(style,
x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]),
y_full_feed.repeat(1,x_full_feed.shape[1])),
single_eval_pos=len(x_given)
)
if ensemble_type == 'mean_probs':
logits = logits.softmax(-1).mean(1, keepdim=True).log_() # (num given + num eval, 1, num buckets)
# print('in ensemble_type == mean_probs ')
logits_list.append(logits) # (< max_dataset_size, 1 , num_buckets)
logits = torch.cat(logits_list, dim=0) # (num given + num eval, 1 or (num_features+1), num buckets)
del logits_list, x_full_feed
if torch.isnan(logits).any():
print('nan logits')
print(f"y_given: {y_given}, x_given: {x_given}, x_eval: {x_eval}")
print(f"logits: {logits}")
return torch.zeros_like(x_eval[:,0])
#logits = model((torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1),
# torch.cat([y_given, torch.zeros(len(x_eval)+len(x_given), device=y_given.device)], dim=0).unsqueeze(1)),
# single_eval_pos=len(x_given))[:,0] # (N + M, num_buckets)
logits_given = logits[:len(x_given)]
logits_eval = logits[len(x_given):]
#tau = criterion.mean(logits_given)[torch.argmax(y_given)] # predicted mean at the best y
if predicted_mean_fbest:
tau = criterion.mean(logits_given)[torch.argmax(y_given)].squeeze(0)
else:
tau = torch.max(y_given)
#log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0)
def acq_ensembling(acq_values): # (points, ensemble dim)
return acq_values.max(1).values
if isinstance(acq_function, (dict,list)):
acq_function = acq_function[style]
if acq_function == 'ei':
acq_value = acq_ensembling(criterion.ei(logits_eval, tau))
elif acq_function == 'ei_or_rand':
if torch.rand(1).item() < 0.5:
acq_value = torch.rand(len(x_eval))
else:
acq_value = acq_ensembling(criterion.ei(logits_eval, tau))
elif acq_function == 'pi':
acq_value = acq_ensembling(criterion.pi(logits_eval, tau))
elif acq_function == 'ucb':
acq_function = criterion.ucb
if ucb_rest_prob is not None:
acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob)
acq_value = acq_ensembling(acq_function(logits_eval, tau))
elif acq_function == 'mean':
acq_value = acq_ensembling(criterion.mean(logits_eval))
elif acq_function.startswith('hebo'):
noise, upsi, delta, eps = (float(v) for v in acq_function.split('_')[1:])
noise = y_given_std * math.sqrt(2 * noise)
kappa = math.sqrt(upsi * 2 * ((2.0 + x_given.shape[1] / 2.0) * math.log(max(1, len(x_given))) + math.log(
3 * math.pi ** 2 / (3 * delta))))
rest_prob = 1. - .5 * (1 + torch.erf(torch.tensor(kappa / math.sqrt(2), device=logits.device)))
ucb = acq_ensembling(criterion.ucb(logits_eval, None, rest_prob=rest_prob)) \
+ torch.randn(len(logits_eval), device=logits_eval.device) * noise
noisy_best_f = tau + eps + \
noise * torch.randn(len(logits_eval), device=logits_eval.device)[:, None].repeat(1, logits_eval.shape[1])
log_pi = acq_ensembling(criterion.pi(logits_eval, noisy_best_f).log())
# log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0)
log_ei = acq_ensembling(criterion.ei(logits_eval, noisy_best_f).log())
acq_values = torch.stack([ucb, log_ei, log_pi], dim=1)
def is_pareto_efficient(costs):
"""
Find the pareto-efficient points
:param costs: An (n_points, n_costs) array
:return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient
"""
is_efficient = torch.ones(costs.shape[0], dtype=bool, device=costs.device)
for i, c in enumerate(costs):
if is_efficient[i]:
is_efficient[is_efficient.clone()] = (costs[is_efficient] < c).any(
1) # Keep any point with a lower cost
is_efficient[i] = True # And keep self
return is_efficient
acq_value = is_pareto_efficient(-acq_values)
else:
raise ValueError(f'Unknown acquisition function: {acq_function}')
max_acq = acq_value.max()
return acq_value if return_actual_ei else (acq_value == max_acq)
def optimize_acq(model, known_x, known_y, num_grad_steps=10, num_random_samples=100, lr=.01, **kwargs):
"""
intervals are assumed to be between 0 and 1
only works with ei
recommended extra kwarg: ensemble_input_rank_transform=='train'
:param model: model to optimize, should already handle different num_features with its encoder
You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)`
:param known_x: (N, num_features)
:param known_y: (N,)
:param num_grad_steps: int
:param num_random_samples: int
:param lr: float
:param kwargs: will be given to `general_acq_function`
:return:
"""
x_eval = torch.rand(num_random_samples, known_x.shape[1]).requires_grad_(True)
opt = torch.optim.Adam(params=[x_eval], lr=lr)
best_acq, best_x = -float('inf'), x_eval[0].detach()
for grad_step in range(num_grad_steps):
acq = general_acq_function(model, known_x, known_y, x_eval, return_actual_ei=True, **kwargs)
max_acq = acq.detach().max().item()
if max_acq > best_acq:
best_x = x_eval[acq.argmax()].detach()
best_acq = max_acq
(-acq.mean()).backward()
assert (x_eval.grad != 0.).any()
if torch.isfinite(x_eval.grad).all():
opt.step()
opt.zero_grad()
with torch.no_grad():
x_eval.clamp_(min=0., max=1.)
return best_x
def optimize_acq_w_lbfgs(model, known_x, known_y, num_grad_steps=15_000, num_candidates=100, pre_sample_size=100_000, device='cpu',
verbose=False, dims_wo_gradient_opt=[], rand_sample_func=None, **kwargs):
"""
intervals are assumed to be between 0 and 1
only works with deterministic acq
recommended extra kwarg: ensemble_input_rank_transform=='train'
:param model: model to optimize, should already handle different num_features with its encoder
You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)`
:param known_x: (N, num_features)
:param known_y: (N,)
:param num_grad_steps: int: how many steps to take inside of scipy, can be left high, as it stops most of the time automatically early
:param num_candidates: int: how many candidates to optimize with LBFGS, increases costs when higher
:param pre_sample_size: int: how many settings to try first with a random search, before optimizing the best with grads
:param dims_wo_gradient_opt: int: which dimensions to not optimize with gradients, but with random search only
:param rand_sample_func: function: how to sample random points, should be a function that takes a number of samples and returns a tensor
For example `lambda n: torch.rand(n, known_x.shape[1])`.
:param kwargs: will be given to `general_acq_function`
:return:
"""
num_features = known_x.shape[1]
dims_w_gradient_opt = sorted(set(range(num_features)) - set(dims_wo_gradient_opt))
known_x = known_x.to(device)
known_y = known_y.to(device)
pre_sample_size = max(pre_sample_size, num_candidates)
rand_sample_func = rand_sample_func or (lambda n: torch.rand(n, num_features, device=device))
if len(known_x) < pre_sample_size:
x_initial = torch.cat((rand_sample_func(pre_sample_size-len(known_x)).to(device), known_x), 0)
else:
x_initial = rand_sample_func(pre_sample_size)
x_initial = x_initial.clamp(min=0., max=1.)
x_initial_all = x_initial
model.to(device)
with torch.no_grad():
acq = general_acq_function(model, known_x, known_y, x_initial.to(device), return_actual_ei=True, **kwargs)
if verbose:
import matplotlib.pyplot as plt
if x_initial.shape[1] == 2:
plt.title('initial acq values, red -> blue')
plt.scatter(x_initial[:, 0][:100], x_initial[:, 1][:100], c=acq.cpu().numpy()[:100], cmap='RdBu')
x_initial = x_initial[acq.argsort(descending=True)[:num_candidates].cpu()].detach() # num_candidates x num_features
x_initial_all_ei = acq.cpu().detach()
def opt_f(x):
x_eval = torch.tensor(x).view(-1, len(dims_w_gradient_opt)).float().to(device).requires_grad_(True)
x_eval_new = x_initial.clone().detach().to(device)
x_eval_new[:, dims_w_gradient_opt] = x_eval
assert x_eval_new.requires_grad
assert not torch.isnan(x_eval_new).any()
model.requires_grad_(False)
acq = general_acq_function(model, known_x, known_y, x_eval_new, return_actual_ei=True, **kwargs)
neg_mean_acq = -acq.mean()
neg_mean_acq.backward()
#print(neg_mean_acq.detach().numpy(), x_eval.grad.detach().view(*x.shape).numpy())
with torch.no_grad():
x_eval.grad[x_eval.grad != x_eval.grad] = 0.
return neg_mean_acq.detach().cpu().to(torch.float64).numpy(), \
x_eval.grad.detach().view(*x.shape).cpu().to(torch.float64).numpy()
# Optimize best candidates with LBFGS
if num_grad_steps > 0 and len(dims_w_gradient_opt) > 0:
# the columns not in dims_wo_gradient_opt will be optimized with gradients
x_initial_for_gradient_opt = x_initial[:, dims_w_gradient_opt].detach().cpu().flatten().numpy() # x_initial.cpu().flatten().numpy()
res = scipy.optimize.minimize(opt_f, x_initial_for_gradient_opt, method='L-BFGS-B', jac=True,
bounds=[(0, 1)]*x_initial_for_gradient_opt.size,
options={'maxiter': num_grad_steps})
results = x_initial.cpu()
results[:, dims_w_gradient_opt] = torch.tensor(res.x).float().view(-1, len(dims_w_gradient_opt))
else:
results = x_initial.cpu()
results = results.clamp(min=0., max=1.)
# Recalculate the acq values for the best candidates
with torch.no_grad():
acq = general_acq_function(model, known_x, known_y, results.to(device), return_actual_ei=True, verbose=verbose, **kwargs)
#print(acq)
if verbose:
from scipy.stats import rankdata
import matplotlib.pyplot as plt
if results.shape[1] == 2:
plt.scatter(results[:, 0], results[:, 1], c=rankdata(acq.cpu().numpy()), marker='x', cmap='RdBu')
plt.show()
best_x = results[acq.argmax().item()].detach()
acq_order = acq.argsort(descending=True).cpu()
all_order = x_initial_all_ei.argsort(descending=True).cpu()
return best_x.detach(), results[acq_order].detach(), acq.cpu()[acq_order].detach(), x_initial_all.cpu()[all_order].detach(), x_initial_all_ei.cpu()[all_order].detach()
from ..utils import to_tensor
class TransformerBOMethod:
def __init__(self, model, acq_f=general_acq_function, device='cpu:0', fit_encoder=None, **kwargs):
self.model = model
self.device = device
self.kwargs = kwargs
self.acq_function = acq_f
self.fit_encoder = fit_encoder
@torch.no_grad()
def observe_and_suggest(self, X_obs, y_obs, X_pen, return_actual_ei=False):
# assert X_pen is not None
# assumptions about X_obs and X_pen:
# X_obs is a numpy array of shape (n_samples, n_features)
# y_obs is a numpy array of shape (n_samples,), between 0 and 1
# X_pen is a numpy array of shape (n_samples_left, n_features)
X_obs = to_tensor(X_obs, device=self.device).to(torch.float32)
y_obs = to_tensor(y_obs, device=self.device).to(torch.float32).view(-1)
X_pen = to_tensor(X_pen, device=self.device).to(torch.float32)
assert len(X_obs) == len(y_obs), "make sure both X_obs and y_obs have the same length."
self.model.to(self.device)
if self.fit_encoder is not None:
w = self.fit_encoder(self.model, X_obs, y_obs)
X_obs = w(X_obs)
X_pen = w(X_pen)
# with (torch.cuda.amp.autocast() if self.device.type != 'cpu' else contextlib.nullcontext()):
with (torch.cuda.amp.autocast() if self.device[:3] != 'cpu' else contextlib.nullcontext()):
acq_values = self.acq_function(self.model, X_obs, y_obs,
X_pen, return_actual_ei=return_actual_ei, **self.kwargs).cpu().clone() # bool array
acq_mask = acq_values.max() == acq_values
possible_next = torch.arange(len(X_pen))[acq_mask]
if len(possible_next) == 0:
possible_next = torch.arange(len(X_pen))
r = possible_next[torch.randperm(len(possible_next))[0]].cpu().item()
if return_actual_ei:
return r, acq_values
else:
return r