File size: 13,750 Bytes
744eb4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 |
import torch.nn as nn
import torch
import torch.nn.functional as F
from . import misc
# from knn_cuda import KNN
# knn = KNN(k=4, transpose_mode=False)
class DGCNN(nn.Module):
def __init__(self, encoder_channel, output_channel):
super().__init__()
'''
K has to be 16
'''
self.input_trans = nn.Conv1d(encoder_channel, 128, 1)
self.layer1 = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, bias=False),
nn.GroupNorm(4, 256),
nn.LeakyReLU(negative_slope=0.2)
)
self.layer2 = nn.Sequential(nn.Conv2d(512, 512, kernel_size=1, bias=False),
nn.GroupNorm(4, 512),
nn.LeakyReLU(negative_slope=0.2)
)
self.layer3 = nn.Sequential(nn.Conv2d(1024, 512, kernel_size=1, bias=False),
nn.GroupNorm(4, 512),
nn.LeakyReLU(negative_slope=0.2)
)
self.layer4 = nn.Sequential(nn.Conv2d(1024, 1024, kernel_size=1, bias=False),
nn.GroupNorm(4, 1024),
nn.LeakyReLU(negative_slope=0.2)
)
self.layer5 = nn.Sequential(nn.Conv1d(2304, output_channel, kernel_size=1, bias=False),
nn.GroupNorm(4, output_channel),
nn.LeakyReLU(negative_slope=0.2)
)
@staticmethod
def get_graph_feature(coor_q, x_q, coor_k, x_k):
# coor: bs, 3, np, x: bs, c, np
k = 4
batch_size = x_k.size(0)
num_points_k = x_k.size(2)
num_points_q = x_q.size(2)
with torch.no_grad():
_, idx = knn(coor_k, coor_q) # bs k np
assert idx.shape[1] == k
idx_base = torch.arange(0, batch_size, device=x_q.device).view(-1, 1, 1) * num_points_k
idx = idx + idx_base
idx = idx.view(-1)
num_dims = x_k.size(1)
x_k = x_k.transpose(2, 1).contiguous()
feature = x_k.view(batch_size * num_points_k, -1)[idx, :]
feature = feature.view(batch_size, k, num_points_q, num_dims).permute(0, 3, 2, 1).contiguous()
x_q = x_q.view(batch_size, num_dims, num_points_q, 1).expand(-1, -1, -1, k)
feature = torch.cat((feature - x_q, x_q), dim=1)
return feature
def forward(self, f, coor):
# f: B G C
# coor: B G 3
# bs 3 N bs C N
feature_list = []
coor = coor.transpose(1, 2).contiguous() # B 3 N
f = f.transpose(1, 2).contiguous() # B C N
f = self.input_trans(f) # B 128 N
f = self.get_graph_feature(coor, f, coor, f) # B 256 N k
f = self.layer1(f) # B 256 N k
f = f.max(dim=-1, keepdim=False)[0] # B 256 N
feature_list.append(f)
f = self.get_graph_feature(coor, f, coor, f) # B 512 N k
f = self.layer2(f) # B 512 N k
f = f.max(dim=-1, keepdim=False)[0] # B 512 N
feature_list.append(f)
f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
f = self.layer3(f) # B 512 N k
f = f.max(dim=-1, keepdim=False)[0] # B 512 N
feature_list.append(f)
f = self.get_graph_feature(coor, f, coor, f) # B 1024 N k
f = self.layer4(f) # B 1024 N k
f = f.max(dim=-1, keepdim=False)[0] # B 1024 N
feature_list.append(f)
f = torch.cat(feature_list, dim=1) # B 2304 N
f = self.layer5(f) # B C' N
f = f.transpose(-1, -2)
return f
### ref https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py ###
def knn_point(nsample, xyz, new_xyz):
"""
Input:
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
return group_idx
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
# self.knn = KNN(k=self.group_size, transpose_mode=True)
def forward(self, xyz):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
B, N, C = xyz.shape
if C > 3:
data = xyz
xyz = data[:,:,:3]
rgb = data[:, :, 3:]
batch_size, num_points, _ = xyz.shape
# fps the centers out
center = misc.fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
# _, idx = self.knn(xyz, center) # B G M
idx = knn_point(self.group_size, xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood_xyz = xyz.view(batch_size * num_points, -1)[idx, :]
neighborhood_xyz = neighborhood_xyz.view(batch_size, self.num_group, self.group_size, 3).contiguous()
if C > 3:
neighborhood_rgb = rgb.view(batch_size * num_points, -1)[idx, :]
neighborhood_rgb = neighborhood_rgb.view(batch_size, self.num_group, self.group_size, -1).contiguous()
# normalize xyz
neighborhood_xyz = neighborhood_xyz - center.unsqueeze(2)
if C > 3:
neighborhood = torch.cat((neighborhood_xyz, neighborhood_rgb), dim=-1)
else:
neighborhood = neighborhood_xyz
return neighborhood, center
class Encoder(nn.Module):
def __init__(self, encoder_channel, point_input_dims=3):
super().__init__()
self.encoder_channel = encoder_channel
self.point_input_dims = point_input_dims
self.first_conv = nn.Sequential(
nn.Conv1d(self.point_input_dims, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, c = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, c)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1)) # BG 256 n
feature_global = torch.max(feature, dim=2, keepdim=True)[0] # BG 256 1
feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) # BG 512 n
feature = self.second_conv(feature) # BG 1024 n
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
return feature_global.reshape(bs, g, self.encoder_channel)
class Decoder(nn.Module):
def __init__(self, encoder_channel, num_fine):
super().__init__()
self.num_fine = num_fine
self.grid_size = 2
self.num_coarse = self.num_fine // 4
assert num_fine % 4 == 0
self.mlp = nn.Sequential(
nn.Linear(encoder_channel, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, 3 * self.num_coarse)
)
self.final_conv = nn.Sequential(
nn.Conv1d(encoder_channel + 3 + 2, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, 3, 1)
)
a = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(1, self.grid_size).expand(
self.grid_size, self.grid_size).reshape(1, -1)
b = torch.linspace(-0.05, 0.05, steps=self.grid_size, dtype=torch.float).view(self.grid_size, 1).expand(
self.grid_size, self.grid_size).reshape(1, -1)
self.folding_seed = torch.cat([a, b], dim=0).view(1, 2, self.grid_size ** 2) # 1 2 S
def forward(self, feature_global):
'''
feature_global : B G C
-------
coarse : B G M 3
fine : B G N 3
'''
bs, g, c = feature_global.shape
feature_global = feature_global.reshape(bs * g, c)
coarse = self.mlp(feature_global).reshape(bs * g, self.num_coarse, 3) # BG M 3
point_feat = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
point_feat = point_feat.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
seed = self.folding_seed.unsqueeze(2).expand(bs * g, -1, self.num_coarse, -1) # BG 2 M (S)
seed = seed.reshape(bs * g, -1, self.num_fine).to(feature_global.device) # BG 2 N
feature_global = feature_global.unsqueeze(2).expand(-1, -1, self.num_fine) # BG 1024 N
feat = torch.cat([feature_global, seed, point_feat], dim=1) # BG C N
center = coarse.unsqueeze(2).expand(-1, -1, self.grid_size ** 2, -1) # BG (M) S 3
center = center.reshape(bs * g, self.num_fine, 3).transpose(2, 1) # BG 3 N
fine = self.final_conv(feat) + center # BG 3 N
fine = fine.reshape(bs, g, 3, self.num_fine).transpose(-1, -2)
coarse = coarse.reshape(bs, g, self.num_coarse, 3)
return coarse, fine
class DiscreteVAE(nn.Module):
def __init__(self, config, **kwargs):
super().__init__()
self.group_size = config.group_size
self.num_group = config.num_group
self.encoder_dims = config.encoder_dims
self.tokens_dims = config.tokens_dims
self.decoder_dims = config.decoder_dims
self.num_tokens = config.num_tokens
self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
self.encoder = Encoder(encoder_channel=self.encoder_dims)
self.dgcnn_1 = DGCNN(encoder_channel=self.encoder_dims, output_channel=self.num_tokens)
self.codebook = nn.Parameter(torch.randn(self.num_tokens, self.tokens_dims))
self.dgcnn_2 = DGCNN(encoder_channel=self.tokens_dims, output_channel=self.decoder_dims)
self.decoder = Decoder(encoder_channel=self.decoder_dims, num_fine=self.group_size)
# self.build_loss_func()
# def build_loss_func(self):
# self.loss_func_cdl1 = ChamferDistanceL1().cuda()
# self.loss_func_cdl2 = ChamferDistanceL2().cuda()
# self.loss_func_emd = emd().cuda()
def recon_loss(self, ret, gt):
whole_coarse, whole_fine, coarse, fine, group_gt, _ = ret
bs, g, _, _ = coarse.shape
coarse = coarse.reshape(bs * g, -1, 3).contiguous()
fine = fine.reshape(bs * g, -1, 3).contiguous()
group_gt = group_gt.reshape(bs * g, -1, 3).contiguous()
loss_coarse_block = self.loss_func_cdl1(coarse, group_gt)
loss_fine_block = self.loss_func_cdl1(fine, group_gt)
loss_recon = loss_coarse_block + loss_fine_block
return loss_recon
def get_loss(self, ret, gt):
# reconstruction loss
loss_recon = self.recon_loss(ret, gt)
# kl divergence
logits = ret[-1] # B G N
softmax = F.softmax(logits, dim=-1)
mean_softmax = softmax.mean(dim=1)
log_qy = torch.log(mean_softmax)
log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=gt.device))
loss_klv = F.kl_div(log_qy, log_uniform.expand(log_qy.size(0), log_qy.size(1)), None, None, 'batchmean',
log_target=True)
return loss_recon, loss_klv
def forward(self, inp, temperature=1., hard=False, **kwargs):
neighborhood, center = self.group_divider(inp)
logits = self.encoder(neighborhood) # B G C
logits = self.dgcnn_1(logits, center) # B G N
soft_one_hot = F.gumbel_softmax(logits, tau=temperature, dim=2, hard=hard) # B G N
sampled = torch.einsum('b g n, n c -> b g c', soft_one_hot, self.codebook) # B G C
feature = self.dgcnn_2(sampled, center)
coarse, fine = self.decoder(feature)
with torch.no_grad():
whole_fine = (fine + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
whole_coarse = (coarse + center.unsqueeze(2)).reshape(inp.size(0), -1, 3)
assert fine.size(2) == self.group_size
ret = (whole_coarse, whole_fine, coarse, fine, neighborhood, logits)
return ret |