| import importlib |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from . import vb_layers_initialize as init |
|
|
|
|
| @torch.compiler.disable |
| def kernel_triangular_mult( |
| x, |
| direction, |
| mask, |
| norm_in_weight, |
| norm_in_bias, |
| p_in_weight, |
| g_in_weight, |
| norm_out_weight, |
| norm_out_bias, |
| p_out_weight, |
| g_out_weight, |
| eps, |
| ): |
| triangle_module = importlib.import_module("cuequivariance_torch.primitives.triangle") |
| triangle_multiplicative_update = triangle_module.triangle_multiplicative_update |
| return triangle_multiplicative_update( |
| x, |
| direction=direction, |
| mask=mask, |
| norm_in_weight=norm_in_weight, |
| norm_in_bias=norm_in_bias, |
| p_in_weight=p_in_weight, |
| g_in_weight=g_in_weight, |
| norm_out_weight=norm_out_weight, |
| norm_out_bias=norm_out_bias, |
| p_out_weight=p_out_weight, |
| g_out_weight=g_out_weight, |
| eps=eps, |
| ) |
|
|
|
|
| class TriangleMultiplicationOutgoing(nn.Module): |
| """TriangleMultiplicationOutgoing.""" |
|
|
| def __init__(self, dim: int = 128) -> None: |
| """Initialize the TriangularUpdate module. |
| |
| Parameters |
| ---------- |
| dim: int |
| The dimension of the input, default 128 |
| |
| """ |
| super().__init__() |
|
|
| self.norm_in = nn.LayerNorm(dim, eps=1e-5) |
| self.p_in = nn.Linear(dim, 2 * dim, bias=False) |
| self.g_in = nn.Linear(dim, 2 * dim, bias=False) |
|
|
| self.norm_out = nn.LayerNorm(dim) |
| self.p_out = nn.Linear(dim, dim, bias=False) |
| self.g_out = nn.Linear(dim, dim, bias=False) |
|
|
| init.bias_init_one_(self.norm_in.weight) |
| init.bias_init_zero_(self.norm_in.bias) |
|
|
| init.lecun_normal_init_(self.p_in.weight) |
| init.gating_init_(self.g_in.weight) |
|
|
| init.bias_init_one_(self.norm_out.weight) |
| init.bias_init_zero_(self.norm_out.bias) |
|
|
| init.final_init_(self.p_out.weight) |
| init.gating_init_(self.g_out.weight) |
|
|
| def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: |
| """Perform a forward pass. |
| |
| Parameters |
| ---------- |
| x: torch.Tensor |
| The input data of shape (B, N, N, D) |
| mask: torch.Tensor |
| The input mask of shape (B, N, N) |
| use_kernels: bool |
| Whether to use the kernel |
| |
| Returns |
| ------- |
| x: torch.Tensor |
| The output data of shape (B, N, N, D) |
| |
| """ |
| if use_kernels: |
| return kernel_triangular_mult( |
| x, |
| direction="outgoing", |
| mask=mask, |
| norm_in_weight=self.norm_in.weight, |
| norm_in_bias=self.norm_in.bias, |
| p_in_weight=self.p_in.weight, |
| g_in_weight=self.g_in.weight, |
| norm_out_weight=self.norm_out.weight, |
| norm_out_bias=self.norm_out.bias, |
| p_out_weight=self.p_out.weight, |
| g_out_weight=self.g_out.weight, |
| eps=1e-5, |
| ) |
|
|
| |
| x = self.norm_in(x) |
| x_in = x |
| x = self.p_in(x) * self.g_in(x).sigmoid() |
|
|
| |
| x = x * mask.unsqueeze(-1) |
|
|
| |
| a, b = torch.chunk(x.float(), 2, dim=-1) |
|
|
| |
| x = torch.einsum("bikd,bjkd->bijd", a, b) |
|
|
| |
| x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() |
|
|
| return x |
|
|
|
|
| class TriangleMultiplicationIncoming(nn.Module): |
| """TriangleMultiplicationIncoming.""" |
|
|
| def __init__(self, dim: int = 128) -> None: |
| """Initialize the TriangularUpdate module. |
| |
| Parameters |
| ---------- |
| dim: int |
| The dimension of the input, default 128 |
| |
| """ |
| super().__init__() |
|
|
| self.norm_in = nn.LayerNorm(dim, eps=1e-5) |
| self.p_in = nn.Linear(dim, 2 * dim, bias=False) |
| self.g_in = nn.Linear(dim, 2 * dim, bias=False) |
|
|
| self.norm_out = nn.LayerNorm(dim) |
| self.p_out = nn.Linear(dim, dim, bias=False) |
| self.g_out = nn.Linear(dim, dim, bias=False) |
|
|
| init.bias_init_one_(self.norm_in.weight) |
| init.bias_init_zero_(self.norm_in.bias) |
|
|
| init.lecun_normal_init_(self.p_in.weight) |
| init.gating_init_(self.g_in.weight) |
|
|
| init.bias_init_one_(self.norm_out.weight) |
| init.bias_init_zero_(self.norm_out.bias) |
|
|
| init.final_init_(self.p_out.weight) |
| init.gating_init_(self.g_out.weight) |
|
|
| def forward(self, x: Tensor, mask: Tensor, use_kernels: bool = False) -> Tensor: |
| """Perform a forward pass. |
| |
| Parameters |
| ---------- |
| x: torch.Tensor |
| The input data of shape (B, N, N, D) |
| mask: torch.Tensor |
| The input mask of shape (B, N, N) |
| use_kernels: bool |
| Whether to use the kernel |
| |
| Returns |
| ------- |
| x: torch.Tensor |
| The output data of shape (B, N, N, D) |
| |
| """ |
| if use_kernels: |
| return kernel_triangular_mult( |
| x, |
| direction="incoming", |
| mask=mask, |
| norm_in_weight=self.norm_in.weight, |
| norm_in_bias=self.norm_in.bias, |
| p_in_weight=self.p_in.weight, |
| g_in_weight=self.g_in.weight, |
| norm_out_weight=self.norm_out.weight, |
| norm_out_bias=self.norm_out.bias, |
| p_out_weight=self.p_out.weight, |
| g_out_weight=self.g_out.weight, |
| eps=1e-5, |
| ) |
|
|
| |
| x = self.norm_in(x) |
| x_in = x |
| x = self.p_in(x) * self.g_in(x).sigmoid() |
|
|
| |
| x = x * mask.unsqueeze(-1) |
|
|
| |
| a, b = torch.chunk(x.float(), 2, dim=-1) |
|
|
| |
| x = torch.einsum("bkid,bkjd->bijd", a, b) |
|
|
| |
| x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid() |
|
|
| return x |
|
|