from dataclasses import dataclass import math import torch import torch.nn as nn from einops import repeat, rearrange from transformers import CLIPModel import craftsman from craftsman.models.transformers.perceiver_1d import Perceiver from craftsman.models.transformers.attention import ResidualCrossAttentionBlock from craftsman.utils.checkpoint import checkpoint from craftsman.utils.base import BaseModule from craftsman.utils.typing import * from .utils import AutoEncoder, FourierEmbedder, get_embedder class PerceiverCrossAttentionEncoder(nn.Module): def __init__(self, use_downsample: bool, num_latents: int, embedder: FourierEmbedder, point_feats: int, embed_point_feats: bool, width: int, heads: int, layers: int, init_scale: float = 0.25, qkv_bias: bool = True, use_ln_post: bool = False, use_flash: bool = False, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.num_latents = num_latents self.use_downsample = use_downsample self.embed_point_feats = embed_point_feats if not self.use_downsample: self.query = nn.Parameter(torch.randn((num_latents, width)) * 0.02) self.embedder = embedder if self.embed_point_feats: self.input_proj = nn.Linear(self.embedder.out_dim * 2, width) else: self.input_proj = nn.Linear(self.embedder.out_dim + point_feats, width) self.cross_attn = ResidualCrossAttentionBlock( width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, use_flash=use_flash, ) self.self_attn = Perceiver( n_ctx=num_latents, width=width, layers=layers, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, use_flash=use_flash, use_checkpoint=False ) if use_ln_post: self.ln_post = nn.LayerNorm(width) else: self.ln_post = None def _forward(self, pc, feats): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: """ bs, N, D = pc.shape data = self.embedder(pc) if feats is not None: if self.embed_point_feats: feats = self.embedder(feats) data = torch.cat([data, feats], dim=-1) data = self.input_proj(data) if self.use_downsample: ###### fps from torch_cluster import fps flattened = pc.view(bs*N, D) batch = torch.arange(bs).to(pc.device) batch = torch.repeat_interleave(batch, N) pos = flattened ratio = 1.0 * self.num_latents / N idx = fps(pos, batch, ratio=ratio) query = data.view(bs*N, -1)[idx].view(bs, -1, data.shape[-1]) else: query = self.query query = repeat(query, "m c -> b m c", b=bs) latents = self.cross_attn(query, data) latents = self.self_attn(latents) if self.ln_post is not None: latents = self.ln_post(latents) return latents def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: dict """ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) class PerceiverCrossAttentionDecoder(nn.Module): def __init__(self, num_latents: int, out_dim: int, embedder: FourierEmbedder, width: int, heads: int, init_scale: float = 0.25, qkv_bias: bool = True, use_flash: bool = False, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.embedder = embedder self.query_proj = nn.Linear(self.embedder.out_dim, width) self.cross_attn_decoder = ResidualCrossAttentionBlock( n_data=num_latents, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, use_flash=use_flash ) self.ln_post = nn.LayerNorm(width) self.output_proj = nn.Linear(width, out_dim) def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): queries = self.query_proj(self.embedder(queries)) x = self.cross_attn_decoder(queries, latents) x = self.ln_post(x) x = self.output_proj(x) return x def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) @craftsman.register("michelangelo-autoencoder") class MichelangeloAutoencoder(AutoEncoder): r""" A VAE model for encoding shapes into latents and decoding latent representations into shapes. """ @dataclass class Config(BaseModule.Config): pretrained_model_name_or_path: str = "" use_downsample: bool = False num_latents: int = 256 point_feats: int = 0 embed_point_feats: bool = False out_dim: int = 1 embed_dim: int = 64 embed_type: str = "fourier" num_freqs: int = 8 include_pi: bool = True width: int = 768 heads: int = 12 num_encoder_layers: int = 8 num_decoder_layers: int = 16 init_scale: float = 0.25 qkv_bias: bool = True use_ln_post: bool = False use_flash: bool = False use_checkpoint: bool = True cfg: Config def configure(self) -> None: super().configure() self.embedder = get_embedder(embed_type=self.cfg.embed_type, num_freqs=self.cfg.num_freqs, include_pi=self.cfg.include_pi) # encoder self.cfg.init_scale = self.cfg.init_scale * math.sqrt(1.0 / self.cfg.width) self.encoder = PerceiverCrossAttentionEncoder( use_downsample=self.cfg.use_downsample, embedder=self.embedder, num_latents=self.cfg.num_latents, point_feats=self.cfg.point_feats, embed_point_feats=self.cfg.embed_point_feats, width=self.cfg.width, heads=self.cfg.heads, layers=self.cfg.num_encoder_layers, init_scale=self.cfg.init_scale, qkv_bias=self.cfg.qkv_bias, use_ln_post=self.cfg.use_ln_post, use_flash=self.cfg.use_flash, use_checkpoint=self.cfg.use_checkpoint ) if self.cfg.embed_dim > 0: # VAE embed self.pre_kl = nn.Linear(self.cfg.width, self.cfg.embed_dim * 2) self.post_kl = nn.Linear(self.cfg.embed_dim, self.cfg.width) self.latent_shape = (self.cfg.num_latents, self.cfg.embed_dim) else: self.latent_shape = (self.cfg.num_latents, self.cfg.width) self.transformer = Perceiver( n_ctx=self.cfg.num_latents, width=self.cfg.width, layers=self.cfg.num_decoder_layers, heads=self.cfg.heads, init_scale=self.cfg.init_scale, qkv_bias=self.cfg.qkv_bias, use_flash=self.cfg.use_flash, use_checkpoint=self.cfg.use_checkpoint ) # decoder self.decoder = PerceiverCrossAttentionDecoder( embedder=self.embedder, out_dim=self.cfg.out_dim, num_latents=self.cfg.num_latents, width=self.cfg.width, heads=self.cfg.heads, init_scale=self.cfg.init_scale, qkv_bias=self.cfg.qkv_bias, use_flash=self.cfg.use_flash, use_checkpoint=self.cfg.use_checkpoint ) if self.cfg.pretrained_model_name_or_path != "": print(f"Loading pretrained model from {self.cfg.pretrained_model_name_or_path}") pretrained_ckpt = torch.load(self.cfg.pretrained_model_name_or_path, map_location="cpu") if 'state_dict' in pretrained_ckpt: _pretrained_ckpt = {} for k, v in pretrained_ckpt['state_dict'].items(): if k.startswith('shape_model.'): _pretrained_ckpt[k.replace('shape_model.', '')] = v pretrained_ckpt = _pretrained_ckpt self.load_state_dict(pretrained_ckpt, strict=True) def encode(self, surface: torch.FloatTensor, sample_posterior: bool = True): """ Args: surface (torch.FloatTensor): [B, N, 3+C] sample_posterior (bool): Returns: shape_latents (torch.FloatTensor): [B, num_latents, width] kl_embed (torch.FloatTensor): [B, num_latents, embed_dim] posterior (DiagonalGaussianDistribution or None): """ assert surface.shape[-1] == 3 + self.cfg.point_feats, f"\ Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 shape_latents = self.encoder(pc, feats) # B, num_latents, width kl_embed, posterior = self.encode_kl_embed(shape_latents, sample_posterior) # B, num_latents, embed_dim return shape_latents, kl_embed, posterior def decode(self, latents: torch.FloatTensor): """ Args: latents (torch.FloatTensor): [B, embed_dim] Returns: latents (torch.FloatTensor): [B, embed_dim] """ latents = self.post_kl(latents) # [B, num_latents, embed_dim] -> [B, num_latents, width] return self.transformer(latents) def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor): """ Args: queries (torch.FloatTensor): [B, N, 3] latents (torch.FloatTensor): [B, embed_dim] Returns: logits (torch.FloatTensor): [B, N], occupancy logits """ logits = self.decoder(queries, latents).squeeze(-1) return logits @craftsman.register("michelangelo-aligned-autoencoder") class MichelangeloAlignedAutoencoder(MichelangeloAutoencoder): r""" A VAE model for encoding shapes into latents and decoding latent representations into shapes. """ @dataclass class Config(MichelangeloAutoencoder.Config): clip_model_version: Optional[str] = None cfg: Config def configure(self) -> None: if self.cfg.clip_model_version is not None: self.clip_model: CLIPModel = CLIPModel.from_pretrained(self.cfg.clip_model_version) self.projection = nn.Parameter(torch.empty(self.cfg.width, self.clip_model.projection_dim)) self.logit_scale = torch.exp(self.clip_model.logit_scale.data) nn.init.normal_(self.projection, std=self.clip_model.projection_dim ** -0.5) else: self.projection = nn.Parameter(torch.empty(self.cfg.width, 768)) nn.init.normal_(self.projection, std=768 ** -0.5) self.cfg.num_latents = self.cfg.num_latents + 1 super().configure() def encode(self, surface: torch.FloatTensor, sample_posterior: bool = True): """ Args: surface (torch.FloatTensor): [B, N, 3+C] sample_posterior (bool): Returns: latents (torch.FloatTensor) posterior (DiagonalGaussianDistribution or None): """ assert surface.shape[-1] == 3 + self.cfg.point_feats, f"\ Expected {3 + self.cfg.point_feats} channels, got {surface.shape[-1]}" pc, feats = surface[..., :3], surface[..., 3:] # B, n_samples, 3 shape_latents = self.encoder(pc, feats) # B, num_latents, width shape_embeds = shape_latents[:, 0] # B, width shape_latents = shape_latents[:, 1:] # B, num_latents-1, width kl_embed, posterior = self.encode_kl_embed(shape_latents, sample_posterior) # B, num_latents, embed_dim shape_embeds = shape_embeds @ self.projection return shape_embeds, kl_embed, posterior def forward(self, surface: torch.FloatTensor, queries: torch.FloatTensor, sample_posterior: bool = True): """ Args: surface (torch.FloatTensor): [B, N, 3+C] queries (torch.FloatTensor): [B, P, 3] sample_posterior (bool): Returns: shape_embeds (torch.FloatTensor): [B, width] latents (torch.FloatTensor): [B, num_latents, embed_dim] posterior (DiagonalGaussianDistribution or None). logits (torch.FloatTensor): [B, P] """ shape_embeds, kl_embed, posterior = self.encode(surface, sample_posterior=sample_posterior) latents = self.decode(kl_embed) # [B, num_latents - 1, width] logits = self.query(queries, latents) # [B,] return shape_embeds, latents, posterior, logits