xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame
No virus
9.46 kB
import functools
import torch
import torch.nn as nn
import numpy as np
from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1,
extra_in_channel=0):
"use_att is True by default, in_ch_multiplier: increase the input channel dimension"
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels = [], []
block_count = 0
for conv_configs, sa_configs in sa_blocks_config:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks): # pconv is repeated
attention = (block_count+1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
if block_count == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule # always not-none
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
block_count += 1
# XH: double the channel for concat, or add additional channel for cross attention
if block_count < len(sa_blocks_config):
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel)
else:
# no cross attention before the self attention module
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier)
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0]) # first pconv is repeated ?
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1, extra_in_channel=0):
"""
:param fp_blocks:
:param in_channels:
:param sa_in_channels:
:param embed_dim:
:param use_att:
:param dropout:
:param with_se:
:param normalize:
:param eps:
:param width_multiplier:
:param voxel_resolution_multiplier:
:param in_ch_multiplier: increase the input channel dimension
:return:
"""
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
if fp_idx > 0:
# to handle additional channel from concatenating human + object features
sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel)
else:
sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention
fp_blocks.append(
PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim,
out_channels=out_channels)
) # interpolate + Conv1d, does not change number of points
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels # this should not change!
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels
def get_timestep_embedding(embed_dim, timesteps, device):
"""
Timestep embedding function. Not that this should work just as well for
continuous values as for discrete values.
"""
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embed_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], embed_dim])
return emb