# Code for GP final layer adapted from this great repo: # https://github.com/kimjeyoung/SNGP-BERT-Pytorch . # We simplify things here a bit by removing the spectral # normalisation as the authors of the Plex paper say that this # isn't strictly necessary, so we just have a GP classification head on the model. import torch import math import copy from torch import nn def RandomFeatureLinear(i_dim, o_dim, bias=True, require_grad=False): m = nn.Linear(i_dim, o_dim, bias) nn.init.normal_(m.weight, mean=0.0, std=0.05) m.weight.requires_grad = require_grad # Freeze weights if bias: nn.init.uniform_(m.bias, a=0.0, b=2.0 * math.pi) # Freeze bias m.bias.requires_grad = require_grad return m class GPClassificationHead(nn.Module): def __init__( self, hidden_size=768, gp_kernel_scale=1.0, num_inducing=1024, gp_output_bias=0.0, layer_norm_eps=1e-12, scale_random_features=True, normalize_input=True, gp_cov_momentum=0.999, gp_cov_ridge_penalty=1e-3, epochs=40, num_classes=3, device="cpu", ): super(GPClassificationHead, self).__init__() self.final_epochs = epochs - 1 self.gp_cov_ridge_penalty = gp_cov_ridge_penalty self.gp_cov_momentum = gp_cov_momentum self.pooled_output_dim = hidden_size self.gp_input_scale = 1.0 / math.sqrt(gp_kernel_scale) self.gp_feature_scale = math.sqrt(2.0 / float(num_inducing)) self.gp_output_bias = gp_output_bias self.scale_random_features = scale_random_features self.normalize_input = normalize_input self.device = device self._gp_input_normalize_layer = torch.nn.LayerNorm( hidden_size, eps=layer_norm_eps ) self._gp_output_layer = nn.Linear( num_inducing, num_classes, bias=False ) # gp_output_bias set to not trainable self._gp_output_bias = torch.tensor([self.gp_output_bias] * num_classes).to( device ) self._random_feature = RandomFeatureLinear(self.pooled_output_dim, num_inducing) # Inverse covariance matrix corresponding to RFF-GP posterior self.initial_precision_matrix = self.gp_cov_ridge_penalty * torch.eye( num_inducing ).to(device) self.precision_matrix = torch.nn.Parameter( copy.deepcopy(self.initial_precision_matrix), requires_grad=False ) def gp_layer(self, gp_inputs, update_cov=True): if self.normalize_input: gp_inputs = self._gp_input_normalize_layer(gp_inputs) gp_feature = self._random_feature(gp_inputs) gp_feature = torch.cos(gp_feature) if self.scale_random_features: gp_feature = gp_feature * self.gp_input_scale gp_output = self._gp_output_layer(gp_feature).to( self.device ) + self._gp_output_bias.to(self.device) if update_cov: self.update_cov(gp_feature) return gp_feature, gp_output def reset_cov(self): self.precision_matrix = torch.nn.Parameter( copy.deepcopy(self.initial_precision_matrix), requires_grad=False ) def update_cov(self, gp_feature): # https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L346 batch_size = gp_feature.size()[0] precision_matrix_minibatch = torch.matmul(gp_feature.t(), gp_feature) # Moving average updates to precision matrix precision_matrix_minibatch = precision_matrix_minibatch / batch_size precision_matrix_new = ( self.gp_cov_momentum * self.precision_matrix + (1.0 - self.gp_cov_momentum) * precision_matrix_minibatch ) self.precision_matrix = torch.nn.Parameter( precision_matrix_new, requires_grad=False ) def compute_predictive_covariance(self, gp_feature): # https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L403 # Covariance matrix of feature coefficient feature_cov_matrix = torch.linalg.inv(self.precision_matrix) # Predictive covariance matrix for the GP cov_feature_product = ( torch.matmul(feature_cov_matrix, gp_feature.t()) * self.gp_cov_ridge_penalty ) gp_cov_matrix = torch.matmul(gp_feature, cov_feature_product) return gp_cov_matrix def forward( self, input_features, return_gp_cov: bool = False, update_cov: bool = True, ): gp_feature, gp_output = self.gp_layer(input_features, update_cov=update_cov) if return_gp_cov: gp_cov_matrix = self.compute_predictive_covariance(gp_feature) return gp_output, gp_cov_matrix return gp_output