File size: 9,129 Bytes
2f85de4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# python3.8
"""Defines feature extractor in 3D generation pipeline."""
import torch
from .triplane_sampler import TriplaneSampler
from .utils import interpolate_feature
from einops import rearrange
import math
__all__ = ['FeatureExtractor']
_REF_MODE = ['none', 'tri_plane', 'feature_volume', 'bev_plane_clevr_256', 'bev_plane_clevr_512', 'bev_plane_carla']
class FeatureExtractor(torch.nn.Module):
"""Defines the feature extractor in 3D Generation Pipeline.
Basically, the feature extractor takes in the latent code and sampled points
in addition to the reference representation as input, and outputs the
feature representation which contains information of each point's color and
density.
"""
def __init__(self, ref_mode='none', xyz_pe=False, reverse_xy=True):
super().__init__()
self.ref_mode = ref_mode
self.xyz_pe = xyz_pe
self.reverse_xy = reverse_xy
assert ref_mode in _REF_MODE
if ref_mode == 'tri_plane':
self.plane_axes = TriplaneSampler.generate_planes()
def forward(self,
wp,
points,
rendering_options,
position_encoder=None,
ref_representation=None,
post_module=None,
post_module_kwargs={},
ray_dirs=None,
cam_matrix=None,):
assert points.ndim in [3, 4]
if points.ndim == 3:
points = points.unsqueeze(2) # shape: [N, R, C] -> [N, R, 1, C]
N, R, K, _ = points.shape[:4]
# (Optional) Positional encoding.
if position_encoder is not None:
points_encoding = position_encoder(points) # shape: [N, R, K, C].
points_encoding = rearrange(points_encoding,
'N R K C -> N C (R K) 1').contiguous()
# Reshape `points` with shape [N, R*K, 3].
points = points.reshape(points.shape[0], -1, points.shape[-1])
# Get pre-point-features by sampling from
# the reference representation (if exists).
pre_point_features = points
if ref_representation is not None:
assert self.ref_mode is not None
if self.ref_mode == 'tri_plane':
pre_point_features = TriplaneSampler.sample_from_planes(
self.plane_axes.to(points.device),
ref_representation,
points,
padding_mode='zeros',
box_warp=rendering_options.get('box_warp', 1.0))
# shape: [N, 3, num_points, C], where num_points = H*W*K.
elif self.ref_mode == 'feature_volume':
bounds = rendering_options.get(
'bounds',
[[-0.1886, -0.1671, -0.1956], [0.1887, 0.1692, 0.1872]])
bounds = torch.Tensor(bounds).to(points.device)
pre_point_features = interpolate_feature(
points, ref_representation, bounds) # shape: [N, C, R*K].
pre_point_features = pre_point_features.unsqueeze(-1)
# shape: [N, C, R*K, 1].
post_module_kwargs.update(points_encoding=points_encoding)
elif 'bev_plane_clevr' in self.ref_mode:
h = w = int(self.ref_mode[-3:])
# first, transform points from world coordinates to bev coordinates
# cam_matrix: N, 4, 4
# points: N, 3, R*K
points_reshape = points # N, R*K, 3
# points_homo = torch.cat([points_reshape, torch.ones([*points_reshape.shape[:2], 1]).to(points_reshape.device)], -1)
# points_cam = torch.einsum('nxy,nby->nbx', cam_matrix, points_homo) # N, R*K, 4
if self.reverse_xy:
y = (0.5 * w - 128 + 256 - (points_reshape[..., 0] /9 + .5) * 256 ) / w * 2 - 1
x = (0.5 * h - 128 + (points_reshape[..., 1] /9 + .5) * 256 ) / h * 2 - 1
else:
x = (0.5 * w - 128 + 256 - (points_reshape[..., 0] /9 + .5) * 256 ) / w * 2 - 1
y = (0.5 * h - 128 + (points_reshape[..., 1] /9 + .5) * 256 ) / h * 2 - 1
z = points_reshape[..., -1] / 9
points_bev = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], -1)
# second, sample feature from BEV map
# ref_representation: N, C, A, A
# points_bev: N, R*K, 3
xy = points_bev[..., :2] # N, R*K, 2
xy = xy.unsqueeze(2) # N, R*K, 1, 2
feat_xy = torch.nn.functional.grid_sample(ref_representation, xy, mode='bilinear',
padding_mode='zeros', align_corners=False) # N, C, R*K, 1
feat_xy = feat_xy.squeeze(3) # N, C,R*K
x = points_bev[..., 0] # N, R*K
y = points_bev[..., 1] # N, R*K
z = points_bev[..., -1] # N, R*K
# third, do positional encoding on z
d_model = 32
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *-(math.log(10000.0) / d_model))).to(z.device)
pe_x = torch.zeros([x.shape[0], x.shape[1], d_model]).to(x.device)
pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1).float() * div_term)
pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1).float() * div_term)
pe_y = torch.zeros([y.shape[0], y.shape[1], d_model]).to(y.device)
pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1).float() * div_term)
pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1).float() * div_term)
pe_z = torch.zeros([z.shape[0], z.shape[1], d_model]).to(z.device)
pe_z[..., 0::2] = torch.sin(z.unsqueeze(-1).float() * div_term)
pe_z[..., 1::2] = torch.cos(z.unsqueeze(-1).float() * div_term)
if self.xyz_pe:
feat_xyz = torch.cat([feat_xy, pe_x.permute(0, 2, 1), pe_y.permute(0,2,1),pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K
else:
feat_xyz = torch.cat([feat_xy ,pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K
pre_point_features = feat_xyz.permute(0, 2, 1) # N, RK, C+d_model
pre_point_features = pre_point_features.view(N, R, K, -1)
elif self.ref_mode == 'bev_plane_carla':
x = (217.5 - 8 * points[..., 0]) / 128 - 1
y = (128.0 + 8 * points[..., 1]) / 128 - 1
z = points[..., 2]
points_bev = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1)], -1)
xy = points_bev[..., :2]
xy = xy.unsqueeze(2)
feat_xy = torch.nn.functional.grid_sample(ref_representation, xy, mode='bilinear',padding_mode='zeros', align_corners=False)
feat_xy = feat_xy.squeeze(3)
z = points_bev[..., -1]
d_model = 32
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *-(math.log(10000.0) / d_model))).to(z.device)
pe_x = torch.zeros([x.shape[0], x.shape[1], d_model]).to(x.device)
pe_x[..., 0::2] = torch.sin(x.unsqueeze(-1).float() * div_term)
pe_x[..., 1::2] = torch.cos(x.unsqueeze(-1).float() * div_term)
pe_y = torch.zeros([y.shape[0], y.shape[1], d_model]).to(y.device)
pe_y[..., 0::2] = torch.sin(y.unsqueeze(-1).float() * div_term)
pe_y[..., 1::2] = torch.cos(y.unsqueeze(-1).float() * div_term)
pe_z = torch.zeros([z.shape[0], z.shape[1], d_model]).to(z.device)
pe_z[..., 0::2] = torch.sin(z.unsqueeze(-1).float() * div_term)
pe_z[..., 1::2] = torch.cos(z.unsqueeze(-1).float() * div_term)
if self.xyz_pe:
feat_xyz = torch.cat([feat_xy, pe_x.permute(0, 2, 1), pe_y.permute(0,2,1),pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K
else:
feat_xyz = torch.cat([feat_xy ,pe_z.permute(0, 2, 1)], 1) # N, C+d_model, R*K
pre_point_features = feat_xyz.permute(0, 2, 1) # N, RK, C+d_model
pre_point_features = pre_point_features.view(N, R, K, -1)
else:
raise NotImplementedError
# Get post-point-features by feeding pre-point-features into the
# post-module (if exists).
if post_module is not None:
post_point_features = post_module(pre_point_features, wp,
**post_module_kwargs)
else:
post_point_features = pre_point_features
if post_point_features.ndim == 2:
post_point_features = rearrange('(N R K) C -> N R K C',
N=N, R=R, K=K).contiguous()
return post_point_features
|