Spaces:
Sleeping
Sleeping
import functools | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish | |
def _linear_gn_relu(in_channels, out_channels): | |
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) | |
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): | |
r = width_multiplier | |
if dim == 1: | |
block = _linear_gn_relu | |
else: | |
block = SharedMLP | |
if not isinstance(out_channels, (list, tuple)): | |
out_channels = [out_channels] | |
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): | |
return nn.Sequential(), in_channels, in_channels | |
layers = [] | |
for oc in out_channels[:-1]: | |
if oc < 1: | |
layers.append(nn.Dropout(oc)) | |
else: | |
oc = int(r * oc) | |
layers.append(block(in_channels, oc)) | |
in_channels = oc | |
if dim == 1: | |
if classifier: | |
layers.append(nn.Linear(in_channels, out_channels[-1])) | |
else: | |
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) | |
else: | |
if classifier: | |
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) | |
else: | |
layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) | |
return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) | |
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, | |
width_multiplier=1, voxel_resolution_multiplier=1): | |
r, vr = width_multiplier, voxel_resolution_multiplier | |
layers, concat_channels = [], 0 | |
c = 0 | |
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): | |
out_channels = int(r * out_channels) | |
for p in range(num_blocks): | |
attention = k % 2 == 0 and k > 0 and p == 0 | |
if voxel_resolution is None: | |
block = SharedMLP | |
else: | |
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, | |
with_se=with_se, normalize=normalize, eps=eps) | |
if c == 0: | |
layers.append(block(in_channels, out_channels)) | |
else: | |
layers.append(block(in_channels+embed_dim, out_channels)) | |
in_channels = out_channels | |
concat_channels += out_channels | |
c += 1 | |
return layers, in_channels, concat_channels | |
def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False, | |
dropout=0.1, with_se=False, normalize=True, eps=0, | |
width_multiplier=1, voxel_resolution_multiplier=1, | |
in_ch_multiplier=1, | |
extra_in_channel=0): | |
"use_att is True by default, in_ch_multiplier: increase the input channel dimension" | |
r, vr = width_multiplier, voxel_resolution_multiplier | |
in_channels = extra_feature_channels + 3 | |
sa_layers, sa_in_channels = [], [] | |
block_count = 0 | |
for conv_configs, sa_configs in sa_blocks_config: | |
k = 0 | |
sa_in_channels.append(in_channels) | |
sa_blocks = [] | |
if conv_configs is not None: | |
out_channels, num_blocks, voxel_resolution = conv_configs | |
out_channels = int(r * out_channels) | |
for p in range(num_blocks): # pconv is repeated | |
attention = (block_count+1) % 2 == 0 and use_att and p == 0 | |
if voxel_resolution is None: | |
block = SharedMLP | |
else: | |
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, | |
dropout=dropout, | |
with_se=with_se, with_se_relu=True, | |
normalize=normalize, eps=eps) | |
if block_count == 0: | |
sa_blocks.append(block(in_channels, out_channels)) | |
elif k ==0: | |
sa_blocks.append(block(in_channels+embed_dim, out_channels)) | |
in_channels = out_channels | |
k += 1 | |
extra_feature_channels = in_channels | |
num_centers, radius, num_neighbors, out_channels = sa_configs | |
_out_channels = [] | |
for oc in out_channels: | |
if isinstance(oc, (list, tuple)): | |
_out_channels.append([int(r * _oc) for _oc in oc]) | |
else: | |
_out_channels.append(int(r * oc)) | |
out_channels = _out_channels | |
if num_centers is None: | |
block = PointNetAModule # always not-none | |
else: | |
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, | |
num_neighbors=num_neighbors) | |
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, | |
include_coordinates=True)) | |
block_count += 1 | |
# XH: double the channel for concat, or add additional channel for cross attention | |
if block_count < len(sa_blocks_config): | |
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel) | |
else: | |
# no cross attention before the self attention module | |
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier) | |
if len(sa_blocks) == 1: | |
sa_layers.append(sa_blocks[0]) # first pconv is repeated ? | |
else: | |
sa_layers.append(nn.Sequential(*sa_blocks)) | |
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers | |
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, | |
dropout=0.1, | |
with_se=False, normalize=True, eps=0, | |
width_multiplier=1, voxel_resolution_multiplier=1, | |
in_ch_multiplier=1, extra_in_channel=0): | |
""" | |
:param fp_blocks: | |
:param in_channels: | |
:param sa_in_channels: | |
:param embed_dim: | |
:param use_att: | |
:param dropout: | |
:param with_se: | |
:param normalize: | |
:param eps: | |
:param width_multiplier: | |
:param voxel_resolution_multiplier: | |
:param in_ch_multiplier: increase the input channel dimension | |
:return: | |
""" | |
r, vr = width_multiplier, voxel_resolution_multiplier | |
fp_layers = [] | |
c = 0 | |
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): | |
fp_blocks = [] | |
out_channels = tuple(int(r * oc) for oc in fp_configs) | |
if fp_idx > 0: | |
# to handle additional channel from concatenating human + object features | |
sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel) | |
else: | |
sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention | |
fp_blocks.append( | |
PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim, | |
out_channels=out_channels) | |
) # interpolate + Conv1d, does not change number of points | |
in_channels = out_channels[-1] | |
if conv_configs is not None: | |
out_channels, num_blocks, voxel_resolution = conv_configs | |
out_channels = int(r * out_channels) | |
for p in range(num_blocks): | |
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 | |
if voxel_resolution is None: | |
block = SharedMLP | |
else: | |
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, | |
dropout=dropout, | |
with_se=with_se, with_se_relu=True, | |
normalize=normalize, eps=eps) | |
fp_blocks.append(block(in_channels, out_channels)) | |
in_channels = out_channels # this should not change! | |
if len(fp_blocks) == 1: | |
fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer | |
else: | |
fp_layers.append(nn.Sequential(*fp_blocks)) | |
c += 1 | |
return fp_layers, in_channels | |
def get_timestep_embedding(embed_dim, timesteps, device): | |
""" | |
Timestep embedding function. Not that this should work just as well for | |
continuous values as for discrete values. | |
""" | |
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 | |
half_dim = embed_dim // 2 | |
emb = np.log(10000) / (half_dim - 1) | |
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) | |
emb = timesteps[:, None] * emb[None, :] | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
if embed_dim % 2 == 1: # zero pad | |
emb = nn.functional.pad(emb, (0, 1), "constant", 0) | |
assert emb.shape == torch.Size([timesteps.shape[0], embed_dim]) | |
return emb | |