bdck's picture
Upload point_sam/model/pc_encoder.py
ac6e542 verified
raw
history blame
6.51 kB
# https://github.com/baaivision/Uni3D/blob/main/models/point_encoder.py
from typing import Union
import timm
import torch
import torch.nn as nn
from timm.models.eva import Eva
from timm.models.vision_transformer import VisionTransformer
from .common import KNNGrouper, NNGrouper, PatchEncoder
class PatchEmbed(nn.Module):
def __init__(
self,
in_channels,
out_channels,
num_patches,
patch_size,
radius: float = None,
centralize_features=False,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.grouper = KNNGrouper(
num_patches,
patch_size,
radius=radius,
centralize_features=centralize_features,
)
self.patch_encoder = PatchEncoder(in_channels, out_channels, [128, 512])
def forward(self, coords: torch.Tensor, features: torch.Tensor):
patches = self.grouper(coords, features)
patch_features = patches["features"] # [B, L, K, C_in]
x = self.patch_encoder(patch_features)
patches["embeddings"] = x
return patches
class PatchDropout(nn.Module):
"""Randomly drop patches.
References:
- https://arxiv.org/abs/2212.00794
- `timm.layers.patch_dropout`. It uses `argsort` rather than `topk`, which might be inefficient.
"""
def __init__(self, prob, num_prefix_tokens: int = 1):
super().__init__()
assert 0.0 <= prob < 1.0, prob
self.prob = prob
# exclude CLS token (or other prefix tokens)
self.num_prefix_tokens = num_prefix_tokens
def forward(self, x: torch.Tensor):
# x: [B, L, ...]
if not self.training or self.prob == 0.0:
return x
if self.num_prefix_tokens:
prefix_tokens = x[:, : self.num_prefix_tokens]
x = x[:, self.num_prefix_tokens :]
else:
prefix_tokens = None
B, L = x.shape[:2]
num_keep = max(1, int(L * (1.0 - self.prob)))
rand = torch.randn(B, L, device=x.device)
keep_indices = rand.topk(num_keep, dim=1).indices
_keep_indices = keep_indices.reshape((B, num_keep) + (-1,) * (x.dim() - 2))
_keep_indices = _keep_indices.expand((-1, -1) + x.shape[2:])
x = x.gather(1, _keep_indices)
if prefix_tokens is not None:
x = torch.cat((prefix_tokens, x), dim=1)
return x
class PointCloudEncoder(nn.Module):
def __init__(
self,
patch_embed: PatchEmbed,
transformer: Union[VisionTransformer, Eva],
embed_dim: int,
patch_drop_rate=0.0,
):
super().__init__()
self.transformer_dim = transformer.embed_dim
self.embed_dim = embed_dim
# Patch embedding
self.patch_embed = patch_embed
# Project patch features to transformer input dim
self.patch_proj = nn.Linear(self.patch_embed.out_channels, self.transformer_dim)
# Positional embedding
self.pos_embed = nn.Sequential(
nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.transformer_dim)
)
assert patch_drop_rate == 0, "PatchDropout is not compatible with decoder."
if patch_drop_rate > 0:
self.patch_dropout = PatchDropout(patch_drop_rate, num_prefix_tokens=0)
else:
self.patch_dropout = nn.Identity()
# Transformer encoder
self.transformer = transformer
# Project transformer output to embedding dim
self.out_proj = nn.Linear(self.transformer_dim, self.embed_dim)
def forward(self, coords, features):
# Group points into patches and get embeddings
patches = self.patch_embed(coords, features)
if isinstance(patches, list):
patch_embed = patches[-1]["embeddings"]
centers = patches[-1]["centers"]
else:
patch_embed = patches["embeddings"] # [B, L, D]
centers = patches["centers"] # [B, L, 3]
patch_embed = self.patch_proj(patch_embed)
# Positional embedding for patches
pos_embed = self.pos_embed(centers)
x = patch_embed + pos_embed
# Dropout patch
x = self.patch_dropout(x)
# Dropout features
x = self.transformer.pos_drop(x)
for block in self.transformer.blocks:
x = block(x)
# In fact, only norm or fc_norm is not identity in those transformers.
x = self.transformer.norm(x)
x = self.transformer.fc_norm(x)
x = self.out_proj(x)
return x, patches
class Block(nn.Module):
def __init__(self, in_channels, hidden_dim, out_channels):
super().__init__()
# Follow timm.layers.mlp
self.mlp = nn.Sequential(
nn.Linear(in_channels, hidden_dim),
nn.GELU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, out_channels),
)
self.norm = nn.LayerNorm(out_channels)
def forward(self, x):
# PreLN. Follow timm.models.vision_transformer
return x + self.mlp(self.norm(x))
class PatchEmbedNN(nn.Module):
def __init__(self, in_channels, hidden_dim, out_channels, num_patches) -> None:
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
hidden_dim = hidden_dim or out_channels
self.grouper = NNGrouper(num_patches)
self.in_proj = nn.Linear(in_channels, hidden_dim)
self.blocks1 = nn.Sequential(
*[Block(hidden_dim, hidden_dim, hidden_dim) for _ in range(3)]
)
self.blocks2 = nn.Sequential(
*[Block(hidden_dim, hidden_dim, hidden_dim) for _ in range(3)]
)
self.norm = nn.LayerNorm(hidden_dim)
self.out_proj = nn.Linear(hidden_dim, out_channels)
def forward(self, coords: torch.tensor, features: torch.tensor):
patches = self.grouper(coords, features)
patch_features = patches["features"] # [B, N, D]
nn_idx = patches["nn_idx"] # [B, N]
x = self.in_proj(patch_features)
x = self.blocks1(x) # [B, N, D]
y = x.new_zeros(x.shape[0], self.grouper.num_groups, x.shape[-1])
y.scatter_reduce_(
1, nn_idx.unsqueeze(-1).expand_as(x), x, "amax", include_self=False
)
x = self.blocks2(y)
x = self.norm(x)
x = self.out_proj(x)
patches["embeddings"] = x
return patches