| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Deep networks.""" |
| |
|
| | from copy import deepcopy |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| |
|
| | def init_weights(m): |
| | @torch.no_grad() |
| | def truncated_normal_init(t, mean=0.0, std=0.01): |
| | |
| | t.data.normal_(mean, std) |
| | while True: |
| | cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std) |
| | if not torch.sum(cond): |
| | break |
| | w = torch.empty(t.shape, device=t.device, dtype=t.dtype) |
| | |
| | w.data.normal_(mean, std) |
| | t = torch.where(cond, w, t) |
| | return t |
| |
|
| | if type(m) is nn.Linear or isinstance(m, EnsembleFC): |
| | truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(m.in_features))) |
| | if m.bias is not None: |
| | m.bias.data.fill_(0.0) |
| |
|
| |
|
| | def init_weights_uniform(m): |
| | input_dim = m.in_features |
| | torch.nn.init.uniform(m.weight, -1 / np.sqrt(input_dim), 1 / np.sqrt(input_dim)) |
| | if m.bias is not None: |
| | m.bias.data.fill_(0.0) |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def __init__(self): |
| | super(Swish, self).__init__() |
| |
|
| | def forward(self, x): |
| | x = x * F.sigmoid(x) |
| | return x |
| |
|
| |
|
| | class MLPModel(nn.Module): |
| | def __init__(self, encoding_dim, hidden_dim=128, activation="relu") -> None: |
| | super(MLPModel, self).__init__() |
| | self.hidden_size = hidden_dim |
| | self.output_dim = 1 |
| |
|
| | self.nn1 = nn.Linear(encoding_dim, hidden_dim) |
| | self.nn2 = nn.Linear(hidden_dim, hidden_dim) |
| | self.nn_out = nn.Linear(hidden_dim, self.output_dim) |
| |
|
| | self.apply(init_weights) |
| |
|
| | if activation == "swish": |
| | self.activation = Swish() |
| | elif activation == "relu": |
| | self.activation = nn.ReLU() |
| | else: |
| | raise ValueError(f"Unknown activation {activation}") |
| |
|
| | def get_params(self) -> torch.Tensor: |
| | params = [] |
| | for pp in list(self.parameters()): |
| | params.append(pp.view(-1)) |
| | return torch.cat(params) |
| |
|
| | def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
| | x = self.activation(self.nn1(encoding)) |
| | x = self.activation(self.nn2(x)) |
| | score = self.nn_out(x) |
| | return score |
| |
|
| | def init(self): |
| | self.init_params = self.get_params().data.clone() |
| | if torch.cuda.is_available(): |
| | self.init_params = self.init_params.cuda() |
| |
|
| | def regularization(self): |
| | """Prior towards independent initialization.""" |
| | return ((self.get_params() - self.init_params) ** 2).mean() |
| |
|
| |
|
| | class EnsembleFC(nn.Module): |
| | __constants__ = ["in_features", "out_features"] |
| | in_features: int |
| | out_features: int |
| | ensemble_size: int |
| | weight: torch.Tensor |
| |
|
| | def __init__( |
| | self, |
| | in_features: int, |
| | out_features: int, |
| | ensemble_size: int, |
| | bias: bool = True, |
| | dtype=torch.float32, |
| | ) -> None: |
| | super(EnsembleFC, self).__init__() |
| | self.in_features = in_features |
| | self.out_features = out_features |
| | self.ensemble_size = ensemble_size |
| | |
| | self.weight = nn.Parameter(torch.empty(ensemble_size, in_features, out_features, dtype=dtype)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(ensemble_size, out_features, dtype=dtype)) |
| | else: |
| | self.register_parameter("bias", None) |
| |
|
| | def forward(self, input: torch.Tensor) -> torch.Tensor: |
| | input = input.to(self.weight.dtype) |
| | wx = torch.einsum("eblh,ehm->eblm", input, self.weight) |
| |
|
| | return torch.add(wx, self.bias[:, None, None, :]) |
| |
|
| |
|
| | def get_params(model): |
| | return torch.cat([p.view(-1) for p in model.parameters()]) |
| |
|
| |
|
| | class _EnsembleModel(nn.Module): |
| | def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None: |
| | |
| | super(_EnsembleModel, self).__init__() |
| | self.num_ensemble = num_ensemble |
| | self.hidden_dim = hidden_dim |
| | self.output_dim = 1 |
| |
|
| | self.nn1 = EnsembleFC(encoding_dim, hidden_dim, num_ensemble, dtype=dtype) |
| | self.nn2 = EnsembleFC(hidden_dim, hidden_dim, num_ensemble, dtype=dtype) |
| | self.nn_out = EnsembleFC(hidden_dim, self.output_dim, num_ensemble, dtype=dtype) |
| |
|
| | self.apply(init_weights) |
| |
|
| | if activation == "swish": |
| | self.activation = Swish() |
| | elif activation == "relu": |
| | self.activation = nn.ReLU() |
| | else: |
| | raise ValueError(f"Unknown activation {activation}") |
| |
|
| | def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
| | x = self.activation(self.nn1(encoding)) |
| | x = self.activation(self.nn2(x)) |
| | score = self.nn_out(x) |
| | return score |
| |
|
| | def regularization(self): |
| | """Prior towards independent initialization.""" |
| | return ((self.get_params() - self.init_params) ** 2).mean() |
| |
|
| |
|
| | class EnsembleModel(nn.Module): |
| | def __init__(self, encoding_dim, num_ensemble, hidden_dim=128, activation="relu", dtype=torch.float32) -> None: |
| | super(EnsembleModel, self).__init__() |
| | self.encoding_dim = encoding_dim |
| | self.num_ensemble = num_ensemble |
| | self.hidden_dim = hidden_dim |
| | self.model = _EnsembleModel(encoding_dim, num_ensemble, hidden_dim, activation, dtype) |
| | self.reg_model = deepcopy(self.model) |
| | |
| | for param in self.reg_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | def forward(self, encoding: torch.Tensor) -> torch.Tensor: |
| | return self.model(encoding) |
| |
|
| | def regularization(self): |
| | """Prior towards independent initialization.""" |
| | model_params = get_params(self.model) |
| | reg_params = get_params(self.reg_model).detach() |
| | return ((model_params - reg_params) ** 2).mean() |
| |
|