gp-uq-tester / gp.py
tombm's picture
Add functionality to app
5212a08
# Code for GP final layer adapted from this great repo:
# https://github.com/kimjeyoung/SNGP-BERT-Pytorch .
# We simplify things here a bit by removing the spectral
# normalisation as the authors of the Plex paper say that this
# isn't strictly necessary, so we just have a GP classification head on the model.
import torch
import math
import copy
from torch import nn
def RandomFeatureLinear(i_dim, o_dim, bias=True, require_grad=False):
m = nn.Linear(i_dim, o_dim, bias)
nn.init.normal_(m.weight, mean=0.0, std=0.05)
m.weight.requires_grad = require_grad # Freeze weights
if bias:
nn.init.uniform_(m.bias, a=0.0, b=2.0 * math.pi) # Freeze bias
m.bias.requires_grad = require_grad
return m
class GPClassificationHead(nn.Module):
def __init__(
self,
hidden_size=768,
gp_kernel_scale=1.0,
num_inducing=1024,
gp_output_bias=0.0,
layer_norm_eps=1e-12,
scale_random_features=True,
normalize_input=True,
gp_cov_momentum=0.999,
gp_cov_ridge_penalty=1e-3,
epochs=40,
num_classes=3,
device="cpu",
):
super(GPClassificationHead, self).__init__()
self.final_epochs = epochs - 1
self.gp_cov_ridge_penalty = gp_cov_ridge_penalty
self.gp_cov_momentum = gp_cov_momentum
self.pooled_output_dim = hidden_size
self.gp_input_scale = 1.0 / math.sqrt(gp_kernel_scale)
self.gp_feature_scale = math.sqrt(2.0 / float(num_inducing))
self.gp_output_bias = gp_output_bias
self.scale_random_features = scale_random_features
self.normalize_input = normalize_input
self.device = device
self._gp_input_normalize_layer = torch.nn.LayerNorm(
hidden_size, eps=layer_norm_eps
)
self._gp_output_layer = nn.Linear(
num_inducing, num_classes, bias=False
) # gp_output_bias set to not trainable
self._gp_output_bias = torch.tensor([self.gp_output_bias] * num_classes).to(
device
)
self._random_feature = RandomFeatureLinear(self.pooled_output_dim, num_inducing)
# Inverse covariance matrix corresponding to RFF-GP posterior
self.initial_precision_matrix = self.gp_cov_ridge_penalty * torch.eye(
num_inducing
).to(device)
self.precision_matrix = torch.nn.Parameter(
copy.deepcopy(self.initial_precision_matrix), requires_grad=False
)
def gp_layer(self, gp_inputs, update_cov=True):
if self.normalize_input:
gp_inputs = self._gp_input_normalize_layer(gp_inputs)
gp_feature = self._random_feature(gp_inputs)
gp_feature = torch.cos(gp_feature)
if self.scale_random_features:
gp_feature = gp_feature * self.gp_input_scale
gp_output = self._gp_output_layer(gp_feature).to(
self.device
) + self._gp_output_bias.to(self.device)
if update_cov:
self.update_cov(gp_feature)
return gp_feature, gp_output
def reset_cov(self):
self.precision_matrix = torch.nn.Parameter(
copy.deepcopy(self.initial_precision_matrix), requires_grad=False
)
def update_cov(self, gp_feature):
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L346
batch_size = gp_feature.size()[0]
precision_matrix_minibatch = torch.matmul(gp_feature.t(), gp_feature)
# Moving average updates to precision matrix
precision_matrix_minibatch = precision_matrix_minibatch / batch_size
precision_matrix_new = (
self.gp_cov_momentum * self.precision_matrix
+ (1.0 - self.gp_cov_momentum) * precision_matrix_minibatch
)
self.precision_matrix = torch.nn.Parameter(
precision_matrix_new, requires_grad=False
)
def compute_predictive_covariance(self, gp_feature):
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L403
# Covariance matrix of feature coefficient
feature_cov_matrix = torch.linalg.inv(self.precision_matrix)
# Predictive covariance matrix for the GP
cov_feature_product = (
torch.matmul(feature_cov_matrix, gp_feature.t()) * self.gp_cov_ridge_penalty
)
gp_cov_matrix = torch.matmul(gp_feature, cov_feature_product)
return gp_cov_matrix
def forward(
self,
input_features,
return_gp_cov: bool = False,
update_cov: bool = True,
):
gp_feature, gp_output = self.gp_layer(input_features, update_cov=update_cov)
if return_gp_cov:
gp_cov_matrix = self.compute_predictive_covariance(gp_feature)
return gp_output, gp_cov_matrix
return gp_output