import functools import torch import torch.nn as nn import numpy as np from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish def _linear_gn_relu(in_channels, out_channels): return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish()) def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1): r = width_multiplier if dim == 1: block = _linear_gn_relu else: block = SharedMLP if not isinstance(out_channels, (list, tuple)): out_channels = [out_channels] if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None): return nn.Sequential(), in_channels, in_channels layers = [] for oc in out_channels[:-1]: if oc < 1: layers.append(nn.Dropout(oc)) else: oc = int(r * oc) layers.append(block(in_channels, oc)) in_channels = oc if dim == 1: if classifier: layers.append(nn.Linear(in_channels, out_channels[-1])) else: layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1]))) else: if classifier: layers.append(nn.Conv1d(in_channels, out_channels[-1], 1)) else: layers.append(SharedMLP(in_channels, int(r * out_channels[-1]))) return layers, out_channels[-1] if classifier else int(r * out_channels[-1]) def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0, width_multiplier=1, voxel_resolution_multiplier=1): r, vr = width_multiplier, voxel_resolution_multiplier layers, concat_channels = [], 0 c = 0 for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks): out_channels = int(r * out_channels) for p in range(num_blocks): attention = k % 2 == 0 and k > 0 and p == 0 if voxel_resolution is None: block = SharedMLP else: block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, with_se=with_se, normalize=normalize, eps=eps) if c == 0: layers.append(block(in_channels, out_channels)) else: layers.append(block(in_channels+embed_dim, out_channels)) in_channels = out_channels concat_channels += out_channels c += 1 return layers, in_channels, concat_channels def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False, dropout=0.1, with_se=False, normalize=True, eps=0, width_multiplier=1, voxel_resolution_multiplier=1, in_ch_multiplier=1, extra_in_channel=0): "use_att is True by default, in_ch_multiplier: increase the input channel dimension" r, vr = width_multiplier, voxel_resolution_multiplier in_channels = extra_feature_channels + 3 sa_layers, sa_in_channels = [], [] block_count = 0 for conv_configs, sa_configs in sa_blocks_config: k = 0 sa_in_channels.append(in_channels) sa_blocks = [] if conv_configs is not None: out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) for p in range(num_blocks): # pconv is repeated attention = (block_count+1) % 2 == 0 and use_att and p == 0 if voxel_resolution is None: block = SharedMLP else: block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, dropout=dropout, with_se=with_se, with_se_relu=True, normalize=normalize, eps=eps) if block_count == 0: sa_blocks.append(block(in_channels, out_channels)) elif k ==0: sa_blocks.append(block(in_channels+embed_dim, out_channels)) in_channels = out_channels k += 1 extra_feature_channels = in_channels num_centers, radius, num_neighbors, out_channels = sa_configs _out_channels = [] for oc in out_channels: if isinstance(oc, (list, tuple)): _out_channels.append([int(r * _oc) for _oc in oc]) else: _out_channels.append(int(r * oc)) out_channels = _out_channels if num_centers is None: block = PointNetAModule # always not-none else: block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius, num_neighbors=num_neighbors) sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels, include_coordinates=True)) block_count += 1 # XH: double the channel for concat, or add additional channel for cross attention if block_count < len(sa_blocks_config): in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel) else: # no cross attention before the self attention module in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier) if len(sa_blocks) == 1: sa_layers.append(sa_blocks[0]) # first pconv is repeated ? else: sa_layers.append(nn.Sequential(*sa_blocks)) return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False, dropout=0.1, with_se=False, normalize=True, eps=0, width_multiplier=1, voxel_resolution_multiplier=1, in_ch_multiplier=1, extra_in_channel=0): """ :param fp_blocks: :param in_channels: :param sa_in_channels: :param embed_dim: :param use_att: :param dropout: :param with_se: :param normalize: :param eps: :param width_multiplier: :param voxel_resolution_multiplier: :param in_ch_multiplier: increase the input channel dimension :return: """ r, vr = width_multiplier, voxel_resolution_multiplier fp_layers = [] c = 0 for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks): fp_blocks = [] out_channels = tuple(int(r * oc) for oc in fp_configs) if fp_idx > 0: # to handle additional channel from concatenating human + object features sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel) else: sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention fp_blocks.append( PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim, out_channels=out_channels) ) # interpolate + Conv1d, does not change number of points in_channels = out_channels[-1] if conv_configs is not None: out_channels, num_blocks, voxel_resolution = conv_configs out_channels = int(r * out_channels) for p in range(num_blocks): attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0 if voxel_resolution is None: block = SharedMLP else: block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention, dropout=dropout, with_se=with_se, with_se_relu=True, normalize=normalize, eps=eps) fp_blocks.append(block(in_channels, out_channels)) in_channels = out_channels # this should not change! if len(fp_blocks) == 1: fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer else: fp_layers.append(nn.Sequential(*fp_blocks)) c += 1 return fp_layers, in_channels def get_timestep_embedding(embed_dim, timesteps, device): """ Timestep embedding function. Not that this should work just as well for continuous values as for discrete values. """ assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 half_dim = embed_dim // 2 emb = np.log(10000) / (half_dim - 1) emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device) emb = timesteps[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embed_dim % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), "constant", 0) assert emb.shape == torch.Size([timesteps.shape[0], embed_dim]) return emb