|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from jaxtyping import Float
|
|
from torch import Tensor
|
|
|
|
from spar3d.models.transformers.transformer_1d import Transformer1D
|
|
from spar3d.models.utils import BaseModule
|
|
|
|
|
|
class TransformerPointTokenizer(BaseModule):
|
|
@dataclass
|
|
class Config(BaseModule.Config):
|
|
num_attention_heads: int = 16
|
|
attention_head_dim: int = 64
|
|
in_channels: Optional[int] = 6
|
|
out_channels: Optional[int] = 1024
|
|
num_layers: int = 16
|
|
norm_num_groups: int = 32
|
|
attention_bias: bool = False
|
|
activation_fn: str = "geglu"
|
|
norm_elementwise_affine: bool = True
|
|
|
|
cfg: Config
|
|
|
|
def configure(self) -> None:
|
|
transformer_cfg = dict(self.cfg.copy())
|
|
|
|
transformer_cfg["in_channels"] = (
|
|
self.cfg.num_attention_heads * self.cfg.attention_head_dim
|
|
)
|
|
self.model = Transformer1D(transformer_cfg)
|
|
self.linear_in = torch.nn.Linear(
|
|
self.cfg.in_channels, transformer_cfg["in_channels"]
|
|
)
|
|
self.linear_out = torch.nn.Linear(
|
|
transformer_cfg["in_channels"], self.cfg.out_channels
|
|
)
|
|
|
|
def forward(
|
|
self, points: Float[Tensor, "B N Ci"], **kwargs
|
|
) -> Float[Tensor, "B N Cp"]:
|
|
assert points.ndim == 3
|
|
inputs = self.linear_in(points).permute(0, 2, 1)
|
|
out = self.model(inputs).permute(0, 2, 1)
|
|
out = self.linear_out(out)
|
|
return out
|
|
|
|
def detokenize(self, *args, **kwargs):
|
|
raise NotImplementedError
|
|
|