File size: 2,942 Bytes
2252f3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import pytorch_lightning as pl
import numpy as np
class SpatialEncoder(pl.LightningModule):
def __init__(self,
sp_level=1,
sp_type="rel_z_decay",
scale=1.0,
n_kpt=24,
sigma=0.2):
super().__init__()
self.sp_type = sp_type
self.sp_level = sp_level
self.n_kpt = n_kpt
self.scale = scale
self.sigma = sigma
@staticmethod
def position_embedding(x, nlevels, scale=1.0):
"""
args:
x: (B, N, C)
return:
(B, N, C * n_levels * 2)
"""
if nlevels <= 0:
return x
vec = SpatialEncoder.pe_vector(nlevels, x.device, scale)
B, N, _ = x.shape
y = x[:, :, None, :] * vec[None, None, :, None]
z = torch.cat((torch.sin(y), torch.cos(y)), axis=-1).view(B, N, -1)
return torch.cat([x, z], -1)
@staticmethod
def pe_vector(nlevels, device, scale=1.0):
v, val = [], 1
for _ in range(nlevels):
v.append(scale * np.pi * val)
val *= 2
return torch.from_numpy(np.asarray(v, dtype=np.float32)).to(device)
def get_dim(self):
if self.sp_type in ["z", "rel_z", "rel_z_decay"]:
if "rel" in self.sp_type:
return (1 + 2 * self.sp_level) * self.n_kpt
else:
return 1 + 2 * self.sp_level
elif "xyz" in self.sp_type:
if "rel" in self.sp_type:
return (1 + 2 * self.sp_level) * 3 * self.n_kpt
else:
return (1 + 2 * self.sp_level) * 3
return 0
def forward(self, cxyz, kptxyz):
B, N = cxyz.shape[:2]
K = kptxyz.shape[1]
dz = cxyz[:, :, None, 2:3] - kptxyz[:, None, :, 2:3]
dxyz = cxyz[:, :, None] - kptxyz[:, None, :]
# (B, N, K)
weight = torch.exp(-(dxyz**2).sum(-1) / (2.0 * (self.sigma**2)))
# position embedding ( B, N, K * (2*n_levels+1) )
out = self.position_embedding(dz.view(B, N, K), self.sp_level)
# BV,N,K,(2*n_levels+1) * B,N,K,1 = B,N,K*(2*n_levels+1) -> BV,K*(2*n_levels+1),N
out = (out.view(B, N, -1, K) * weight[:, :, None]).view(B, N, -1).permute(0,2,1)
return out
if __name__ == "__main__":
pts = torch.randn(2, 10000, 3).to("cuda")
kpts = torch.randn(2, 24, 3).to("cuda")
sp_encoder = SpatialEncoder(sp_level=3,
sp_type="rel_z_decay",
scale=1.0,
n_kpt=24,
sigma=0.1).to("cuda")
out = sp_encoder(pts, kpts)
print(out.shape)
|