| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class PointEmbed(nn.Module): |
| def __init__(self, hidden_dim=48, dim=128): |
| super().__init__() |
|
|
| assert hidden_dim % 6 == 0 |
|
|
| self.embedding_dim = hidden_dim |
| e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
| e = torch.stack( |
| [ |
| torch.cat( |
| [ |
| e, |
| torch.zeros(self.embedding_dim // 6), |
| torch.zeros(self.embedding_dim // 6), |
| ] |
| ), |
| torch.cat( |
| [ |
| torch.zeros(self.embedding_dim // 6), |
| e, |
| torch.zeros(self.embedding_dim // 6), |
| ] |
| ), |
| torch.cat( |
| [ |
| torch.zeros(self.embedding_dim // 6), |
| torch.zeros(self.embedding_dim // 6), |
| e, |
| ] |
| ), |
| ] |
| ) |
|
|
| self.register_buffer("basis", e) |
|
|
| self.mlp = nn.Linear(self.embedding_dim + 3, dim) |
| self.norm = nn.LayerNorm(dim) |
|
|
| @staticmethod |
| def embed(input, basis): |
| projections = torch.einsum("bnd,de->bne", input, basis) |
| embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
|
|
| return embeddings |
|
|
| def forward(self, input): |
| |
| embed = self.mlp( |
| torch.cat([self.embed(input, self.basis), input], dim=2) |
| ) |
| embed = self.norm(embed) |
| return embed |
|
|