English
Thomas Male
Upload 98 files
a5407e7
from abc import abstractmethod
from typing import Dict, Optional
import torch
import torch.nn as nn
from .perceiver import SimplePerceiver
from .transformer import Transformer
class PointCloudSDFModel(nn.Module):
@property
@abstractmethod
def device(self) -> torch.device:
"""
Get the device that should be used for input tensors.
"""
@property
@abstractmethod
def default_batch_size(self) -> int:
"""
Get a reasonable default number of query points for the model.
In some cases, this might be the only supported size.
"""
@abstractmethod
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Encode a batch of point clouds to cache part of the SDF calculation
done by forward().
:param point_clouds: a batch of [batch x 3 x N] points.
:return: a state representing the encoded point cloud batch.
"""
def forward(
self,
x: torch.Tensor,
point_clouds: Optional[torch.Tensor] = None,
encoded: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Predict the SDF at the coordinates x, given a batch of point clouds.
Either point_clouds or encoded should be passed. Only exactly one of
these arguments should be None.
:param x: a [batch x 3 x N'] tensor of query points.
:param point_clouds: a [batch x 3 x N] batch of point clouds.
:param encoded: the result of calling encode_point_clouds().
:return: a [batch x N'] tensor of SDF predictions.
"""
assert point_clouds is not None or encoded is not None
assert point_clouds is None or encoded is None
if point_clouds is not None:
encoded = self.encode_point_clouds(point_clouds)
return self.predict_sdf(x, encoded)
@abstractmethod
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""
Predict the SDF at the query points given the encoded point clouds.
Each query point should be treated independently, only conditioning on
the point clouds themselves.
"""
class CrossAttentionPointCloudSDFModel(PointCloudSDFModel):
"""
Encode point clouds using a transformer, and query points using cross
attention to the encoded latents.
"""
def __init__(
self,
*,
device: torch.device,
dtype: torch.dtype,
n_ctx: int = 4096,
width: int = 512,
encoder_layers: int = 12,
encoder_heads: int = 8,
decoder_layers: int = 4,
decoder_heads: int = 8,
init_scale: float = 0.25,
):
super().__init__()
self._device = device
self.n_ctx = n_ctx
self.encoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.encoder = Transformer(
device=device,
dtype=dtype,
n_ctx=n_ctx,
width=width,
layers=encoder_layers,
heads=encoder_heads,
init_scale=init_scale,
)
self.decoder_input_proj = nn.Linear(3, width, device=device, dtype=dtype)
self.decoder = SimplePerceiver(
device=device,
dtype=dtype,
n_data=n_ctx,
width=width,
layers=decoder_layers,
heads=decoder_heads,
init_scale=init_scale,
)
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
@property
def device(self) -> torch.device:
return self._device
@property
def default_batch_size(self) -> int:
return self.n_query
def encode_point_clouds(self, point_clouds: torch.Tensor) -> Dict[str, torch.Tensor]:
h = self.encoder_input_proj(point_clouds.permute(0, 2, 1))
h = self.encoder(h)
return dict(latents=h)
def predict_sdf(
self, x: torch.Tensor, encoded: Optional[Dict[str, torch.Tensor]]
) -> torch.Tensor:
data = encoded["latents"]
x = self.decoder_input_proj(x.permute(0, 2, 1))
x = self.decoder(x, data)
x = self.ln_post(x)
x = self.output_proj(x)
return x[..., 0]