from ..custom_types import * from abc import ABC import math def torch_no_grad(func): def wrapper(*args, **kwargs): with torch.no_grad(): result = func(*args, **kwargs) return result return wrapper class Model(nn.Module, ABC): def __init__(self): super(Model, self).__init__() self.save_model: Union[None, Callable[[nn.Module]]] = None def save(self, **kwargs): self.save_model(self, **kwargs) class Concatenate(nn.Module): def __init__(self, dim): super(Concatenate, self).__init__() self.dim = dim def forward(self, x): return torch.cat(x, dim=self.dim) class View(nn.Module): def __init__(self, *shape): super(View, self).__init__() self.shape = shape def forward(self, x): return x.view(*self.shape) class Transpose(nn.Module): def __init__(self, dim0, dim1): super(Transpose, self).__init__() self.dim0, self.dim1 = dim0, dim1 def forward(self, x): return x.transpose(self.dim0, self.dim1) class Dummy(nn.Module): def __init__(self, *args): super(Dummy, self).__init__() def forward(self, *args): return args[0] class SineLayer(nn.Module): """ From the siren repository https://colab.research.google.com/github/vsitzmann/siren/blob/master/explore_siren.ipynb """ def __init__(self, in_features, out_features, bias=True, is_first=False, omega_0=30): super().__init__() self.omega_0 = omega_0 self.is_first = is_first self.in_features = in_features self.linear = nn.Linear(in_features, out_features, bias=bias) self.output_channels = out_features self.init_weights() def init_weights(self): with torch.no_grad(): if self.is_first: self.linear.weight.uniform_(-1 / self.in_features, 1 / self.in_features) else: self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, np.sqrt(6 / self.in_features) / self.omega_0) def forward(self, input): return torch.sin(self.omega_0 * self.linear(input)) class MLP(nn.Module): def forward(self, x, *_): return self.net(x) def __init__(self, ch: Union[List[int], Tuple[int, ...]], act: nn.Module = nn.ReLU, weight_norm=False): super(MLP, self).__init__() layers = [] for i in range(len(ch) - 1): layers.append(nn.Linear(ch[i], ch[i + 1])) if weight_norm: layers[-1] = nn.utils.weight_norm(layers[-1]) if i < len(ch) - 2: layers.append(act(True)) self.net = nn.Sequential(*layers) class GMAttend(nn.Module): def __init__(self, hidden_dim: int): super(GMAttend, self).__init__() self.key_dim = hidden_dim // 8 self.query_w = nn.Linear(hidden_dim, self.key_dim) self.key_w = nn.Linear(hidden_dim, self.key_dim) self.value_w = nn.Linear(hidden_dim, hidden_dim) self.softmax = nn.Softmax(dim=3) self.gamma = nn.Parameter(torch.zeros(1)) self.scale = 1 / torch.sqrt(torch.tensor(self.key_dim, dtype=torch.float32)) def forward(self, x): queries = self.query_w(x) keys = self.key_w(x) vals = self.value_w(x) attention = self.softmax(torch.einsum('bgqf,bgkf->bgqk', queries, keys)) out = torch.einsum('bgvf,bgqv->bgqf', vals, attention) out = self.gamma * out + x return out def recursive_to(item, device): if type(item) is T: return item.to(device) elif type(item) is tuple or type(item) is list: return [recursive_to(item[i], device) for i in range(len(item))] return item