# -*- coding: utf-8 -*- # # @File: pt_v3.py # @Author: Xiaoyang Wu # @Date: 2024-04-01 16:31:36 # @Last Modified by: Haozhe Xie # @Last Modified at: 2024-05-15 22:05:09 # @Email: root@haozhexie.com # Ref: # - https://github.com/Pointcept/PointTransformerV3/blob/main/model.py # - https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py import addict import collections import functools import flash_attn import math import torch import spconv.pytorch as spconv import torch_scatter import typing @torch.inference_mode() def offset2bincount(offset): return torch.diff( offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) ) @torch.inference_mode() def offset2batch(offset): bincount = offset2bincount(offset) return torch.arange( len(bincount), device=offset.device, dtype=torch.long ).repeat_interleave(bincount) @torch.inference_mode() def batch2offset(batch): return torch.cumsum(batch.bincount(), dim=0).long() class KeyLUT: def __init__(self): r256 = torch.arange(256, dtype=torch.int64) r512 = torch.arange(512, dtype=torch.int64) zero = torch.zeros(256, dtype=torch.int64) device = torch.device("cpu") self._encode = { device: ( self.xyz2key(r256, zero, zero, 8), self.xyz2key(zero, r256, zero, 8), self.xyz2key(zero, zero, r256, 8), ) } self._decode = {device: self.key2xyz(r512, 9)} def encode_lut(self, device=torch.device("cpu")): if device not in self._encode: cpu = torch.device("cpu") self._encode[device] = tuple(e.to(device) for e in self._encode[cpu]) return self._encode[device] def decode_lut(self, device=torch.device("cpu")): if device not in self._decode: cpu = torch.device("cpu") self._decode[device] = tuple(e.to(device) for e in self._decode[cpu]) return self._decode[device] def xyz2key(self, x, y, z, depth): key = torch.zeros_like(x) for i in range(depth): mask = 1 << i key = ( key | ((x & mask) << (2 * i + 2)) | ((y & mask) << (2 * i + 1)) | ((z & mask) << (2 * i + 0)) ) return key def key2xyz(self, key, depth): x = torch.zeros_like(key) y = torch.zeros_like(key) z = torch.zeros_like(key) for i in range(depth): x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2)) y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1)) z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0)) return x, y, z class Serializator: def encode(self, grid_coord, grid_size=0.01, batch=None, depth=16, order="cord"): assert order in {"cord", "z", "z-trans", "hilbert", "hilbert-trans"} if order in ["z", "z-trans"]: self.key_lut = KeyLUT() if order == "cord": code = self.cord_encode(grid_coord, grid_size) elif order == "z": code = self.z_order_encode(grid_coord, depth=depth) elif order == "z-trans": code = self.z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth) elif order == "hilbert": code = self.hilbert_encode(grid_coord, depth=depth) elif order == "hilbert-trans": code = self.hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth) else: raise NotImplementedError if batch is not None: batch = batch.long() code = batch << depth * 3 | code return code def cord_encode(self, grid_coord: torch.Tensor, grid_size: float): x, y, z = ( grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long(), ) # we block the support to batch, maintain batched code in Point class code = x / grid_size**2 + y / grid_size + z return code.long() def z_order_encode(self, grid_coord: torch.Tensor, depth: int = 16): x, y, z = ( grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long(), ) # we block the support to batch, maintain batched code in Point class code = self._xyz2key(x, y, z, b=None, depth=depth) return code def _xyz2key( self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, b: typing.Optional[typing.Union[torch.Tensor, int]] = None, depth: int = 16, ): r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys based on pre-computed look up tables. The speed of this function is much faster than the method based on for-loop. Args: x (torch.Tensor): The x coordinate. y (torch.Tensor): The y coordinate. z (torch.Tensor): The z coordinate. b (torch.Tensor or int): The batch index of the coordinates, and should be smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`. depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17). """ EX, EY, EZ = self.key_lut.encode_lut(x.device) x, y, z = x.long(), y.long(), z.long() mask = 255 if depth > 8 else (1 << depth) - 1 key = EX[x & mask] | EY[y & mask] | EZ[z & mask] if depth > 8: mask = (1 << (depth - 8)) - 1 key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask] key = key16 << 24 | key if b is not None: b = b.long() key = b << 48 | key return key def hilbert_encode(self, grid_coord: torch.Tensor, depth: int = 16): return self._hilbert_encode(grid_coord, num_dims=3, num_bits=depth) def _hilbert_encode(self, locs, num_dims, num_bits): """Decode an array of locations in a hypercube into a Hilbert integer. This is a vectorized-ish version of the Hilbert curve implementation by John Skilling as described in: Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. Params: ------- locs - An ndarray of locations in a hypercube of num_dims dimensions, in which each dimension runs from 0 to 2**num_bits-1. The shape can be arbitrary, as long as the last dimension of the same has size num_dims. num_dims - The dimensionality of the hypercube. Integer. num_bits - The number of bits for each dimension. Integer. Returns: -------- The output is an ndarray of uint64 integers with the same shape as the input, excluding the last dimension, which needs to be num_dims. """ # Keep around the original shape for later. orig_shape = locs.shape bitpack_mask = 1 << torch.arange(0, 8).to(locs.device) bitpack_mask_rev = bitpack_mask.flip(-1) if orig_shape[-1] != num_dims: raise ValueError( """ The shape of locs was surprising in that the last dimension was of size %d, but num_dims=%d. These need to be equal. """ % (orig_shape[-1], num_dims) ) if num_dims * num_bits > 63: raise ValueError( """ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded into a int64. Are you sure you need that many points on your Hilbert curve? """ % (num_dims, num_bits, num_dims * num_bits) ) # Treat the location integers as 64-bit unsigned and then split them up into # a sequence of uint8s. Preserve the association by dimension. locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1) # Now turn these into bits and truncate to num_bits. gray = ( locs_uint8.unsqueeze(-1) .bitwise_and(bitpack_mask_rev) .ne(0) .byte() .flatten(-2, -1)[..., -num_bits:] ) # Run the decoding process the other way. # Iterate forwards through the bits. for bit in range(0, num_bits): # Iterate forwards through the dimensions. for dim in range(0, num_dims): # Identify which ones have this bit active. mask = gray[:, dim, bit] # Where this bit is on, invert the 0 dimension for lower bits. gray[:, 0, bit + 1 :] = torch.logical_xor( gray[:, 0, bit + 1 :], mask[:, None] ) # Where the bit is off, exchange the lower bits with the 0 dimension. to_flip = torch.logical_and( torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1), torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]), ) gray[:, dim, bit + 1 :] = torch.logical_xor( gray[:, dim, bit + 1 :], to_flip ) gray[:, 0, bit + 1 :] = torch.logical_xor( gray[:, 0, bit + 1 :], to_flip ) # Now flatten out. # Fix: shape '[-1, 0]' is invalid for input of size 192 gray = gray.swapaxes(1, 2).reshape((gray.size(0), -1)) # Convert Gray back to binary. hh_bin = self._gray2binary(gray) # Pad back out to 64 bits. extra_dims = 64 - gray.size(1) padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0) # Convert binary values into uint8s. hh_uint8 = ( (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask) .sum(2) .squeeze() .type(torch.uint8) ) # Convert uint8s into uint64s. hh_uint64 = hh_uint8.view(torch.int64).squeeze() return hh_uint64 def _gray2binary(self, gray, axis=-1): """Convert an array of Gray codes back into binary values. Parameters: ----------- gray: An ndarray of gray codes. axis: The axis along which to perform Gray decoding. Default=-1. Returns: -------- Returns an ndarray of binary values. """ # Loop the log2(bits) number of times necessary, with shift and xor. shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1) while shift > 0: gray = torch.logical_xor(gray, self._right_shift(gray, shift)) shift = torch.div(shift, 2, rounding_mode="floor") return gray def _right_shift(self, binary, k=1, axis=-1): """Right shift an array of binary values. Parameters: ----------- binary: An ndarray of binary values. k: The number of bits to shift. Default 1. axis: The axis along which to shift. Default -1. Returns: -------- Returns an ndarray with zero prepended and the ends truncated, along whatever axis was specified.""" # If we're shifting the whole thing, just return zeros. if binary.shape[axis] <= k: return torch.zeros_like(binary) # Determine the padding pattern. # padding = [(0,0)] * len(binary.shape) # padding[axis] = (k,0) # Determine the slicing pattern to eliminate just the last one. slicing = [slice(None)] * len(binary.shape) slicing[axis] = slice(None, -k) shifted = torch.nn.functional.pad( binary[tuple(slicing)], (k, 0), mode="constant", value=0 ) return shifted class PointModule(torch.nn.Module): r"""PointModule placeholder, all module subclass from this will take Point in PointSequential. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class Point(addict.Dict): """ Point Structure of Pointcept A Point (point cloud) in Pointcept is a dictionary that contains various properties of a batched point cloud. The property with the following names have a specific definition as follows: - "coord": original coordinate of point cloud; - "grid_coord": grid coordinate for specific grid size (related to GridSampling); Point also support the following optional attributes: - "offset": if not exist, initialized as batch size is 1; - "batch": if not exist, initialized as batch size is 1; - "feat": feature of point cloud, default input of model; - "grid_size": Grid size of point cloud (related to GridSampling); (related to Serialization) - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; - "serialized_code": a list of serialization codes; - "serialized_order": a list of serialization order determined by code; - "serialized_inverse": a list of inverse mapping determined by code; (related to Sparsify: SpConv) - "sparse_shape": Sparse shape for Sparse Conv Tensor; - "sparse_conv_feat": SparseConvTensor init with information provide by Point; """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.serializator = Serializator() # If one of "offset" or "batch" do not exist, generate by the existing one if "batch" not in self.keys() and "offset" in self.keys(): self["batch"] = offset2batch(self.offset) elif "offset" not in self.keys() and "batch" in self.keys(): self["offset"] = batch2offset(self.batch) def serialization(self, order="z", depth=None, shuffle_orders=False): """ Point Cloud Serialization relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] """ assert "batch" in self.keys() if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if depth is None: # Adaptive measure the depth of serialization cube (length = 2 ^ depth) depth = int(self.grid_coord.max()).bit_length() self["serialized_depth"] = depth # Maximum bit length for serialization code is 63 (int64) assert depth * 3 + len(self.offset).bit_length() <= 63 # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 # cube with a grid size of 0.01 meter. We consider it is enough for the current stage. # We can unlock the limitation by optimizing the z-order encoding function if necessary. assert depth <= 16 # The serialization codes are arranged as following structures: # [Order1 ([n]), # Order2 ([n]), # ... # OrderN ([n])] (k, n) code = [ self.serializator.encode( self.grid_coord, self.grid_size, self.batch, depth, order=order_ ) for order_ in order ] code = torch.stack(code) order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat( code.shape[0], 1 ), ) if shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] self["serialized_code"] = code self["serialized_order"] = order self["serialized_inverse"] = inverse def sparsify(self, pad=96): """ Point Cloud Serialization Point cloud is sparse, here we use "sparsify" to specifically refer to preparing "spconv.SparseConvTensor" for SpConv. relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] pad: padding sparse for sparse shape. """ assert {"feat", "batch"}.issubset(self.keys()) if "grid_coord" not in self.keys(): # if you don't want to operate GridSampling in data augmentation, # please add the following augmentation into your pipline: # dict(type="Copy", keys_dict={"grid_size": 0.01}), # (adjust `grid_size` to what your want) assert {"grid_size", "coord"}.issubset(self.keys()) self["grid_coord"] = torch.div( self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" ).int() if "sparse_shape" in self.keys(): sparse_shape = self.sparse_shape else: sparse_shape = torch.add( torch.max(self.grid_coord, dim=0).values, pad ).tolist() sparse_conv_feat = spconv.SparseConvTensor( features=self.feat, indices=torch.cat( [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 ).contiguous(), spatial_shape=sparse_shape, batch_size=self.batch[-1].tolist() + 1, ) self["sparse_shape"] = sparse_shape self["sparse_conv_feat"] = sparse_conv_feat class PointSequential(PointModule): r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in. """ def __init__(self, name="", *args, **kwargs): super().__init__() self.name = name if len(args) == 1 and isinstance(args[0], collections.OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) for name, module in kwargs.items(): if name in self._modules: raise ValueError("name exists.") self.add_module(name, module) def __getitem__(self, idx): if not (-len(self) <= idx < len(self)): raise IndexError("index {} is out of range".format(idx)) if idx < 0: idx += len(self) it = iter(self._modules.values()) for i in range(idx): next(it) return next(it) def __len__(self): return len(self._modules) def add(self, module, name=None): if name is None: name = str(len(self._modules)) if name in self._modules: raise KeyError("name exists") self.add_module(name, module) def forward(self, x): for module in self._modules.values(): # Point module if isinstance(module, PointModule): x = module(x) # Spconv module elif spconv.modules.is_spconv_module(module): if isinstance(x, Point): x.sparse_conv_feat = module(x.sparse_conv_feat) x.feat = x.sparse_conv_feat.features else: x = module(x) # Fix: Expected more than 1 value per channel when training elif isinstance(module, torch.nn.BatchNorm1d) and isinstance(x, Point): if x.feat.size(0) != 1: x.feat = module(x.feat) # PyTorch module else: if isinstance(x, Point): x.feat = module(x.feat) if "sparse_conv_feat" in x.keys(): x.sparse_conv_feat = x.sparse_conv_feat.replace_feature(x.feat) elif isinstance(x, spconv.SparseConvTensor): if x.indices.shape[0] != 0: x = x.replace_feature(module(x.features)) else: x = module(x) return x class PDNorm(PointModule): def __init__( self, num_features, norm_layer, context_channels=256, conditions=("ScanNet", "S3DIS", "Structured3D"), decouple=True, adaptive=False, ): super().__init__() self.conditions = conditions self.decouple = decouple self.adaptive = adaptive if self.decouple: self.norm = torch.nn.ModuleList( [norm_layer(num_features) for _ in conditions] ) else: self.norm = norm_layer if self.adaptive: self.modulation = torch.nn.Sequential( torch.nn.SiLU(), torch.nn.Linear(context_channels, 2 * num_features, bias=True), ) def forward(self, point): assert {"feat", "condition"}.issubset(point.keys()) if isinstance(point.condition, str): condition = point.condition else: condition = point.condition[0] if self.decouple: assert condition in self.conditions norm = self.norm[self.conditions.index(condition)] else: norm = self.norm point.feat = norm(point.feat) if self.adaptive: assert "context" in point.keys() shift, scale = self.modulation(point.context).chunk(2, dim=1) point.feat = point.feat * (1.0 + scale) + shift return point class RPE(torch.nn.Module): def __init__(self, patch_size, num_heads): super().__init__() self.patch_size = patch_size self.num_heads = num_heads self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) self.rpe_num = 2 * self.pos_bnd + 1 self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) def forward(self, coord): idx = ( coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd + self.pos_bnd # relative position to positive index + torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride ) out = self.rpe_table.index_select(0, idx.reshape(-1)) out = out.view(idx.shape + (-1,)).sum(3) out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) return out class SerializedAttention(PointModule): def __init__( self, channels, num_heads, patch_size, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, order_index=0, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() assert channels % num_heads == 0 self.channels = channels self.num_heads = num_heads self.scale = qk_scale or (channels // num_heads) ** -0.5 self.order_index = order_index self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.enable_rpe = enable_rpe self.enable_flash = enable_flash if enable_flash: assert ( enable_rpe is False ), "Set enable_rpe to False when enable Flash Attention" assert ( upcast_attention is False ), "Set upcast_attention to False when enable Flash Attention" assert ( upcast_softmax is False ), "Set upcast_softmax to False when enable Flash Attention" assert flash_attn is not None, "Make sure flash_attn is installed." self.patch_size = patch_size self.attn_drop = attn_drop else: # when disable flash attention, we still don't want to use mask # consequently, patch size will auto set to the # min number of patch_size_max and number of points self.patch_size_max = patch_size self.patch_size = 0 self.attn_drop = torch.nn.Dropout(attn_drop) self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) self.proj = torch.nn.Linear(channels, channels) self.proj_drop = torch.nn.Dropout(proj_drop) self.softmax = torch.nn.Softmax(dim=-1) self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None @torch.no_grad() def get_rel_pos(self, point, order): K = self.patch_size rel_pos_key = f"rel_pos_{self.order_index}" if rel_pos_key not in point.keys(): grid_coord = point.grid_coord[order] grid_coord = grid_coord.reshape(-1, K, 3) point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) return point[rel_pos_key] @torch.no_grad() def get_padding_and_inverse(self, point): pad_key = "pad" unpad_key = "unpad" cu_seqlens_key = "cu_seqlens_key" if ( pad_key not in point.keys() or unpad_key not in point.keys() or cu_seqlens_key not in point.keys() ): offset = point.offset bincount = offset2bincount(offset) bincount_pad = ( torch.div( bincount + self.patch_size - 1, self.patch_size, rounding_mode="trunc", ) * self.patch_size ) # only pad point when num of points larger than patch_size mask_pad = bincount > self.patch_size bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad _offset = torch.nn.functional.pad(offset, (1, 0)) _offset_pad = torch.nn.functional.pad( torch.cumsum(bincount_pad, dim=0), (1, 0) ) pad = torch.arange(_offset_pad[-1], device=offset.device) unpad = torch.arange(_offset[-1], device=offset.device) cu_seqlens = [] for i in range(len(offset)): unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] if bincount[i] != bincount_pad[i]: pad[ _offset_pad[i + 1] - self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] ] = pad[ _offset_pad[i + 1] - 2 * self.patch_size + (bincount[i] % self.patch_size) : _offset_pad[i + 1] - self.patch_size ] pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] cu_seqlens.append( torch.arange( _offset_pad[i], _offset_pad[i + 1], step=self.patch_size, dtype=torch.int32, device=offset.device, ) ) point[pad_key] = pad point[unpad_key] = unpad point[cu_seqlens_key] = torch.nn.functional.pad( torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] ) return point[pad_key], point[unpad_key], point[cu_seqlens_key] def forward(self, point): if not self.enable_flash: self.patch_size = min( offset2bincount(point.offset).min().tolist(), self.patch_size_max ) H = self.num_heads K = self.patch_size C = self.channels pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) order = point.serialized_order[self.order_index][pad] inverse = unpad[point.serialized_inverse[self.order_index]] # padding and reshape feat and batch for serialized point patch qkv = self.qkv(point.feat)[order] if not self.enable_flash: # encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') q, k, v = ( qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) ) # attn if self.upcast_attention: q = q.float() k = k.float() attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) if self.enable_rpe: attn = attn + self.rpe(self.get_rel_pos(point, order)) if self.upcast_softmax: attn = attn.float() attn = self.softmax(attn) attn = self.attn_drop(attn).to(qkv.dtype) feat = (attn @ v).transpose(1, 2).reshape(-1, C) else: feat = flash_attn.flash_attn_varlen_qkvpacked_func( qkv.half().reshape(-1, 3, H, C // H), cu_seqlens, max_seqlen=self.patch_size, dropout_p=self.attn_drop if self.training else 0, softmax_scale=self.scale, ).reshape(-1, C) feat = feat.to(qkv.dtype) feat = feat[inverse] # ffn feat = self.proj(feat) feat = self.proj_drop(feat) point.feat = feat return point class MLP(torch.nn.Module): def __init__( self, in_channels, hidden_channels=None, out_channels=None, act_layer=torch.nn.GELU, drop=0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.fc1 = torch.nn.Linear(in_channels, hidden_channels) self.act = act_layer() self.fc2 = torch.nn.Linear(hidden_channels, out_channels) # self.drop = torch.nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) # x = self.drop(x) x = self.fc2(x) # x = self.drop(x) return x class Block(PointModule): def __init__( self, channels, num_heads, patch_size=48, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.0, norm_layer=torch.nn.LayerNorm, act_layer=torch.nn.GELU, pre_norm=True, order_index=0, cpe_indice_key=None, enable_rpe=False, enable_flash=True, upcast_attention=True, upcast_softmax=True, ): super().__init__() self.channels = channels self.pre_norm = pre_norm self.cpe = PointSequential( spconv.SubMConv3d( channels, channels, kernel_size=3, bias=True, indice_key=cpe_indice_key, ), torch.nn.Linear(channels, channels), norm_layer(channels), ) self.norm1 = PointSequential(norm_layer(channels)) self.attn = SerializedAttention( channels=channels, patch_size=patch_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, order_index=order_index, enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ) self.norm2 = PointSequential(norm_layer(channels)) self.mlp = PointSequential( MLP( in_channels=channels, hidden_channels=int(channels * mlp_ratio), out_channels=channels, act_layer=act_layer, drop=proj_drop, ) ) self.drop_path = PointSequential( DropPath(drop_path) if drop_path > 0.0 else torch.nn.Identity() ) def forward(self, point: Point): shortcut = point.feat point = self.cpe(point) point.feat = shortcut + point.feat shortcut = point.feat if self.pre_norm: point = self.norm1(point) point = self.drop_path(self.attn(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm1(point) shortcut = point.feat if self.pre_norm: point = self.norm2(point) point = self.drop_path(self.mlp(point)) point.feat = shortcut + point.feat if not self.pre_norm: point = self.norm2(point) point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) return point class DropPath(torch.nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep def _drop_path( self, x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True, ): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0 and scale_by_keep: random_tensor.div_(keep_prob) return x * random_tensor def forward(self, x): return self._drop_path(x, self.drop_prob, self.training, self.scale_by_keep) class SerializedPooling(PointModule): def __init__( self, in_channels, out_channels, stride=2, norm_layer=None, act_layer=None, reduce="max", shuffle_orders=True, traceable=True, # record parent and cluster ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 # TODO: add support to grid pool (any stride) self.stride = stride assert reduce in ["sum", "mean", "min", "max"] self.reduce = reduce self.shuffle_orders = shuffle_orders self.traceable = traceable self.proj = torch.nn.Linear(in_channels, out_channels) if norm_layer is not None: self.norm = PointSequential(norm_layer(out_channels)) if act_layer is not None: self.act = PointSequential(act_layer()) def forward(self, point: Point): pooling_depth = (math.ceil(self.stride) - 1).bit_length() if pooling_depth > point.serialized_depth: pooling_depth = 0 assert { "serialized_code", "serialized_order", "serialized_inverse", "serialized_depth", }.issubset( point.keys() ), "Run point.serialization() point cloud before SerializedPooling" code = point.serialized_code >> pooling_depth * 3 code_, cluster, counts = torch.unique( code[0], sorted=True, return_inverse=True, return_counts=True, ) # indices of point sorted by cluster, for torch_scatter.segment_csr _, indices = torch.sort(cluster) # index pointer for sorted point, for torch_scatter.segment_csr idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) # head_indices of each cluster, for reduce attr e.g. code, batch head_indices = indices[idx_ptr[:-1]] # generate down code, order, inverse code = code[:, head_indices] order = torch.argsort(code) inverse = torch.zeros_like(order).scatter_( dim=1, index=order, src=torch.arange(0, code.shape[1], device=order.device).repeat( code.shape[0], 1 ), ) if self.shuffle_orders: perm = torch.randperm(code.shape[0]) code = code[perm] order = order[perm] inverse = inverse[perm] # collect information point_dict = addict.Dict( feat=torch_scatter.segment_csr( self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce ), coord=torch_scatter.segment_csr( point.coord[indices], idx_ptr, reduce="mean" ), grid_coord=point.grid_coord[head_indices] >> pooling_depth, serialized_code=code, serialized_order=order, serialized_inverse=inverse, serialized_depth=point.serialized_depth - pooling_depth, batch=point.batch[head_indices], ) if "condition" in point.keys(): point_dict["condition"] = point.condition if "context" in point.keys(): point_dict["context"] = point.context if self.traceable: point_dict["pooling_inverse"] = cluster point_dict["pooling_parent"] = point point = Point(point_dict) # Fix: Expected more than 1 value per channel when training if self.norm is not None and point.feat.size(0) != 1: point = self.norm(point) if self.act is not None: point = self.act(point) point.sparsify() return point class SerializedUnpooling(PointModule): def __init__( self, in_channels, skip_channels, out_channels, norm_layer=None, act_layer=None, traceable=False, # record parent and cluster ): super().__init__() self.proj = PointSequential(torch.nn.Linear(in_channels, out_channels)) self.proj_skip = PointSequential(torch.nn.Linear(skip_channels, out_channels)) if norm_layer is not None: self.proj.add(norm_layer(out_channels)) self.proj_skip.add(norm_layer(out_channels)) if act_layer is not None: self.proj.add(act_layer()) self.proj_skip.add(act_layer()) self.traceable = traceable def forward(self, point): assert "pooling_parent" in point.keys() assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") point = self.proj(point) parent = self.proj_skip(parent) parent.feat = parent.feat + point.feat[inverse] if self.traceable: parent["unpooling_parent"] = point return parent class Embedding(PointModule): def __init__( self, in_channels, embed_channels, norm_layer=None, act_layer=None, ): super().__init__() self.in_channels = in_channels self.embed_channels = embed_channels # TODO: check remove spconv self.stem = PointSequential( conv=spconv.SubMConv3d( in_channels, embed_channels, kernel_size=5, padding=1, bias=False, indice_key="stem", ) ) if norm_layer is not None: self.stem.add(norm_layer(embed_channels), name="norm") if act_layer is not None: self.stem.add(act_layer(), name="act") def forward(self, point: Point): point = self.stem(point) return point class PointTransformerV3(PointModule): def __init__( self, in_channels=6, order=("cord"), stride=(2, 2, 2, 2), enc_depths=(2, 2, 2, 6, 2), enc_channels=(32, 64, 128, 256, 512), enc_num_head=(2, 4, 8, 16, 32), enc_patch_size=(1024, 1024, 1024, 1024, 1024), dec_depths=(2, 2, 2, 2), dec_channels=(64, 64, 128, 256), dec_num_head=(4, 4, 8, 16), dec_patch_size=(1024, 1024, 1024, 1024), mlp_ratio=4, grid_size=0.01, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, drop_path=0.3, pre_norm=True, shuffle_orders=True, enable_rpe=False, enable_flash=True, upcast_attention=False, upcast_softmax=False, cls_mode=False, pdnorm_bn=False, pdnorm_ln=False, pdnorm_decouple=True, pdnorm_adaptive=False, pdnorm_affine=True, pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), ): super().__init__() self.num_stages = len(enc_depths) self.order = [order] if isinstance(order, str) else order self.cls_mode = cls_mode self.shuffle_orders = shuffle_orders self.grid_size = grid_size assert self.num_stages == len(stride) + 1 assert self.num_stages == len(enc_depths) assert self.num_stages == len(enc_channels) assert self.num_stages == len(enc_num_head) assert self.num_stages == len(enc_patch_size) assert self.cls_mode or self.num_stages == len(dec_depths) + 1 assert self.cls_mode or self.num_stages == len(dec_channels) + 1 assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 # norm layers if pdnorm_bn: bn_layer = functools.partial( PDNorm, norm_layer=functools.partial( torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine ), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: bn_layer = functools.partial(torch.nn.BatchNorm1d, eps=1e-3, momentum=0.01) if pdnorm_ln: ln_layer = functools.partial( PDNorm, norm_layer=functools.partial( torch.nn.LayerNorm, elementwise_affine=pdnorm_affine ), conditions=pdnorm_conditions, decouple=pdnorm_decouple, adaptive=pdnorm_adaptive, ) else: ln_layer = torch.nn.LayerNorm # activation layers act_layer = torch.nn.GELU self.embedding = Embedding( in_channels=in_channels, embed_channels=enc_channels[0], norm_layer=bn_layer, act_layer=act_layer, ) # encoder enc_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) ] self.enc = PointSequential(name="encoder") for s in range(self.num_stages): enc_drop_path_ = enc_drop_path[ sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) ] enc = PointSequential(name="encoder_layer_%d" % s) if s > 0: enc.add( SerializedPooling( in_channels=enc_channels[s - 1], out_channels=enc_channels[s], stride=stride[s - 1], norm_layer=bn_layer, act_layer=act_layer, ), name="down", ) for i in range(enc_depths[s]): enc.add( Block( channels=enc_channels[s], num_heads=enc_num_head[s], patch_size=enc_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=enc_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) if len(enc) != 0: self.enc.add(module=enc, name=f"enc{s}") # decoder if not self.cls_mode: dec_drop_path = [ x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) ] self.dec = PointSequential(name="decoder") dec_channels = list(dec_channels) + [enc_channels[-1]] for s in reversed(range(self.num_stages - 1)): dec_drop_path_ = dec_drop_path[ sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) ] dec_drop_path_.reverse() dec = PointSequential(name="decoder_layer_%d" % s) dec.add( SerializedUnpooling( in_channels=dec_channels[s + 1], skip_channels=enc_channels[s], out_channels=dec_channels[s], norm_layer=bn_layer, act_layer=act_layer, ), name="up", ) for i in range(dec_depths[s]): dec.add( Block( channels=dec_channels[s], num_heads=dec_num_head[s], patch_size=dec_patch_size[s], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop, drop_path=dec_drop_path_[i], norm_layer=ln_layer, act_layer=act_layer, pre_norm=pre_norm, order_index=i % len(self.order), cpe_indice_key=f"stage{s}", enable_rpe=enable_rpe, enable_flash=enable_flash, upcast_attention=upcast_attention, upcast_softmax=upcast_softmax, ), name=f"block{i}", ) self.dec.add(module=dec, name=f"dec{s}") def forward(self, batch, feat, coord): """ A data_dict is a dictionary containing properties of a batched point cloud. It should contain the following properties for PTv3: 1. "feat": feature of point cloud 2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size" 3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset """ point = Point( { "batch": batch.squeeze(dim=0), "feat": feat.squeeze(dim=0), "coord": coord.squeeze(dim=0), "grid_size": self.grid_size, } ) point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) point.sparsify() point = self.embedding(point) point = self.enc(point) if not self.cls_mode: point = self.dec(point) return point.feat.unsqueeze(dim=0)