Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
import torch | |
from .voxelization import Voxelization | |
from .shared_mlp import SharedMLP | |
from .se import SE3d | |
from . import functional as F | |
__all__ = ['PVConv', 'Attention', 'Swish', 'PVConvReLU'] | |
class Swish(nn.Module): | |
def forward(self,x): | |
return x * torch.sigmoid(x) | |
class Attention(nn.Module): | |
def __init__(self, in_ch, num_groups, D=3): | |
super(Attention, self).__init__() | |
assert in_ch % num_groups == 0 | |
# it also has some learnable parameters | |
if D == 3: | |
self.q = nn.Conv3d(in_ch, in_ch, 1) | |
self.k = nn.Conv3d(in_ch, in_ch, 1) | |
self.v = nn.Conv3d(in_ch, in_ch, 1) | |
self.out = nn.Conv3d(in_ch, in_ch, 1) | |
elif D == 1: | |
self.q = nn.Conv1d(in_ch, in_ch, 1) | |
self.k = nn.Conv1d(in_ch, in_ch, 1) | |
self.v = nn.Conv1d(in_ch, in_ch, 1) | |
self.out = nn.Conv1d(in_ch, in_ch, 1) | |
self.norm = nn.GroupNorm(num_groups, in_ch) | |
self.nonlin = Swish() | |
self.sm = nn.Softmax(-1) | |
def forward(self, x): | |
""" | |
self attention | |
reso32: Attention layer, x=torch.Size([16, 64, 16, 16, 16]), q=torch.Size([16, 64, 4096]), k=torch.Size([16, 64, 4096]), v=torch.Size([16, 64, 4096]) | |
reso48: Attention layer, x=torch.Size([16, 64, 24, 24, 24]), q=torch.Size([16, 64, 13824]), k=torch.Size([16, 64, 13824]), v=torch.Size([16, 64, 13824]) | |
# this can cause OOM! | |
:param x: (B, C, reso, reso, reso)? | |
:return: | |
""" | |
B, C = x.shape[:2] | |
h = x | |
q = self.q(h).reshape(B,C,-1) | |
k = self.k(h).reshape(B,C,-1) | |
v = self.v(h).reshape(B,C,-1) | |
qk = torch.matmul(q.permute(0, 2, 1), k) #* (int(C) ** (-0.5)) | |
w = self.sm(qk) | |
h = torch.matmul(v, w.permute(0, 2, 1)).reshape(B,C,*x.shape[2:]) | |
h = self.out(h) | |
x = h + x | |
x = self.nonlin(self.norm(x)) # group norm + swish | |
return x | |
class PVConv(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, | |
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.resolution = resolution | |
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) | |
voxel_layers = [ | |
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), | |
nn.GroupNorm(num_groups=8, num_channels=out_channels), | |
Swish() | |
] | |
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] | |
voxel_layers += [ | |
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), | |
nn.GroupNorm(num_groups=8, num_channels=out_channels), | |
Attention(out_channels, 8) if attention else Swish() | |
] | |
if with_se: | |
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) | |
self.voxel_layers = nn.Sequential(*voxel_layers) | |
self.point_features = SharedMLP(in_channels, out_channels) # this is basically an MLP | |
def forward(self, inputs): | |
features, coords, temb = inputs # features: (B, F, N), temb: sinusoidal embedding of diffusion timestaps | |
voxel_features, voxel_coords = self.voxelization(features, coords) | |
voxel_features = self.voxel_layers(voxel_features) | |
voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) | |
fused_features = voxel_features + self.point_features(features) | |
return fused_features, coords, temb # coords is not changed, and also temb | |
class PVConvReLU(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, resolution, attention=False, leak=0.2, | |
dropout=0.1, with_se=False, with_se_relu=False, normalize=True, eps=0): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.resolution = resolution | |
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps) | |
voxel_layers = [ | |
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), | |
nn.BatchNorm3d(out_channels), | |
nn.LeakyReLU(leak, True) | |
] | |
voxel_layers += [nn.Dropout(dropout)] if dropout is not None else [] | |
voxel_layers += [ | |
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2), | |
nn.BatchNorm3d(out_channels), | |
Attention(out_channels, 8) if attention else nn.LeakyReLU(leak, True) | |
] | |
if with_se: | |
voxel_layers.append(SE3d(out_channels, use_relu=with_se_relu)) | |
self.voxel_layers = nn.Sequential(*voxel_layers) | |
self.point_features = SharedMLP(in_channels, out_channels) | |
def forward(self, inputs): | |
features, coords, temb = inputs | |
voxel_features, voxel_coords = self.voxelization(features, coords) | |
voxel_features = self.voxel_layers(voxel_features) | |
voxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training) | |
fused_features = voxel_features + self.point_features(features) | |
return fused_features, coords, temb | |