# -*- coding: utf-8 -*- import torch import torch.nn as nn from typing import Optional from einops import repeat import math from michelangelo.models.modules import checkpoint from michelangelo.models.modules.embedder import FourierEmbedder from michelangelo.models.modules.distributions import DiagonalGaussianDistribution from michelangelo.models.modules.transformer_blocks import ( ResidualCrossAttentionBlock, Transformer ) from .tsal_base import ShapeAsLatentModule class CrossAttentionEncoder(nn.Module): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, fourier_embedder: FourierEmbedder, point_feats: int, width: int, heads: int, layers: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.num_latents = num_latents self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) self.fourier_embedder = fourier_embedder self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) self.cross_attn = ResidualCrossAttentionBlock( device=device, dtype=dtype, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, ) self.self_attn = Transformer( device=device, dtype=dtype, n_ctx=num_latents, width=width, layers=layers, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=False ) if use_ln_post: self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) else: self.ln_post = None def _forward(self, pc, feats): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: """ bs = pc.shape[0] data = self.fourier_embedder(pc) if feats is not None: data = torch.cat([data, feats], dim=-1) data = self.input_proj(data) query = repeat(self.query, "m c -> b m c", b=bs) latents = self.cross_attn(query, data) latents = self.self_attn(latents) if self.ln_post is not None: latents = self.ln_post(latents) return latents, pc def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] Returns: dict """ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) class CrossAttentionDecoder(nn.Module): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, out_channels: int, fourier_embedder: FourierEmbedder, width: int, heads: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.fourier_embedder = fourier_embedder self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) self.cross_attn_decoder = ResidualCrossAttentionBlock( device=device, dtype=dtype, n_data=num_latents, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash ) self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): queries = self.query_proj(self.fourier_embedder(queries)) x = self.cross_attn_decoder(queries, latents) x = self.ln_post(x) x = self.output_proj(x) return x def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) class ShapeAsLatentPerceiver(ShapeAsLatentModule): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, point_feats: int = 0, embed_dim: int = 0, num_freqs: int = 8, include_pi: bool = True, width: int, heads: int, num_encoder_layers: int, num_decoder_layers: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False): super().__init__() self.use_checkpoint = use_checkpoint self.num_latents = num_latents self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) init_scale = init_scale * math.sqrt(1.0 / width) self.encoder = CrossAttentionEncoder( device=device, dtype=dtype, fourier_embedder=self.fourier_embedder, num_latents=num_latents, point_feats=point_feats, width=width, heads=heads, layers=num_encoder_layers, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_ln_post=use_ln_post, use_checkpoint=use_checkpoint ) self.embed_dim = embed_dim if embed_dim > 0: # VAE embed self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) self.latent_shape = (num_latents, embed_dim) else: self.latent_shape = (num_latents, width) self.transformer = Transformer( device=device, dtype=dtype, n_ctx=num_latents, width=width, layers=num_decoder_layers, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=use_checkpoint ) # geometry decoder self.geo_decoder = CrossAttentionDecoder( device=device, dtype=dtype, fourier_embedder=self.fourier_embedder, out_channels=1, num_latents=num_latents, width=width, heads=heads, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_checkpoint=use_checkpoint ) def encode(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None, sample_posterior: bool = True): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] sample_posterior (bool): Returns: latents (torch.FloatTensor) center_pos (torch.FloatTensor or None): posterior (DiagonalGaussianDistribution or None): """ latents, center_pos = self.encoder(pc, feats) posterior = None if self.embed_dim > 0: moments = self.pre_kl(latents) posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) if sample_posterior: latents = posterior.sample() else: latents = posterior.mode() return latents, center_pos, posterior def decode(self, latents: torch.FloatTensor): latents = self.post_kl(latents) return self.transformer(latents) def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): logits = self.geo_decoder(queries, latents).squeeze(-1) return logits def forward(self, pc: torch.FloatTensor, feats: torch.FloatTensor, volume_queries: torch.FloatTensor, sample_posterior: bool = True): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] volume_queries (torch.FloatTensor): [B, P, 3] sample_posterior (bool): Returns: logits (torch.FloatTensor): [B, P] center_pos (torch.FloatTensor): [B, M, 3] posterior (DiagonalGaussianDistribution or None). """ latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) latents = self.decode(latents) logits = self.query_geometry(volume_queries, latents) return logits, center_pos, posterior class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): def __init__(self, *, device: Optional[torch.device], dtype: Optional[torch.dtype], num_latents: int, point_feats: int = 0, embed_dim: int = 0, num_freqs: int = 8, include_pi: bool = True, width: int, heads: int, num_encoder_layers: int, num_decoder_layers: int, init_scale: float = 0.25, qkv_bias: bool = True, flash: bool = False, use_ln_post: bool = False, use_checkpoint: bool = False): super().__init__( device=device, dtype=dtype, num_latents=1 + num_latents, point_feats=point_feats, embed_dim=embed_dim, num_freqs=num_freqs, include_pi=include_pi, width=width, heads=heads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, init_scale=init_scale, qkv_bias=qkv_bias, flash=flash, use_ln_post=use_ln_post, use_checkpoint=use_checkpoint ) self.width = width def encode(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None, sample_posterior: bool = True): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, c] sample_posterior (bool): Returns: shape_embed (torch.FloatTensor) kl_embed (torch.FloatTensor): posterior (DiagonalGaussianDistribution or None): """ shape_embed, latents = self.encode_latents(pc, feats) kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) return shape_embed, kl_embed, posterior def encode_latents(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): x, _ = self.encoder(pc, feats) shape_embed = x[:, 0] latents = x[:, 1:] return shape_embed, latents def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): posterior = None if self.embed_dim > 0: moments = self.pre_kl(latents) posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) if sample_posterior: kl_embed = posterior.sample() else: kl_embed = posterior.mode() else: kl_embed = latents return kl_embed, posterior def forward(self, pc: torch.FloatTensor, feats: torch.FloatTensor, volume_queries: torch.FloatTensor, sample_posterior: bool = True): """ Args: pc (torch.FloatTensor): [B, N, 3] feats (torch.FloatTensor or None): [B, N, C] volume_queries (torch.FloatTensor): [B, P, 3] sample_posterior (bool): Returns: shape_embed (torch.FloatTensor): [B, projection_dim] logits (torch.FloatTensor): [B, M] posterior (DiagonalGaussianDistribution or None). """ shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) latents = self.decode(kl_embed) logits = self.query_geometry(volume_queries, latents) return shape_embed, logits, posterior