xiexh20's picture
add hdm demo v1
2fd6166
raw
history blame contribute delete
No virus
7.18 kB
import numpy as np
import torch
import torch.nn as nn
from model.pvcnn.modules import Attention
from model.pvcnn.pvcnn_utils import create_mlp_components, create_pointnet2_sa_components, create_pointnet2_fp_modules
from model.pvcnn.pvcnn_utils import get_timestep_embedding
class PVCNN2Base(nn.Module):
def __init__(
self,
num_classes: int,
embed_dim: int,
use_att: bool = True,
dropout: float = 0.1,
extra_feature_channels: int = 3,
width_multiplier: int = 1,
voxel_resolution_multiplier: int = 1
):
super().__init__()
assert extra_feature_channels >= 0
self.embed_dim = embed_dim
self.dropout = dropout
self.width_multiplier = width_multiplier
self.in_channels = extra_feature_channels + 3
# Create PointNet-2 model
sa_layers, sa_in_channels, channels_sa_features, _ = create_pointnet2_sa_components(
sa_blocks_config=self.sa_blocks,
extra_feature_channels=extra_feature_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.sa_layers = nn.ModuleList(sa_layers)
# Additional global attention module, default true
self.global_att = None if not use_att else Attention(channels_sa_features, 8, D=1)
# Only use extra features in the last fp module
sa_in_channels[0] = extra_feature_channels
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
fp_blocks=self.fp_blocks,
in_channels=channels_sa_features,
sa_in_channels=sa_in_channels,
with_se=True,
embed_dim=embed_dim,
use_att=use_att,
dropout=dropout,
width_multiplier=width_multiplier,
voxel_resolution_multiplier=voxel_resolution_multiplier
)
self.fp_layers = nn.ModuleList(fp_layers)
# Create MLP layers
self.channels_fp_features = channels_fp_features
layers, _ = create_mlp_components(
in_channels=channels_fp_features,
out_channels=[128, dropout, num_classes], # was 0.5
classifier=True,
dim=2,
width_multiplier=width_multiplier
)
self.classifier = nn.Sequential(*layers) # applied to point features directly
# Time embedding function
self.embedf = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.LeakyReLU(0.1, inplace=True),
nn.Linear(embed_dim, embed_dim),
)
def forward(self, inputs: torch.Tensor, t: torch.Tensor, ret_feats=False):
"""
The inputs have size (B, 3 + S, N), where S is the number of additional
feature channels and N is the number of points. The timesteps t can be either
continuous or discrete. This model has a sort of U-Net-like structure I think,
which is why it first goes down and then up in terms of resolution (?)
torch.Size([16, 394, 16384])
Downscaling step 0 feature shape: torch.Size([16, 64, 1024])
Downscaling step 1 feature shape: torch.Size([16, 128, 256])
Downscaling step 2 feature shape: torch.Size([16, 256, 64])
Downscaling step 3 feature shape: torch.Size([16, 512, 16])
Upscaling step 0 feature shape: torch.Size([16, 256, 64])
Upscaling step 1 feature shape: torch.Size([16, 256, 256])
Upscaling step 2 feature shape: torch.Size([16, 128, 1024])
Upscaling step 3 feature shape: torch.Size([16, 64, 16384])
"""
# Embed timesteps, sinusoidal encoding
t_emb = get_timestep_embedding(self.embed_dim, t, inputs.device).float()
t_emb = self.embedf(t_emb)[:, :, None].expand(-1, -1, inputs.shape[-1])
# Separate input coordinates and features
coords = inputs[:, :3, :].contiguous() # (B, 3, N) range (-3.5, 3.5)
features = inputs # (B, 3 + S, N)
# Downscaling layers
coords_list = []
in_features_list = []
for i, sa_blocks in enumerate(self.sa_layers):
in_features_list.append(features)
coords_list.append(coords)
if i == 0:
features, coords, t_emb = sa_blocks((features, coords, t_emb))
else:
features, coords, t_emb = sa_blocks((torch.cat([features, t_emb], dim=1), coords, t_emb))
# Replace the input features
in_features_list[0] = inputs[:, 3:, :].contiguous()
# Apply global attention layer
if self.global_att is not None:
features = self.global_att(features)
# Upscaling layers
feats_list = [] # save intermediate features from the decoder layers
for fp_idx, fp_blocks in enumerate(self.fp_layers):
features, coords, t_emb = fp_blocks(
( # this is a tuple because of nn.Sequential
coords_list[-1 - fp_idx], # reverse coords list from above
coords, # original point coordinates
torch.cat([features, t_emb], dim=1), # keep concatenating upsampled features with timesteps
in_features_list[-1 - fp_idx], # reverse features list from above
t_emb # original timestep embedding
) # this is where point voxel convolution is carried out, the point feature network preserves the order.
)
feats_list.append((features, coords)) # t_emb is always the same
# exit(0)
# Output MLP layers
output = self.classifier(features)
if ret_feats:
return output, feats_list # return intermediate features
return output
class PVCNN2(PVCNN2Base):
# exact same configuration from PVD: https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/train_completion.py#L375
# conv_configs, sa_configs
# conv_configs: (out_ch, num_blocks, voxel_reso), sa_configs: (num_centers, radius, num_neighbors, out_channels)
sa_blocks = [
((32, 2, 32), (1024, 0.1, 32, (32, 64))), # the first is out_channels, num_blocks, voxel_resolution
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 256))),
(None, (16, 0.8, 32, (256, 256, 512))),
]
fp_blocks = [
((256, 256), (256, 3, 8)),
((256, 256), (256, 3, 8)),
((256, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, num_classes, embed_dim, use_att=True, dropout=0.1, extra_feature_channels=3,
width_multiplier=1, voxel_resolution_multiplier=1):
super().__init__(
num_classes=num_classes, embed_dim=embed_dim, use_att=use_att,
dropout=dropout, extra_feature_channels=extra_feature_channels,
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier
)