|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
def pos2posemb3d(pos, num_pos_feats=128, temperature=10000):
|
|
scale = 2 * math.pi
|
|
pos = pos * scale
|
|
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
|
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
|
|
pos_x = pos[..., 0, None] / dim_t
|
|
pos_y = pos[..., 1, None] / dim_t
|
|
pos_z = pos[..., 2, None] / dim_t
|
|
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
|
|
pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
|
|
pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), dim=-1).flatten(-2)
|
|
posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1)
|
|
return posemb
|
|
|
|
|
|
def pos2posemb1d(pos, num_pos_feats=256, temperature=10000):
|
|
scale = 2 * math.pi
|
|
pos = pos * scale
|
|
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
|
|
dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats)
|
|
pos_x = pos[..., 0, None] / dim_t
|
|
|
|
pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
|
|
|
|
return pos_x
|
|
|
|
|
|
def nerf_positional_encoding(
|
|
tensor, num_encoding_functions=6, include_input=False, log_sampling=True
|
|
) -> torch.Tensor:
|
|
r"""Apply positional encoding to the input.
|
|
Args:
|
|
tensor (torch.Tensor): Input tensor to be positionally encoded.
|
|
encoding_size (optional, int): Number of encoding functions used to compute
|
|
a positional encoding (default: 6).
|
|
include_input (optional, bool): Whether or not to include the input in the
|
|
positional encoding (default: True).
|
|
Returns:
|
|
(torch.Tensor): Positional encoding of the input tensor.
|
|
"""
|
|
|
|
|
|
encoding = [tensor] if include_input else []
|
|
if log_sampling:
|
|
frequency_bands = 2.0 ** torch.linspace(
|
|
0.0,
|
|
num_encoding_functions - 1,
|
|
num_encoding_functions,
|
|
dtype=tensor.dtype,
|
|
device=tensor.device,
|
|
)
|
|
else:
|
|
frequency_bands = torch.linspace(
|
|
2.0 ** 0.0,
|
|
2.0 ** (num_encoding_functions - 1),
|
|
num_encoding_functions,
|
|
dtype=tensor.dtype,
|
|
device=tensor.device,
|
|
)
|
|
|
|
for freq in frequency_bands:
|
|
for func in [torch.sin, torch.cos]:
|
|
encoding.append(func(tensor * freq))
|
|
|
|
|
|
if len(encoding) == 1:
|
|
return encoding[0]
|
|
else:
|
|
return torch.cat(encoding, dim=-1)
|
|
|
|
|
|
def traj2nerf(traj):
|
|
result = torch.cat(
|
|
[
|
|
nerf_positional_encoding(traj[..., :2]),
|
|
torch.cos(traj[..., -1])[..., None],
|
|
torch.sin(traj[..., -1])[..., None],
|
|
], dim=-1
|
|
)
|
|
return result
|
|
|
|
|
|
def nerf2traj(nerf, num_encoding_functions=6, include_input=False, log_sampling=True):
|
|
|
|
original_dim = 2
|
|
|
|
|
|
if include_input:
|
|
encoding_length = original_dim * (2 * num_encoding_functions + 1)
|
|
else:
|
|
encoding_length = original_dim * 2 * num_encoding_functions
|
|
|
|
|
|
positional_encoding = nerf[..., :encoding_length]
|
|
|
|
|
|
if include_input:
|
|
original_position = positional_encoding[..., :original_dim]
|
|
positional_encoding = positional_encoding[..., original_dim:]
|
|
else:
|
|
original_position = torch.zeros(
|
|
(*nerf.shape[:-1], original_dim), dtype=nerf.dtype, device=nerf.device
|
|
)
|
|
|
|
if log_sampling:
|
|
frequency_bands = 2.0 ** torch.linspace(
|
|
0.0,
|
|
num_encoding_functions - 1,
|
|
num_encoding_functions,
|
|
dtype=nerf.dtype,
|
|
device=nerf.device,
|
|
)
|
|
else:
|
|
frequency_bands = torch.linspace(
|
|
2.0 ** 0.0,
|
|
2.0 ** (num_encoding_functions - 1),
|
|
num_encoding_functions,
|
|
dtype=nerf.dtype,
|
|
device=nerf.device,
|
|
)
|
|
|
|
for i, freq in enumerate(frequency_bands):
|
|
for j, func in enumerate([torch.sin, torch.cos]):
|
|
original_position += func(positional_encoding[..., (2 * i + j)::2 * num_encoding_functions]) / freq
|
|
|
|
|
|
cos_angle = nerf[..., -2]
|
|
sin_angle = nerf[..., -1]
|
|
|
|
|
|
angle = torch.atan2(sin_angle, cos_angle)
|
|
|
|
|
|
traj = torch.cat([original_position, angle[..., None]], dim=-1)
|
|
return traj
|
|
|
|
|
|
if __name__ == '__main__':
|
|
traj = torch.from_numpy(np.load('/mnt/f/e2e/navsim_ours/traj_final/test_4096_kmeans.npy'))
|
|
nerf = traj2nerf(traj)
|
|
traj_2 = nerf2traj(nerf)
|
|
|