from __future__ import annotations import random from collections import OrderedDict import torch from .. import utils from ..priors.hebo_prior import Warp from gpytorch.priors import LogNormalPrior # from botorch.optim import module_to_array, set_params_with_array # from .. import module_to_array, set_params_with_array import scipy from scipy.optimize import Bounds from typing import OrderedDict import numpy as np from functools import partial device = 'cpu' def fit_lbfgs(x, w, nll, num_grad_steps=10, ignore_prior=True, params0=None): bounds_ = {} if hasattr(w, "named_parameters_and_constraints"): for param_name, _, constraint in w.named_parameters_and_constraints(): if constraint is not None and not constraint.enforced: bounds_[param_name] = constraint.lower_bound, constraint.upper_bound params0_, property_dict, bounds_ = module_to_array( module=w, bounds=bounds_, exclude=None ) if params0 is None: params0 = params0_ bounds = Bounds(lb=bounds_[0], ub=bounds_[1], keep_feasible=True) def loss_f(params, w): w = set_params_with_array(w, params, property_dict) w.requires_grad_(True) loss = 0. if not ignore_prior: for name, module, prior, closure, _ in w.named_priors(): prior_term = prior.log_prob(closure(module)) loss -= prior_term.sum(dim=-1) negll = nll(w(x.to(torch.float64)).to(torch.float)).sum() #if loss != 0.: # print(loss.item(), negll.item()) loss = loss + negll return w, loss def opt_f(params, w): w, loss = loss_f(params, w) w.zero_grad() loss.backward() grad = [] param_dict = OrderedDict(w.named_parameters()) for p_name in property_dict: t = param_dict[p_name].grad if t is None: # this deals with parameters that do not affect the loss grad.append(np.zeros(property_dict[p_name].shape.numel())) else: grad.append(t.detach().view(-1).cpu().double().clone().numpy()) w.zero_grad() # print(neg_mean_acq.detach().numpy(), x_eval.grad.detach().view(*x.shape).numpy()) return loss.item(), np.concatenate(grad) if num_grad_steps: return scipy.optimize.minimize(partial(opt_f, w=w), params0, method='L-BFGS-B', jac=True, bounds=bounds, options={'maxiter': num_grad_steps}) else: with torch.no_grad(): return loss_f(params0, w), params0 def log_vs_nonlog(x, w, *args, **kwargs): if "true_nll" in kwargs: true_nll = kwargs["true_nll"] del kwargs["true_nll"] else: true_nll = None params, property_dict, _ = module_to_array(module=w) no_log = np.ones_like(params) log = np.array([1.9, 0.11] * (int(len(property_dict) / 2))) loss_no_log = fit_lbfgs(x, w, *args, **{**kwargs, 'num_grad_steps': 0}, params0=no_log) loss_log = fit_lbfgs(x, w, *args, **{**kwargs, 'num_grad_steps': 0}, params0=log) print("loss no log", loss_no_log[0][1], "loss log", loss_log[0][1]) if loss_no_log[0][1] < loss_log[0][1]: set_params_with_array(module=w, x=loss_no_log[1], property_dict=property_dict) if true_nll: best_params, _, _ = module_to_array(module=w) print("true nll", fit_lbfgs(x, w, true_nll, **{**kwargs, 'num_grad_steps': 0}, params0=best_params)) def fit_lbfgs_with_restarts(x, w, *args, old_solution=None, rs_size=50, **kwargs): if "true_nll" in kwargs: true_nll = kwargs["true_nll"] del kwargs["true_nll"] else: true_nll = None rs_results = [] if old_solution: rs_results.append(fit_lbfgs(x, old_solution, *args, **{**kwargs, 'num_grad_steps': 0})) for i in range(rs_size): with torch.no_grad(): w.concentration0[:] = w.concentration0_prior() w.concentration1[:] = w.concentration1_prior() rs_results.append(fit_lbfgs(x, w, *args, **{**kwargs, 'num_grad_steps': 0})) best_r = min(rs_results, key=lambda r: r[0][1]) print('best r', best_r) with torch.set_grad_enabled(True): r = fit_lbfgs(x, w, *args, **kwargs, params0=best_r[1]) _, property_dict, _ = module_to_array(module=w) set_params_with_array(module=w, x=r.x, property_dict=property_dict) print('final r', r) if true_nll: print("true nll", fit_lbfgs(x, w, true_nll, **{**kwargs, 'num_grad_steps': 0}, params0=r.x)) return r # use seed 0 for sampling indices, and reset seed afterwards old_seed = random.getstate() random.seed(0) one_out_indices_sampled_per_num_obs = [None]+[random.sample(range(i), min(10, i)) for i in range(1, 100)] random.setstate(old_seed) # use seed 0 for sampling subsets old_seed = random.getstate() random.seed(0) subsets = [None]+[[random.sample(range(i), i//2) for _ in range(10)] for i in range(1, 100)] neg_subsets = [None]+[[list(set(range(i)) - set(s)) for s in ss] for i, ss in enumerate(subsets[1:], 1)] random.setstate(old_seed) def fit_input_warping(model, x, y, nll_type='fast', old_solution=None, opt_method="lbfgs", **kwargs): """ :param model: :param x: shape (n, d) :param y: shape (n, 1) :param nll_type: :param kwargs: Possible kwargs: `num_grad_steps`, `rs_size` :return: """ device = x.device assert y.device == device, y.device model.requires_grad_(False) w = Warp(range(x.shape[1]), concentration1_prior=LogNormalPrior(torch.tensor(0.0, device=device), torch.tensor(.75, device=device)), concentration0_prior=LogNormalPrior(torch.tensor(0.0, device=device), torch.tensor(.75, device=device)), eps=1e-12) w.to(device) def fast_nll(x): # noqa actually used with `eval` below model.requires_grad_(False) if model.style_encoder is not None: style = torch.zeros(1, 1, dtype=torch.int64, device=device) utils.print_once("WARNING: using style 0 for input warping, this is set for nonmyopic BO setting.") else: style = None logits = model(x[:, None], y[:, None], x[:, None], style=style, only_return_standard_out=True) loss = model.criterion(logits, y[:, None]).squeeze(1) return loss def true_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "true nll not implemented for style encoder, see above for an example impl" model.requires_grad_(False) total_nll = 0. for cutoff in range(len(x)): logits = model(x[:cutoff, None], y[:cutoff, None], x[cutoff:cutoff + 1, None]) total_nll = total_nll + model.criterion(logits, y[cutoff:cutoff + 1, None]).squeeze() assert len(total_nll.shape) == 0, f"{total_nll.shape=}" return total_nll def repeated_true_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "true nll not implemented for style encoder, see above for an example impl" model.requires_grad_(False) total_nll = 0. for i in range(5): rs = np.random.RandomState(i) shuffle_idx = rs.permutation(len(x)) x_ = x.clone()[shuffle_idx] y_ = y.clone()[shuffle_idx] for cutoff in range(len(x)): logits = model(x_[:cutoff, None], y_[:cutoff, None], x_[cutoff:cutoff + 1, None]) total_nll = total_nll + model.criterion(logits, y_[cutoff:cutoff + 1, None]).squeeze() assert len(total_nll.shape) == 0, f"{total_nll.shape=}" return total_nll def repeated_true_100_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "true nll not implemented for style encoder, see above for an example impl" model.requires_grad_(False) total_nll = 0. for i in range(100): rs = np.random.RandomState(i) shuffle_idx = rs.permutation(len(x)) x_ = x.clone()[shuffle_idx] y_ = y.clone()[shuffle_idx] for cutoff in range(len(x)): logits = model(x_[:cutoff, None], y_[:cutoff, None], x_[cutoff:cutoff + 1, None]) total_nll = total_nll + model.criterion(logits, y_[cutoff:cutoff + 1, None]).squeeze() assert len(total_nll.shape) == 0, f"{total_nll.shape=}" return total_nll / 100 def batched_repeated_chunked_true_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "true nll not implemented for style encoder, see above for an example impl" assert len(x.shape) == 2 and len(y.shape) == 1 model.requires_grad_(False) n_features = x.shape[1] if len(x.shape) > 1 else 1 batch_size = 10 X = [] Y = [] for i in range(batch_size): #if i == 0: # shuffle_idx = list(range(len(x))) #else: rs = np.random.RandomState(i) shuffle_idx = rs.permutation(len(x)) X.append(x.clone()[shuffle_idx]) Y.append(y.clone()[shuffle_idx]) X = torch.stack(X, dim=1).view((x.shape[0], batch_size, n_features)) Y = torch.stack(Y, dim=1).view((x.shape[0], batch_size, 1)) total_nll = 0. batch_indizes = sorted(list(set(np.linspace(0, len(x), 10, dtype=int)))) for chunk_start, chunk_end in zip(batch_indizes[:-1], batch_indizes[1:]): X_cutoff = X[:chunk_start] Y_cutoff = Y[:chunk_start] X_after_cutoff = X[chunk_start:chunk_end] Y_after_cutoff = Y[chunk_start:chunk_end] pending_x = X_after_cutoff.reshape(X_after_cutoff.shape[0], batch_size, n_features) # n_pen x batch_size x n_feat observed_x = X_cutoff.reshape(X_cutoff.shape[0], batch_size, n_features) # n_obs x batch_size x n_feat X_tmp = torch.cat((observed_x, pending_x), dim=0) # (n_obs+n_pen) x batch_size x n_feat logits = model((X_tmp, Y_cutoff), single_eval_pos=int(chunk_start)) total_nll = total_nll + model.criterion(logits, Y_after_cutoff).sum() assert len(total_nll.shape) == 0, f"{total_nll.shape=}" return total_nll def batched_repeated_true_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "true nll not implemented for style encoder, see above for an example impl" model.requires_grad_(False) n_features = x.shape[1] if len(x.shape) > 1 else 1 batch_size = 10 X = [] Y = [] for i in range(batch_size): #if i == 0: # shuffle_idx = list(range(len(x))) #else: rs = np.random.RandomState(i) shuffle_idx = rs.permutation(len(x)) X.append(x.clone()[shuffle_idx]) Y.append(y.clone()[shuffle_idx]) X = torch.cat(X, dim=1).reshape((x.shape[0], batch_size, n_features)) Y = torch.cat(Y, dim=1).reshape((x.shape[0], batch_size, 1)) total_nll = 0. for cutoff in range(0, len(x)): X_cutoff = X[:cutoff] Y_cutoff = Y[:cutoff] X_after_cutoff = X[cutoff:cutoff+1] Y_after_cutoff = Y[cutoff:cutoff+1] pending_x = X_after_cutoff.reshape(X_after_cutoff.shape[0], batch_size, n_features) # n_pen x batch_size x n_feat observed_x = X_cutoff.reshape(X_cutoff.shape[0], batch_size, n_features) # n_obs x batch_size x n_feat X_tmp = torch.cat((observed_x, pending_x), dim=0) # (n_obs+n_pen) x batch_size x n_feat pad_y = torch.zeros((X_after_cutoff.shape[0], batch_size, 1)) # (n_obs+n_pen) x batch_size Y_tmp = torch.cat((Y_cutoff, pad_y), dim=0) logits = model((X_tmp, Y_tmp), single_eval_pos=cutoff) total_nll = total_nll + model.criterion(logits, Y_after_cutoff).sum() assert len(total_nll.shape) == 0, f"{total_nll.shape=}" return total_nll def one_out_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "one out nll not implemented for style encoder, see above for an example impl" # x shape: (n, d) # iterate over a pre-defined set of model.requires_grad_(False) #indices = one_out_indices_sampled_per_num_obs[len(x)] indices = list(range(x.shape[0])) # create batch by moving the one out index to the end eval_x = x[indices][None] # shape (1, 10, d) eval_y = y[indices][None] # shape (1, 10, 1) # all other indices are used for training train_x = torch.stack([torch.cat([x[:i], x[i + 1:]]) for i in indices], 1) train_y = torch.stack([torch.cat([y[:i], y[i + 1:]]) for i in indices], 1) logits = model(train_x, train_y, eval_x) return model.criterion(logits, eval_y).squeeze(0) def subset_nll(x): # noqa actually used with `eval` below assert model.style_encoder is None, "subset nll not implemented for style encoder, see above for an example impl" # x shape: (n, d) # iterate over a pre-defined set of model.requires_grad_(False) eval_indices = torch.tensor(subsets[len(x)]) train_indices = torch.tensor(neg_subsets[len(x)]) # batch by using all eval_indices eval_x = x[eval_indices.flatten()].view(eval_indices.shape + (-1,)) # shape (10, n//2, d) eval_y = y[eval_indices.flatten()].view(eval_indices.shape + (-1,)) # shape (10, n//2, 1) # all other indices are used for training train_x = x[train_indices.flatten()].view(train_indices.shape + (-1,)) # shape (10, n//2, d) train_y = y[train_indices.flatten()].view(train_indices.shape + (-1,)) # shape (10, n//2, 1) logits = model(train_x.transpose(0, 1), train_y.transpose(0, 1), eval_x.transpose(0, 1)) return model.criterion(logits, eval_y.transpose(0, 1)) if opt_method == "log_vs_nolog": log_vs_nonlog(x, w, eval(nll_type + '_nll'), ignore_prior=True, # true_nll=repeated_true_100_nll, **kwargs) elif opt_method == "lbfgs": fit_lbfgs_with_restarts( x, w, eval(nll_type + '_nll'), ignore_prior=True, old_solution=old_solution, # true_nll=repeated_true_100_nll, **kwargs) elif opt_method == "lbfgs_w_prior": fit_lbfgs_with_restarts( x, w, eval(nll_type + '_nll'), ignore_prior=False, old_solution=old_solution, # true_nll=repeated_true_100_nll, **kwargs) else: raise ValueError(opt_method) return w #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. r""" A converter that simplifies using numpy-based optimizers with generic torch `nn.Module` classes. This enables using a `scipy.optim.minimize` optimizer for optimizing module parameters. """ from collections import OrderedDict from math import inf from numbers import Number from typing import Dict, List, Optional, Set, Tuple from warnings import warn import numpy as np import torch from botorch.optim.utils import ( _get_extra_mll_args, _handle_numerical_errors, get_name_filter, get_parameters_and_bounds, TorchAttr, ) from gpytorch.mlls import MarginalLogLikelihood from torch.nn import Module def module_to_array( module: Module, bounds: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]] = None, exclude: Optional[Set[str]] = None, ) -> Tuple[np.ndarray, Dict[str, TorchAttr], Optional[np.ndarray]]: r"""Extract named parameters from a module into a numpy array. Only extracts parameters with requires_grad, since it is meant for optimizing. Args: module: A module with parameters. May specify parameter constraints in a `named_parameters_and_constraints` method. bounds: A dictionary mapping parameter names t lower and upper bounds. of lower and upper bounds. Bounds specified here take precedence over bounds on the same parameters specified in the constraints registered with the module. exclude: A list of parameter names that are to be excluded from extraction. Returns: 3-element tuple containing - The parameter values as a numpy array. - An ordered dictionary with the name and tensor attributes of each parameter. - A `2 x n_params` numpy array with lower and upper bounds if at least one constraint is finite, and None otherwise. Example: >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) >>> parameter_array, property_dict, bounds_out = module_to_array(mll) """ warn( "`module_to_array` is marked for deprecation, consider using " "`get_parameters_and_bounds`, `get_parameters_as_ndarray_1d`, or " "`get_bounds_as_ndarray` instead.", DeprecationWarning, ) param_dict, bounds_dict = get_parameters_and_bounds( module=module, name_filter=None if exclude is None else get_name_filter(exclude), requires_grad=True, ) if bounds is not None: bounds_dict.update(bounds) # Record tensor metadata and read parameter values to the tape param_tape: List[Number] = [] property_dict = OrderedDict() with torch.no_grad(): for name, param in param_dict.items(): property_dict[name] = TorchAttr(param.shape, param.dtype, param.device) param_tape.extend(param.view(-1).cpu().double().tolist()) # Extract lower and upper bounds start = 0 bounds_np = None params_np = np.asarray(param_tape) for name, param in param_dict.items(): numel = param.numel() if name in bounds_dict: for row, bound in enumerate(bounds_dict[name]): if bound is None: continue if torch.is_tensor(bound): if (bound == (2 * row - 1) * inf).all(): continue bound = bound.detach().cpu() elif bound == (2 * row - 1) * inf: continue if bounds_np is None: bounds_np = np.full((2, len(params_np)), ((-inf,), (inf,))) bounds_np[row, start : start + numel] = bound start += numel return params_np, property_dict, bounds_np def set_params_with_array( module: Module, x: np.ndarray, property_dict: Dict[str, TorchAttr] ) -> Module: r"""Set module parameters with values from numpy array. Args: module: Module with parameters to be set x: Numpy array with parameter values property_dict: Dictionary of parameter names and torch attributes as returned by module_to_array. Returns: Module: module with parameters updated in-place. Example: >>> mll = ExactMarginalLogLikelihood(model.likelihood, model) >>> parameter_array, property_dict, bounds_out = module_to_array(mll) >>> parameter_array += 0.1 # perturb parameters (for example only) >>> mll = set_params_with_array(mll, parameter_array, property_dict) """ warn( "`_set_params_with_array` is marked for deprecation, consider using " "`set_parameters_from_ndarray_1d` instead.", DeprecationWarning, ) param_dict = OrderedDict(module.named_parameters()) start_idx = 0 for p_name, attrs in property_dict.items(): # Construct the new tensor if len(attrs.shape) == 0: # deal with scalar tensors end_idx = start_idx + 1 new_data = torch.tensor( x[start_idx], dtype=attrs.dtype, device=attrs.device ) else: end_idx = start_idx + np.prod(attrs.shape) new_data = torch.tensor( x[start_idx:end_idx], dtype=attrs.dtype, device=attrs.device ).view(*attrs.shape) start_idx = end_idx # Update corresponding parameter in-place. Disable autograd to update. param_dict[p_name].requires_grad_(False) param_dict[p_name].copy_(new_data) param_dict[p_name].requires_grad_(True) return module