xiexh20's picture
add hdm demo v1
2fd6166
import functools
import torch
import torch.nn as nn
import numpy as np
from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1,
extra_in_channel=0):
"use_att is True by default, in_ch_multiplier: increase the input channel dimension"
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels = [], []
block_count = 0
for conv_configs, sa_configs in sa_blocks_config:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks): # pconv is repeated
attention = (block_count+1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
if block_count == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule # always not-none
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
block_count += 1
# XH: double the channel for concat, or add additional channel for cross attention
if block_count < len(sa_blocks_config):
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel)
else:
# no cross attention before the self attention module
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier)
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0]) # first pconv is repeated ?
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1, extra_in_channel=0):
"""
:param fp_blocks:
:param in_channels:
:param sa_in_channels:
:param embed_dim:
:param use_att:
:param dropout:
:param with_se:
:param normalize:
:param eps:
:param width_multiplier:
:param voxel_resolution_multiplier:
:param in_ch_multiplier: increase the input channel dimension
:return:
"""
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
if fp_idx > 0:
# to handle additional channel from concatenating human + object features
sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel)
else:
sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention
fp_blocks.append(
PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim,
out_channels=out_channels)
) # interpolate + Conv1d, does not change number of points
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels # this should not change!
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels
def get_timestep_embedding(embed_dim, timesteps, device):
"""
Timestep embedding function. Not that this should work just as well for
continuous values as for discrete values.
"""
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embed_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], embed_dim])
return emb