Spaces:
Sleeping
Sleeping
| import contextlib | |
| from .. import transformer | |
| from .. import bar_distribution | |
| import torch | |
| import scipy | |
| import math | |
| from sklearn.preprocessing import power_transform, PowerTransformer | |
| def log01(x, eps=.0000001, input_between_zero_and_one=False): | |
| logx = torch.log(x + eps) | |
| if input_between_zero_and_one: | |
| return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps)) | |
| return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0]) | |
| def log01_batch(x, eps=.0000001, input_between_zero_and_one=False): | |
| x = x.repeat(1, x.shape[-1] + 1, 1) | |
| for b in range(x.shape[-1]): | |
| x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one) | |
| return x | |
| def lognormed_batch(x, eval_pos, eps=.0000001): | |
| x = x.repeat(1, x.shape[-1] + 1, 1) | |
| for b in range(x.shape[-1]): | |
| logx = torch.log(x[:, b, b]+eps) | |
| x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0) | |
| return x | |
| def _rank_transform(x_train, x): | |
| assert len(x_train.shape) == len(x.shape) == 1 | |
| relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1) | |
| higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1) | |
| pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1]) | |
| x_transformed = higher_comparison - 1 + pos_inside_interval | |
| return x_transformed/(len(relative_to)-1.) | |
| def rank_transform(x_train, x): | |
| assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}" | |
| # make sure everything is between 0 and 1 | |
| assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}" | |
| assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}" | |
| return_x = x.clone() | |
| for feature_dim in range(x.shape[1]): | |
| return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim]) | |
| return return_x | |
| def general_power_transform(x_train, x_apply, eps, less_safe=False): | |
| if eps > 0: | |
| try: | |
| pt = PowerTransformer(method='box-cox') | |
| pt.fit(x_train.cpu()+eps) | |
| x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device) | |
| except ValueError as e: | |
| print(e) | |
| x_out = x_apply - x_train.mean(0) | |
| else: | |
| pt = PowerTransformer(method='yeo-johnson') | |
| if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000): | |
| x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) | |
| x_train = (x_train - x_train.mean(0)) / x_train.std(0) | |
| # print('inputs are LAARGEe, normalizing them') | |
| try: | |
| pt.fit(x_train.cpu().double()) | |
| except ValueError as e: | |
| print('caught this errrr', e) | |
| if less_safe: | |
| x_train = (x_train - x_train.mean(0)) / x_train.std(0) | |
| x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) | |
| else: | |
| x_train = x_train - x_train.mean(0) | |
| x_apply = x_apply - x_train.mean(0) | |
| pt.fit(x_train.cpu().double()) | |
| x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device) | |
| if torch.isnan(x_out).any() or torch.isinf(x_out).any(): | |
| print('WARNING: power transform failed') | |
| print(f"{x_train=} and {x_apply=}") | |
| x_out = x_apply - x_train.mean(0) | |
| return x_out | |
| #@torch.inference_mode() | |
| def general_acq_function(model: transformer.TransformerModel, x_given, y_given, x_eval, apply_power_transform=True, | |
| rand_sample=False, znormalize=False, pre_normalize=False, pre_znormalize=False, predicted_mean_fbest=False, | |
| input_znormalize=False, max_dataset_size=10_000, remove_features_with_one_value_only=False, | |
| return_actual_ei=False, acq_function='ei', ucb_rest_prob=.05, ensemble_log_dims=False, | |
| ensemble_type='mean_probs', # in ('mean_probs', 'max_acq') | |
| input_power_transform=False, power_transform_eps=.0, input_power_transform_eps=.0, | |
| input_rank_transform=False, ensemble_input_rank_transform=False, | |
| ensemble_power_transform=False, ensemble_feature_rotation=False, | |
| style=None, outlier_stretching_interval=0.0, verbose=False, unsafe_power_transform=False, | |
| ): | |
| """ | |
| Differences to HEBO: | |
| - The noise can't be set in the same way, as it depends on the tuning of HPs via VI. | |
| - Log EI and PI are always used directly instead of using the approximation. | |
| This is a stochastic function, relying on torch.randn | |
| :param model: | |
| :param x_given: torch.Tensor of shape (N, D) | |
| :param y_given: torch.Tensor of shape (N, 1) or (N,) | |
| :param x_eval: torch.Tensor of shape (M, D) | |
| :param kappa: | |
| :param eps: | |
| :return: | |
| """ | |
| assert ensemble_type in ('mean_probs', 'max_acq') | |
| if rand_sample is not False \ | |
| and (len(x_given) == 0 or | |
| ((1 + x_given.shape[1] if rand_sample is None else max(2, rand_sample)) > x_given.shape[0])): | |
| print('rando') | |
| return torch.zeros_like(x_eval[:,0]) #torch.randperm(x_eval.shape[0])[0] | |
| y_given = y_given.reshape(-1) | |
| assert len(y_given) == len(x_given) | |
| if apply_power_transform: | |
| if pre_normalize: | |
| y_normed = y_given / y_given.std() | |
| if not torch.isinf(y_normed).any() and not torch.isnan(y_normed).any(): | |
| y_given = y_normed | |
| elif pre_znormalize: | |
| y_znormed = (y_given - y_given.mean()) / y_given.std() | |
| if not torch.isinf(y_znormed).any() and not torch.isnan(y_znormed).any(): | |
| y_given = y_znormed | |
| y_given = general_power_transform(y_given.unsqueeze(1), y_given.unsqueeze(1), power_transform_eps, less_safe=unsafe_power_transform).squeeze(1) | |
| if verbose: | |
| print(f"{y_given=}") | |
| #y_given = torch.tensor(power_transform(y_given.cpu().unsqueeze(1), method='yeo-johnson', standardize=znormalize), device=y_given.device, dtype=y_given.dtype,).squeeze(1) | |
| y_given_std = torch.tensor(1., device=y_given.device, dtype=y_given.dtype) | |
| if znormalize and not apply_power_transform: | |
| if len(y_given) > 1: | |
| y_given_std = y_given.std() | |
| y_given_mean = y_given.mean() | |
| y_given = (y_given - y_given_mean) / y_given_std | |
| if remove_features_with_one_value_only: | |
| x_all = torch.cat([x_given, x_eval], dim=0) | |
| only_one_value_feature = torch.tensor([len(torch.unique(x_all[:,i])) for i in range(x_all.shape[1])]) == 1 | |
| x_given = x_given[:,~only_one_value_feature] | |
| x_eval = x_eval[:,~only_one_value_feature] | |
| if outlier_stretching_interval > 0.: | |
| tx = torch.cat([x_given, x_eval], dim=0) | |
| m = outlier_stretching_interval | |
| eps = 1e-10 | |
| small_values = (tx < m) & (tx > 0.) | |
| tx[small_values] = m * (torch.log(tx[small_values] + eps) - math.log(eps)) / (math.log(m + eps) - math.log(eps)) | |
| large_values = (tx > 1. - m) & (tx < 1.) | |
| tx[large_values] = 1. - m * (torch.log(1 - tx[large_values] + eps) - math.log(eps)) / ( | |
| math.log(m + eps) - math.log(eps)) | |
| x_given = tx[:len(x_given)] | |
| x_eval = tx[len(x_given):] | |
| if input_znormalize: # implementation that relies on the test set, too... | |
| std = x_given.std(dim=0) | |
| std[std == 0.] = 1. | |
| mean = x_given.mean(dim=0) | |
| x_given = (x_given - mean) / std | |
| x_eval = (x_eval - mean) / std | |
| if input_power_transform: | |
| x_given = general_power_transform(x_given, x_given, input_power_transform_eps) | |
| x_eval = general_power_transform(x_given, x_eval, input_power_transform_eps) | |
| if input_rank_transform is True or input_rank_transform == 'full': # uses test set x statistics... | |
| x_all = torch.cat((x_given,x_eval), dim=0) | |
| for feature_dim in range(x_all.shape[-1]): | |
| uniques = torch.sort(torch.unique(x_all[..., feature_dim])).values | |
| x_eval[...,feature_dim] = torch.searchsorted(uniques,x_eval[..., feature_dim]).float() / (len(uniques)-1) | |
| x_given[...,feature_dim] = torch.searchsorted(uniques,x_given[..., feature_dim]).float() / (len(uniques)-1) | |
| elif input_rank_transform is False: | |
| pass | |
| elif input_rank_transform == 'train': | |
| x_given = rank_transform(x_given, x_given) | |
| x_eval = rank_transform(x_given, x_eval) | |
| elif input_rank_transform.startswith('train'): | |
| likelihood = float(input_rank_transform.split('_')[-1]) | |
| if torch.rand(1).item() < likelihood: | |
| print('rank transform') | |
| x_given = rank_transform(x_given, x_given) | |
| x_eval = rank_transform(x_given, x_eval) | |
| else: | |
| raise NotImplementedError | |
| # compute logits | |
| criterion: bar_distribution.BarDistribution = model.criterion | |
| x_predict = torch.cat([x_given, x_eval], dim=0) | |
| logits_list = [] | |
| for x_feed in torch.split(x_predict, max_dataset_size, dim=0): | |
| x_full_feed = torch.cat([x_given, x_feed], dim=0).unsqueeze(1) | |
| y_full_feed = y_given.unsqueeze(1) | |
| if ensemble_log_dims == '01': | |
| x_full_feed = log01_batch(x_full_feed) | |
| elif ensemble_log_dims == 'global01' or ensemble_log_dims is True: | |
| x_full_feed = log01_batch(x_full_feed, input_between_zero_and_one=True) | |
| elif ensemble_log_dims == '01-10': | |
| x_full_feed = torch.cat((log01_batch(x_full_feed)[:, :-1], log01_batch(1. - x_full_feed)), 1) | |
| elif ensemble_log_dims == 'norm': | |
| x_full_feed = lognormed_batch(x_full_feed, len(x_given)) | |
| elif ensemble_log_dims is not False: | |
| raise NotImplementedError | |
| if ensemble_feature_rotation: | |
| x_full_feed = torch.cat([x_full_feed[:, :, (i+torch.arange(x_full_feed.shape[2])) % x_full_feed.shape[2]] for i in range(x_full_feed.shape[2])], dim=1) | |
| if ensemble_input_rank_transform == 'train' or ensemble_input_rank_transform is True: | |
| x_full_feed = torch.cat([rank_transform(x_given, x_full_feed[:,i,:])[:,None] for i in range(x_full_feed.shape[1])] + [x_full_feed], dim=1) | |
| if ensemble_power_transform: | |
| assert apply_power_transform is False | |
| y_full_feed = torch.cat((general_power_transform(y_full_feed, y_full_feed, power_transform_eps), y_full_feed), dim=1) | |
| if style is not None: | |
| if callable(style): | |
| style = style() | |
| if isinstance(style, torch.Tensor): | |
| style = style.to(x_full_feed.device) | |
| else: | |
| style = torch.tensor(style, device=x_full_feed.device).view(1, 1).repeat(x_full_feed.shape[1], 1) | |
| logits = model( | |
| (style, | |
| x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]), | |
| y_full_feed.repeat(1,x_full_feed.shape[1])), | |
| single_eval_pos=len(x_given) | |
| ) | |
| if ensemble_type == 'mean_probs': | |
| logits = logits.softmax(-1).mean(1, keepdim=True).log_() # (num given + num eval, 1, num buckets) | |
| # print('in ensemble_type == mean_probs ') | |
| logits_list.append(logits) # (< max_dataset_size, 1 , num_buckets) | |
| logits = torch.cat(logits_list, dim=0) # (num given + num eval, 1 or (num_features+1), num buckets) | |
| del logits_list, x_full_feed | |
| if torch.isnan(logits).any(): | |
| print('nan logits') | |
| print(f"y_given: {y_given}, x_given: {x_given}, x_eval: {x_eval}") | |
| print(f"logits: {logits}") | |
| return torch.zeros_like(x_eval[:,0]) | |
| #logits = model((torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1), | |
| # torch.cat([y_given, torch.zeros(len(x_eval)+len(x_given), device=y_given.device)], dim=0).unsqueeze(1)), | |
| # single_eval_pos=len(x_given))[:,0] # (N + M, num_buckets) | |
| logits_given = logits[:len(x_given)] | |
| logits_eval = logits[len(x_given):] | |
| #tau = criterion.mean(logits_given)[torch.argmax(y_given)] # predicted mean at the best y | |
| if predicted_mean_fbest: | |
| tau = criterion.mean(logits_given)[torch.argmax(y_given)].squeeze(0) | |
| else: | |
| tau = torch.max(y_given) | |
| #log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0) | |
| def acq_ensembling(acq_values): # (points, ensemble dim) | |
| return acq_values.max(1).values | |
| if isinstance(acq_function, (dict,list)): | |
| acq_function = acq_function[style] | |
| if acq_function == 'ei': | |
| acq_value = acq_ensembling(criterion.ei(logits_eval, tau)) | |
| elif acq_function == 'ei_or_rand': | |
| if torch.rand(1).item() < 0.5: | |
| acq_value = torch.rand(len(x_eval)) | |
| else: | |
| acq_value = acq_ensembling(criterion.ei(logits_eval, tau)) | |
| elif acq_function == 'pi': | |
| acq_value = acq_ensembling(criterion.pi(logits_eval, tau)) | |
| elif acq_function == 'ucb': | |
| acq_function = criterion.ucb | |
| if ucb_rest_prob is not None: | |
| acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob) | |
| acq_value = acq_ensembling(acq_function(logits_eval, tau)) | |
| elif acq_function == 'mean': | |
| acq_value = acq_ensembling(criterion.mean(logits_eval)) | |
| elif acq_function.startswith('hebo'): | |
| noise, upsi, delta, eps = (float(v) for v in acq_function.split('_')[1:]) | |
| noise = y_given_std * math.sqrt(2 * noise) | |
| kappa = math.sqrt(upsi * 2 * ((2.0 + x_given.shape[1] / 2.0) * math.log(max(1, len(x_given))) + math.log( | |
| 3 * math.pi ** 2 / (3 * delta)))) | |
| rest_prob = 1. - .5 * (1 + torch.erf(torch.tensor(kappa / math.sqrt(2), device=logits.device))) | |
| ucb = acq_ensembling(criterion.ucb(logits_eval, None, rest_prob=rest_prob)) \ | |
| + torch.randn(len(logits_eval), device=logits_eval.device) * noise | |
| noisy_best_f = tau + eps + \ | |
| noise * torch.randn(len(logits_eval), device=logits_eval.device)[:, None].repeat(1, logits_eval.shape[1]) | |
| log_pi = acq_ensembling(criterion.pi(logits_eval, noisy_best_f).log()) | |
| # log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0) | |
| log_ei = acq_ensembling(criterion.ei(logits_eval, noisy_best_f).log()) | |
| acq_values = torch.stack([ucb, log_ei, log_pi], dim=1) | |
| def is_pareto_efficient(costs): | |
| """ | |
| Find the pareto-efficient points | |
| :param costs: An (n_points, n_costs) array | |
| :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient | |
| """ | |
| is_efficient = torch.ones(costs.shape[0], dtype=bool, device=costs.device) | |
| for i, c in enumerate(costs): | |
| if is_efficient[i]: | |
| is_efficient[is_efficient.clone()] = (costs[is_efficient] < c).any( | |
| 1) # Keep any point with a lower cost | |
| is_efficient[i] = True # And keep self | |
| return is_efficient | |
| acq_value = is_pareto_efficient(-acq_values) | |
| else: | |
| raise ValueError(f'Unknown acquisition function: {acq_function}') | |
| max_acq = acq_value.max() | |
| return acq_value if return_actual_ei else (acq_value == max_acq) | |
| def optimize_acq(model, known_x, known_y, num_grad_steps=10, num_random_samples=100, lr=.01, **kwargs): | |
| """ | |
| intervals are assumed to be between 0 and 1 | |
| only works with ei | |
| recommended extra kwarg: ensemble_input_rank_transform=='train' | |
| :param model: model to optimize, should already handle different num_features with its encoder | |
| You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)` | |
| :param known_x: (N, num_features) | |
| :param known_y: (N,) | |
| :param num_grad_steps: int | |
| :param num_random_samples: int | |
| :param lr: float | |
| :param kwargs: will be given to `general_acq_function` | |
| :return: | |
| """ | |
| x_eval = torch.rand(num_random_samples, known_x.shape[1]).requires_grad_(True) | |
| opt = torch.optim.Adam(params=[x_eval], lr=lr) | |
| best_acq, best_x = -float('inf'), x_eval[0].detach() | |
| for grad_step in range(num_grad_steps): | |
| acq = general_acq_function(model, known_x, known_y, x_eval, return_actual_ei=True, **kwargs) | |
| max_acq = acq.detach().max().item() | |
| if max_acq > best_acq: | |
| best_x = x_eval[acq.argmax()].detach() | |
| best_acq = max_acq | |
| (-acq.mean()).backward() | |
| assert (x_eval.grad != 0.).any() | |
| if torch.isfinite(x_eval.grad).all(): | |
| opt.step() | |
| opt.zero_grad() | |
| with torch.no_grad(): | |
| x_eval.clamp_(min=0., max=1.) | |
| return best_x | |
| def optimize_acq_w_lbfgs(model, known_x, known_y, num_grad_steps=15_000, num_candidates=100, pre_sample_size=100_000, device='cpu', | |
| verbose=False, dims_wo_gradient_opt=[], rand_sample_func=None, **kwargs): | |
| """ | |
| intervals are assumed to be between 0 and 1 | |
| only works with deterministic acq | |
| recommended extra kwarg: ensemble_input_rank_transform=='train' | |
| :param model: model to optimize, should already handle different num_features with its encoder | |
| You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)` | |
| :param known_x: (N, num_features) | |
| :param known_y: (N,) | |
| :param num_grad_steps: int: how many steps to take inside of scipy, can be left high, as it stops most of the time automatically early | |
| :param num_candidates: int: how many candidates to optimize with LBFGS, increases costs when higher | |
| :param pre_sample_size: int: how many settings to try first with a random search, before optimizing the best with grads | |
| :param dims_wo_gradient_opt: int: which dimensions to not optimize with gradients, but with random search only | |
| :param rand_sample_func: function: how to sample random points, should be a function that takes a number of samples and returns a tensor | |
| For example `lambda n: torch.rand(n, known_x.shape[1])`. | |
| :param kwargs: will be given to `general_acq_function` | |
| :return: | |
| """ | |
| num_features = known_x.shape[1] | |
| dims_w_gradient_opt = sorted(set(range(num_features)) - set(dims_wo_gradient_opt)) | |
| known_x = known_x.to(device) | |
| known_y = known_y.to(device) | |
| pre_sample_size = max(pre_sample_size, num_candidates) | |
| rand_sample_func = rand_sample_func or (lambda n: torch.rand(n, num_features, device=device)) | |
| if len(known_x) < pre_sample_size: | |
| x_initial = torch.cat((rand_sample_func(pre_sample_size-len(known_x)).to(device), known_x), 0) | |
| else: | |
| x_initial = rand_sample_func(pre_sample_size) | |
| x_initial = x_initial.clamp(min=0., max=1.) | |
| x_initial_all = x_initial | |
| model.to(device) | |
| with torch.no_grad(): | |
| acq = general_acq_function(model, known_x, known_y, x_initial.to(device), return_actual_ei=True, **kwargs) | |
| if verbose: | |
| import matplotlib.pyplot as plt | |
| if x_initial.shape[1] == 2: | |
| plt.title('initial acq values, red -> blue') | |
| plt.scatter(x_initial[:, 0][:100], x_initial[:, 1][:100], c=acq.cpu().numpy()[:100], cmap='RdBu') | |
| x_initial = x_initial[acq.argsort(descending=True)[:num_candidates].cpu()].detach() # num_candidates x num_features | |
| x_initial_all_ei = acq.cpu().detach() | |
| def opt_f(x): | |
| x_eval = torch.tensor(x).view(-1, len(dims_w_gradient_opt)).float().to(device).requires_grad_(True) | |
| x_eval_new = x_initial.clone().detach().to(device) | |
| x_eval_new[:, dims_w_gradient_opt] = x_eval | |
| assert x_eval_new.requires_grad | |
| assert not torch.isnan(x_eval_new).any() | |
| model.requires_grad_(False) | |
| acq = general_acq_function(model, known_x, known_y, x_eval_new, return_actual_ei=True, **kwargs) | |
| neg_mean_acq = -acq.mean() | |
| neg_mean_acq.backward() | |
| #print(neg_mean_acq.detach().numpy(), x_eval.grad.detach().view(*x.shape).numpy()) | |
| with torch.no_grad(): | |
| x_eval.grad[x_eval.grad != x_eval.grad] = 0. | |
| return neg_mean_acq.detach().cpu().to(torch.float64).numpy(), \ | |
| x_eval.grad.detach().view(*x.shape).cpu().to(torch.float64).numpy() | |
| # Optimize best candidates with LBFGS | |
| if num_grad_steps > 0 and len(dims_w_gradient_opt) > 0: | |
| # the columns not in dims_wo_gradient_opt will be optimized with gradients | |
| x_initial_for_gradient_opt = x_initial[:, dims_w_gradient_opt].detach().cpu().flatten().numpy() # x_initial.cpu().flatten().numpy() | |
| res = scipy.optimize.minimize(opt_f, x_initial_for_gradient_opt, method='L-BFGS-B', jac=True, | |
| bounds=[(0, 1)]*x_initial_for_gradient_opt.size, | |
| options={'maxiter': num_grad_steps}) | |
| results = x_initial.cpu() | |
| results[:, dims_w_gradient_opt] = torch.tensor(res.x).float().view(-1, len(dims_w_gradient_opt)) | |
| else: | |
| results = x_initial.cpu() | |
| results = results.clamp(min=0., max=1.) | |
| # Recalculate the acq values for the best candidates | |
| with torch.no_grad(): | |
| acq = general_acq_function(model, known_x, known_y, results.to(device), return_actual_ei=True, verbose=verbose, **kwargs) | |
| #print(acq) | |
| if verbose: | |
| from scipy.stats import rankdata | |
| import matplotlib.pyplot as plt | |
| if results.shape[1] == 2: | |
| plt.scatter(results[:, 0], results[:, 1], c=rankdata(acq.cpu().numpy()), marker='x', cmap='RdBu') | |
| plt.show() | |
| best_x = results[acq.argmax().item()].detach() | |
| acq_order = acq.argsort(descending=True).cpu() | |
| all_order = x_initial_all_ei.argsort(descending=True).cpu() | |
| return best_x.detach(), results[acq_order].detach(), acq.cpu()[acq_order].detach(), x_initial_all.cpu()[all_order].detach(), x_initial_all_ei.cpu()[all_order].detach() | |
| from ..utils import to_tensor | |
| class TransformerBOMethod: | |
| def __init__(self, model, acq_f=general_acq_function, device='cpu:0', fit_encoder=None, **kwargs): | |
| self.model = model | |
| self.device = device | |
| self.kwargs = kwargs | |
| self.acq_function = acq_f | |
| self.fit_encoder = fit_encoder | |
| def observe_and_suggest(self, X_obs, y_obs, X_pen, return_actual_ei=False): | |
| # assert X_pen is not None | |
| # assumptions about X_obs and X_pen: | |
| # X_obs is a numpy array of shape (n_samples, n_features) | |
| # y_obs is a numpy array of shape (n_samples,), between 0 and 1 | |
| # X_pen is a numpy array of shape (n_samples_left, n_features) | |
| X_obs = to_tensor(X_obs, device=self.device).to(torch.float32) | |
| y_obs = to_tensor(y_obs, device=self.device).to(torch.float32).view(-1) | |
| X_pen = to_tensor(X_pen, device=self.device).to(torch.float32) | |
| assert len(X_obs) == len(y_obs), "make sure both X_obs and y_obs have the same length." | |
| self.model.to(self.device) | |
| if self.fit_encoder is not None: | |
| w = self.fit_encoder(self.model, X_obs, y_obs) | |
| X_obs = w(X_obs) | |
| X_pen = w(X_pen) | |
| # with (torch.cuda.amp.autocast() if self.device.type != 'cpu' else contextlib.nullcontext()): | |
| with (torch.cuda.amp.autocast() if self.device[:3] != 'cpu' else contextlib.nullcontext()): | |
| acq_values = self.acq_function(self.model, X_obs, y_obs, | |
| X_pen, return_actual_ei=return_actual_ei, **self.kwargs).cpu().clone() # bool array | |
| acq_mask = acq_values.max() == acq_values | |
| possible_next = torch.arange(len(X_pen))[acq_mask] | |
| if len(possible_next) == 0: | |
| possible_next = torch.arange(len(X_pen)) | |
| r = possible_next[torch.randperm(len(possible_next))[0]].cpu().item() | |
| if return_actual_ei: | |
| return r, acq_values | |
| else: | |
| return r | |