Spaces:
Sleeping
Sleeping
File size: 9,459 Bytes
2fd6166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import functools
import torch
import torch.nn as nn
import numpy as np
from model.pvcnn.modules import SharedMLP, PVConv, PointNetSAModule, PointNetAModule, PointNetFPModule, Swish
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet_components(blocks, in_channels, embed_dim, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1):
r, vr = width_multiplier, voxel_resolution_multiplier
layers, concat_channels = [], 0
c = 0
for k, (out_channels, num_blocks, voxel_resolution) in enumerate(blocks):
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = k % 2 == 0 and k > 0 and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
with_se=with_se, normalize=normalize, eps=eps)
if c == 0:
layers.append(block(in_channels, out_channels))
else:
layers.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
concat_channels += out_channels
c += 1
return layers, in_channels, concat_channels
def create_pointnet2_sa_components(sa_blocks_config, extra_feature_channels, embed_dim=64, use_att=False,
dropout=0.1, with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1,
extra_in_channel=0):
"use_att is True by default, in_ch_multiplier: increase the input channel dimension"
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + 3
sa_layers, sa_in_channels = [], []
block_count = 0
for conv_configs, sa_configs in sa_blocks_config:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks): # pconv is repeated
attention = (block_count+1) % 2 == 0 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
if block_count == 0:
sa_blocks.append(block(in_channels, out_channels))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim, out_channels))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule # always not-none
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(in_channels=extra_feature_channels+(embed_dim if k==0 else 0 ), out_channels=out_channels,
include_coordinates=True))
block_count += 1
# XH: double the channel for concat, or add additional channel for cross attention
if block_count < len(sa_blocks_config):
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier + extra_in_channel)
else:
# no cross attention before the self attention module
in_channels = extra_feature_channels = int(sa_blocks[-1].out_channels * in_ch_multiplier)
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0]) # first pconv is repeated ?
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
in_ch_multiplier=1, extra_in_channel=0):
"""
:param fp_blocks:
:param in_channels:
:param sa_in_channels:
:param embed_dim:
:param use_att:
:param dropout:
:param with_se:
:param normalize:
:param eps:
:param width_multiplier:
:param voxel_resolution_multiplier:
:param in_ch_multiplier: increase the input channel dimension
:return:
"""
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
if fp_idx > 0:
# to handle additional channel from concatenating human + object features
sa_in_concat = int(in_channels*in_ch_multiplier + extra_in_channel)
else:
sa_in_concat = in_channels + extra_in_channel # this is for simple-coord3d, where the decoder first layer also has cross attention
fp_blocks.append(
PointNetFPModule(in_channels=sa_in_concat + sa_in_channels[-1 - fp_idx] + embed_dim,
out_channels=out_channels)
) # interpolate + Conv1d, does not change number of points
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(PVConv, kernel_size=3, resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, with_se_relu=True,
normalize=normalize, eps=eps)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels # this should not change!
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0]) # this is the last block, no PVConv layer
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels
def get_timestep_embedding(embed_dim, timesteps, device):
"""
Timestep embedding function. Not that this should work just as well for
continuous values as for discrete values.
"""
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embed_dim // 2
emb = np.log(10000) / (half_dim - 1)
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embed_dim % 2 == 1: # zero pad
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
assert emb.shape == torch.Size([timesteps.shape[0], embed_dim])
return emb
|