import numpy as np import torch import torch.nn as nn class GaussianProcessRegressor(nn.Module): def __init__( self, length_scale=1.0, noise_scale=1.0, amplitude_scale=1.0, ): super().__init__() if isinstance(length_scale, float): length_scale = np.array([length_scale]) elif isinstance(length_scale, np.ndarray): assert length_scale.ndim == 1 else: raise TypeError() self.register_parameter( "length_scale_", param=nn.Parameter(torch.Tensor(np.log(length_scale)), requires_grad=True), ) self.register_parameter( "noise_scale_", param=nn.Parameter(torch.tensor(np.log(noise_scale)), requires_grad=True), ) self.register_parameter( "amplitude_scale_", param=nn.Parameter( torch.tensor(np.log(amplitude_scale)), requires_grad=True ), ) self.nll = None def forward(self, x): alpha = self.alpha k = self.Kxy(self.X, x) mu = k.T.mm(alpha) return mu def log_marginal_likelihood(self, X, y): D = X.shape[1] K = self.Kxx(X) L = torch.linalg.cholesky(K) alpha = torch.linalg.solve(L.T, torch.linalg.solve(L, y)) marginal_likelihood = ( -0.5 * y.T.mm(alpha) - torch.log(torch.diag(L)).sum() - D * 0.5 * np.log(2 * np.pi) ) self.L = L self.alpha = alpha self.K = K return marginal_likelihood def Kxx(self, X): param = self.length_scale_.exp().sqrt() sqdist = torch.cdist(X / param[None], X / param[None]) ** 2 res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist) + self.noise_scale_.exp() * torch.eye(len(X)).type_as(X) return res def Kxy(self, X, Z): param = self.length_scale_.exp().sqrt() sqdist = torch.cdist(X / param[None], Z / param[None]) ** 2 res = self.amplitude_scale_.exp() * torch.exp(-0.5 * sqdist) return res def fit(self, X, y, opt, num_steps): assert X.shape[1] == len(self.length_scale_) self.y = y self.X = X scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9) self.train() nll_hist = [] for it in range(num_steps): opt.zero_grad() try: nll = -self.log_marginal_likelihood(self.X, self.y).sum() except torch.linalg.LinAlgError: break nll.backward() opt.step() if it%10==0 and it<1000: scheduler.step() nll_hist.append(nll.item()) return nll_hist