Spaces:
Runtime error
Runtime error
# Code for GP final layer adapted from this great repo: | |
# https://github.com/kimjeyoung/SNGP-BERT-Pytorch . | |
# We simplify things here a bit by removing the spectral | |
# normalisation as the authors of the Plex paper say that this | |
# isn't strictly necessary, so we just have a GP classification head on the model. | |
import torch | |
import math | |
import copy | |
from torch import nn | |
def RandomFeatureLinear(i_dim, o_dim, bias=True, require_grad=False): | |
m = nn.Linear(i_dim, o_dim, bias) | |
nn.init.normal_(m.weight, mean=0.0, std=0.05) | |
m.weight.requires_grad = require_grad # Freeze weights | |
if bias: | |
nn.init.uniform_(m.bias, a=0.0, b=2.0 * math.pi) # Freeze bias | |
m.bias.requires_grad = require_grad | |
return m | |
class GPClassificationHead(nn.Module): | |
def __init__( | |
self, | |
hidden_size=768, | |
gp_kernel_scale=1.0, | |
num_inducing=1024, | |
gp_output_bias=0.0, | |
layer_norm_eps=1e-12, | |
scale_random_features=True, | |
normalize_input=True, | |
gp_cov_momentum=0.999, | |
gp_cov_ridge_penalty=1e-3, | |
epochs=40, | |
num_classes=3, | |
device="cpu", | |
): | |
super(GPClassificationHead, self).__init__() | |
self.final_epochs = epochs - 1 | |
self.gp_cov_ridge_penalty = gp_cov_ridge_penalty | |
self.gp_cov_momentum = gp_cov_momentum | |
self.pooled_output_dim = hidden_size | |
self.gp_input_scale = 1.0 / math.sqrt(gp_kernel_scale) | |
self.gp_feature_scale = math.sqrt(2.0 / float(num_inducing)) | |
self.gp_output_bias = gp_output_bias | |
self.scale_random_features = scale_random_features | |
self.normalize_input = normalize_input | |
self.device = device | |
self._gp_input_normalize_layer = torch.nn.LayerNorm( | |
hidden_size, eps=layer_norm_eps | |
) | |
self._gp_output_layer = nn.Linear( | |
num_inducing, num_classes, bias=False | |
) # gp_output_bias set to not trainable | |
self._gp_output_bias = torch.tensor([self.gp_output_bias] * num_classes).to( | |
device | |
) | |
self._random_feature = RandomFeatureLinear(self.pooled_output_dim, num_inducing) | |
# Inverse covariance matrix corresponding to RFF-GP posterior | |
self.initial_precision_matrix = self.gp_cov_ridge_penalty * torch.eye( | |
num_inducing | |
).to(device) | |
self.precision_matrix = torch.nn.Parameter( | |
copy.deepcopy(self.initial_precision_matrix), requires_grad=False | |
) | |
def gp_layer(self, gp_inputs, update_cov=True): | |
if self.normalize_input: | |
gp_inputs = self._gp_input_normalize_layer(gp_inputs) | |
gp_feature = self._random_feature(gp_inputs) | |
gp_feature = torch.cos(gp_feature) | |
if self.scale_random_features: | |
gp_feature = gp_feature * self.gp_input_scale | |
gp_output = self._gp_output_layer(gp_feature).to( | |
self.device | |
) + self._gp_output_bias.to(self.device) | |
if update_cov: | |
self.update_cov(gp_feature) | |
return gp_feature, gp_output | |
def reset_cov(self): | |
self.precision_matrix = torch.nn.Parameter( | |
copy.deepcopy(self.initial_precision_matrix), requires_grad=False | |
) | |
def update_cov(self, gp_feature): | |
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L346 | |
batch_size = gp_feature.size()[0] | |
precision_matrix_minibatch = torch.matmul(gp_feature.t(), gp_feature) | |
# Moving average updates to precision matrix | |
precision_matrix_minibatch = precision_matrix_minibatch / batch_size | |
precision_matrix_new = ( | |
self.gp_cov_momentum * self.precision_matrix | |
+ (1.0 - self.gp_cov_momentum) * precision_matrix_minibatch | |
) | |
self.precision_matrix = torch.nn.Parameter( | |
precision_matrix_new, requires_grad=False | |
) | |
def compute_predictive_covariance(self, gp_feature): | |
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L403 | |
# Covariance matrix of feature coefficient | |
feature_cov_matrix = torch.linalg.inv(self.precision_matrix) | |
# Predictive covariance matrix for the GP | |
cov_feature_product = ( | |
torch.matmul(feature_cov_matrix, gp_feature.t()) * self.gp_cov_ridge_penalty | |
) | |
gp_cov_matrix = torch.matmul(gp_feature, cov_feature_product) | |
return gp_cov_matrix | |
def forward( | |
self, | |
input_features, | |
return_gp_cov: bool = False, | |
update_cov: bool = True, | |
): | |
gp_feature, gp_output = self.gp_layer(input_features, update_cov=update_cov) | |
if return_gp_cov: | |
gp_cov_matrix = self.compute_predictive_covariance(gp_feature) | |
return gp_output, gp_cov_matrix | |
return gp_output | |