|
import torch.nn as nn |
|
import torch |
|
import numpy as np |
|
|
|
class PointEmbed(nn.Module): |
|
def __init__(self, hidden_dim=48): |
|
super().__init__() |
|
|
|
assert hidden_dim % 6 == 0 |
|
|
|
self.embedding_dim = hidden_dim |
|
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi |
|
e = torch.stack([ |
|
torch.cat([e, torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), e, |
|
torch.zeros(self.embedding_dim // 6)]), |
|
torch.cat([torch.zeros(self.embedding_dim // 6), |
|
torch.zeros(self.embedding_dim // 6), e]), |
|
]) |
|
self.register_buffer('basis', e) |
|
|
|
|
|
@staticmethod |
|
def embed(input, basis): |
|
projections = torch.einsum( |
|
'bnd,de->bne', input, basis) |
|
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) |
|
return embeddings |
|
|
|
def forward(self, input): |
|
|
|
embed = self.embed(input, self.basis) |
|
return embed |
|
|