from __future__ import annotations from functools import partial from math import ceil import os from accelerate.utils import DistributedDataParallelKwargs from beartype.typing import Tuple, Callable, List from einops import rearrange, repeat, reduce, pack from gateloop_transformer import SimpleGateLoopLayer from huggingface_hub import PyTorchModelHubMixin import numpy as np import trimesh from tqdm import tqdm import torch from torch import nn, Tensor from torch.nn import Module, ModuleList import torch.nn.functional as F from pytorch3d.loss import chamfer_distance from pytorch3d.transforms import euler_angles_to_matrix from x_transformers import Decoder from x_transformers.x_transformers import LayerIntermediates from x_transformers.autoregressive_wrapper import eval_decorator from .michelangelo import ShapeConditioner as ShapeConditioner_miche from .utils import ( discretize, undiscretize, set_module_requires_grad_, default, exists, safe_cat, identity, is_tensor_empty, ) from .utils.typing import Float, Int, Bool, typecheck # constants DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( find_unused_parameters = True ) SHAPE_CODE = { 'CubeBevel': 0, 'SphereSharp': 1, 'CylinderSharp': 2, } BS_NAME = { 0: 'CubeBevel', 1: 'SphereSharp', 2: 'CylinderSharp', } # FiLM block class FiLM(Module): def __init__(self, dim, dim_out = None): super().__init__() dim_out = default(dim_out, dim) self.to_gamma = nn.Linear(dim, dim_out, bias = False) self.to_beta = nn.Linear(dim, dim_out) self.gamma_mult = nn.Parameter(torch.zeros(1,)) self.beta_mult = nn.Parameter(torch.zeros(1,)) def forward(self, x, cond): gamma, beta = self.to_gamma(cond), self.to_beta(cond) gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta)) # for initializing to identity gamma = (1 + self.gamma_mult * gamma.tanh()) beta = beta.tanh() * self.beta_mult # classic film return x * gamma + beta # gateloop layers class GateLoopBlock(Module): def __init__( self, dim, *, depth, use_heinsen = True ): super().__init__() self.gateloops = ModuleList([]) for _ in range(depth): gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen) self.gateloops.append(gateloop) def forward( self, x, cache = None ): received_cache = exists(cache) if is_tensor_empty(x): return x, None if received_cache: prev, x = x[:, :-1], x[:, -1:] cache = default(cache, []) cache = iter(cache) new_caches = [] for gateloop in self.gateloops: layer_cache = next(cache, None) out, new_cache = gateloop(x, cache = layer_cache, return_cache = True) new_caches.append(new_cache) x = x + out if received_cache: x = torch.cat((prev, x), dim = -2) return x, new_caches def top_k_2(logits, frac_num_tokens=0.1, k=None): num_tokens = logits.shape[-1] k = default(k, ceil(frac_num_tokens * num_tokens)) k = min(k, num_tokens) val, ind = torch.topk(logits, k) probs = torch.full_like(logits, float('-inf')) probs.scatter_(2, ind, val) return probs def soft_argmax(labels): indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device) soft_argmax = torch.sum(labels * indices, dim=-1) return soft_argmax class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin): @typecheck def __init__( self, *, num_discrete_scale = 128, continuous_range_scale: List[float, float] = [0, 1], dim_scale_embed = 64, num_discrete_rotation = 180, continuous_range_rotation: List[float, float] = [-180, 180], dim_rotation_embed = 64, num_discrete_translation = 128, continuous_range_translation: List[float, float] = [-1, 1], dim_translation_embed = 64, num_type = 3, dim_type_embed = 64, embed_order = 'ctrs', bin_smooth_blur_sigma = 0.4, dim: int | Tuple[int, int] = 512, flash_attn = True, attn_depth = 12, attn_dim_head = 64, attn_heads = 16, attn_kwargs: dict = dict( ff_glu = True, attn_num_mem_kv = 4 ), max_primitive_len = 144, dropout = 0., coarse_pre_gateloop_depth = 2, coarse_post_gateloop_depth = 0, coarse_adaptive_rmsnorm = False, gateloop_use_heinsen = False, pad_id = -1, num_sos_tokens = None, condition_on_shape = True, shape_cond_with_cross_attn = False, shape_cond_with_film = False, shape_cond_with_cat = False, shape_condition_model_type = 'michelangelo', shape_condition_len = 1, shape_condition_dim = None, cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition loss_weight: dict = dict( eos = 1.0, type = 1.0, scale = 1.0, rotation = 1.0, translation = 1.0, reconstruction = 1.0, scale_huber = 1.0, rotation_huber = 1.0, translation_huber = 1.0, ), bs_pc_dir=None, ): super().__init__() # feature embedding self.num_discrete_scale = num_discrete_scale self.continuous_range_scale = continuous_range_scale self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed) self.num_discrete_rotation = num_discrete_rotation self.continuous_range_rotation = continuous_range_rotation self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed) self.num_discrete_translation = num_discrete_translation self.continuous_range_translation = continuous_range_translation self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed) self.num_type = num_type self.type_embed = nn.Embedding(num_type, dim_type_embed) self.embed_order = embed_order self.bin_smooth_blur_sigma = bin_smooth_blur_sigma # initial dimension self.dim = dim init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed # project into model dimension self.project_in = nn.Linear(init_dim, dim) num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4) assert num_sos_tokens > 0 self.num_sos_tokens = num_sos_tokens self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim)) # the transformer eos token self.eos_token = nn.Parameter(torch.randn(1, dim)) self.emb_layernorm = nn.LayerNorm(dim) self.max_seq_len = max_primitive_len # shape condition self.condition_on_shape = condition_on_shape self.shape_cond_with_cross_attn = False self.shape_cond_with_cat = False self.shape_condition_model_type = '' self.conditioner = None dim_shape = None if condition_on_shape: assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat self.shape_cond_with_cross_attn = shape_cond_with_cross_attn self.shape_cond_with_cat = shape_cond_with_cat self.shape_condition_model_type = shape_condition_model_type if 'michelangelo' in shape_condition_model_type: self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim) self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent) self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent) else: raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') dim_shape = self.conditioner.dim_latent set_module_requires_grad_(self.conditioner, False) self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm self.decoder = Decoder( dim=dim, depth=attn_depth, heads=attn_heads, attn_dim_head=attn_dim_head, attn_flash=flash_attn, attn_dropout=dropout, ff_dropout=dropout, use_adaptive_rmsnorm=coarse_adaptive_rmsnorm, dim_condition=dim_shape, cross_attend=self.shape_cond_with_cross_attn, cross_attn_dim_context=dim_shape, cross_attn_num_mem_kv=cross_attn_num_mem_kv, **attn_kwargs ) # to logits self.to_eos_logits = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, 1) ) self.to_type_logits = nn.Sequential( nn.Linear(dim, dim), nn.ReLU(), nn.Linear(dim, num_type) ) self.to_translation_logits = nn.Sequential( nn.Linear(dim + dim_type_embed, dim), nn.ReLU(), nn.Linear(dim, 3 * num_discrete_translation) ) self.to_rotation_logits = nn.Sequential( nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim), nn.ReLU(), nn.Linear(dim, 3 * num_discrete_rotation) ) self.to_scale_logits = nn.Sequential( nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim), nn.ReLU(), nn.Linear(dim, 3 * num_discrete_scale) ) self.pad_id = pad_id bs_pc_map = {} for bs_name, type_code in SHAPE_CODE.items(): pc = trimesh.load(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply')) bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.vertices)).float() bs_pc_list = [] for i in range(len(bs_pc_map)): bs_pc_list.append(bs_pc_map[i]) self.bs_pc = torch.stack(bs_pc_list, dim=0) self.rotation_matrix_align_coord = euler_angles_to_matrix( torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0) @property def device(self): return next(self.parameters()).device @typecheck @torch.no_grad() def embed_pc(self, pc: Tensor): if 'michelangelo' in self.shape_condition_model_type: pc_head, pc_embed = self.conditioner(shape=pc) pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach() else: raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') return pc_embed @typecheck def recon_primitives( self, scale_logits: Float['b np 3 nd'], rotation_logits: Float['b np 3 nd'], translation_logits: Float['b np 3 nd'], type_logits: Int['b np nd'], primitive_mask: Bool['b np'] ): recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1)) recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1)) recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1)) recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) recon_type_code = type_logits.argmax(dim=-1) recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1) return { 'scale': recon_scale, 'rotation': recon_rotation, 'translation': recon_translation, 'type_code': recon_type_code } @typecheck def sample_primitives( self, scale: Float['b np 3 nd'], rotation: Float['b np 3 nd'], translation: Float['b np 3 nd'], type_code: Int['b np nd'], next_embed: Float['b 1 nd'], temperature: float = 1., filter_logits_fn: Callable = top_k_2, filter_kwargs: dict = dict() ): def sample_func(logits): if logits.ndim == 4: enable_squeeze = True logits = logits.squeeze(1) else: enable_squeeze = False filtered_logits = filter_logits_fn(logits, **filter_kwargs) if temperature == 0.: sample = filtered_logits.argmax(dim=-1) else: probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device) for b_i in range(probs.shape[0]): sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze() if enable_squeeze: sample = sample.unsqueeze(1) return sample next_type_logits = self.to_type_logits(next_embed) next_type_code = sample_func(next_type_logits) type_code_new, _ = pack([type_code, next_type_code], 'b *') type_embed = self.type_embed(next_type_code) next_embed_packed, _ = pack([next_embed, type_embed], 'b np *') next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation) next_discretize_translation = sample_func(next_translation_logits) next_translation = self.undiscretize_translation(next_discretize_translation) translation_new, _ = pack([translation, next_translation], 'b * nd') next_translation_embed = self.translation_embed(next_discretize_translation) next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *') next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation) next_discretize_rotation = sample_func(next_rotation_logits) next_rotation = self.undiscretize_rotation(next_discretize_rotation) rotation_new, _ = pack([rotation, next_rotation], 'b * nd') next_rotation_embed = self.rotation_embed(next_discretize_rotation) next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *') next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale) next_discretize_scale = sample_func(next_scale_logits) next_scale = self.undiscretize_scale(next_discretize_scale) scale_new, _ = pack([scale, next_scale], 'b * nd') return ( scale_new, rotation_new, translation_new, type_code_new ) @eval_decorator @torch.no_grad() @typecheck def generate( self, batch_size: int | None = None, filter_logits_fn: Callable = top_k_2, filter_kwargs: dict = dict(), temperature: float = 1., scale: Float['b np 3'] | None = None, rotation: Float['b np 3'] | None = None, translation: Float['b np 3'] | None = None, type_code: Int['b np'] | None = None, pc: Tensor | None = None, pc_embed: Tensor | None = None, cache_kv = True, max_seq_len = None, ): max_seq_len = default(max_seq_len, self.max_seq_len) if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): assert not exists(batch_size) assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] assert scale.shape[1] <= self.max_seq_len batch_size = scale.shape[0] if self.condition_on_shape: assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' if exists(pc): pc_embed = self.embed_pc(pc) batch_size = default(batch_size, pc_embed.shape[0]) batch_size = default(batch_size, 1) scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) curr_length = scale.shape[1] cache = None eos_codes = None for i in tqdm(range(curr_length, max_seq_len)): can_eos = i != 0 output = self.forward( scale=scale, rotation=rotation, translation=translation, type_code=type_code, pc_embed=pc_embed, return_loss=False, return_cache=cache_kv, append_eos=False, cache=cache ) if cache_kv: next_embed, cache = output else: next_embed = output ( scale, rotation, translation, type_code ) = self.sample_primitives( scale, rotation, translation, type_code, next_embed, temperature=temperature, filter_logits_fn=filter_logits_fn, filter_kwargs=filter_kwargs ) next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) eos_codes = safe_cat([eos_codes, next_eos_code], 1) if can_eos and eos_codes.any(dim=-1).all(): break # mask out to padding anything after the first eos mask = eos_codes.float().cumsum(dim=-1) >= 1 # concat cur_length to mask mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1) type_code = type_code.masked_fill(mask, self.pad_id) scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) recon_primitives = { 'scale': scale, 'rotation': rotation, 'translation': translation, 'type_code': type_code } primitive_mask = ~eos_codes return recon_primitives, primitive_mask @eval_decorator @torch.no_grad() @typecheck def generate_w_recon_loss( self, batch_size: int | None = None, filter_logits_fn: Callable = top_k_2, filter_kwargs: dict = dict(), temperature: float = 1., scale: Float['b np 3'] | None = None, rotation: Float['b np 3'] | None = None, translation: Float['b np 3'] | None = None, type_code: Int['b np'] | None = None, pc: Tensor | None = None, pc_embed: Tensor | None = None, cache_kv = True, max_seq_len = None, single_directional = True, ): max_seq_len = default(max_seq_len, self.max_seq_len) if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): assert not exists(batch_size) assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] assert scale.shape[1] <= self.max_seq_len batch_size = scale.shape[0] if self.condition_on_shape: assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' if exists(pc): pc_embed = self.embed_pc(pc) batch_size = default(batch_size, pc_embed.shape[0]) batch_size = default(batch_size, 1) assert batch_size == 1 # TODO: support any batch size scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) curr_length = scale.shape[1] cache = None eos_codes = None last_recon_loss = 1 for i in tqdm(range(curr_length, max_seq_len)): can_eos = i != 0 output = self.forward( scale=scale, rotation=rotation, translation=translation, type_code=type_code, pc_embed=pc_embed, return_loss=False, return_cache=cache_kv, append_eos=False, cache=cache ) if cache_kv: next_embed, cache = output else: next_embed = output ( scale_new, rotation_new, translation_new, type_code_new ) = self.sample_primitives( scale, rotation, translation, type_code, next_embed, temperature=temperature, filter_logits_fn=filter_logits_fn, filter_kwargs=filter_kwargs ) next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) eos_codes = safe_cat([eos_codes, next_eos_code], 1) if can_eos and eos_codes.any(dim=-1).all(): scale, rotation, translation, type_code = ( scale_new, rotation_new, translation_new, type_code_new) break recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional) if recon_loss < last_recon_loss: last_recon_loss = recon_loss scale, rotation, translation, type_code = ( scale_new, rotation_new, translation_new, type_code_new) else: best_recon_loss = recon_loss best_primitives = dict( scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) success_flag = False print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive') for try_i in range(5): ( scale_new, rotation_new, translation_new, type_code_new ) = self.sample_primitives( scale, rotation, translation, type_code, next_embed, temperature=1.0, filter_logits_fn=filter_logits_fn, filter_kwargs=filter_kwargs ) recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc) print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}') if recon_loss < last_recon_loss: last_recon_loss = recon_loss scale, rotation, translation, type_code = ( scale_new, rotation_new, translation_new, type_code_new) success_flag = True break else: if recon_loss < best_recon_loss: best_recon_loss = recon_loss best_primitives = dict( scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) if not success_flag: last_recon_loss = best_recon_loss scale, rotation, translation, type_code = ( best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code']) print(f'new_last_recon_loss:{last_recon_loss}') # mask out to padding anything after the first eos mask = eos_codes.float().cumsum(dim=-1) >= 1 type_code = type_code.masked_fill(mask, self.pad_id) scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) recon_primitives = { 'scale': scale, 'rotation': rotation, 'translation': translation, 'type_code': type_code } primitive_mask = ~eos_codes return recon_primitives, primitive_mask @typecheck def encode( self, *, scale: Float['b np 3'], rotation: Float['b np 3'], translation: Float['b np 3'], type_code: Int['b np'], primitive_mask: Bool['b np'], return_primitives = False ): """ einops: b - batch np - number of primitives c - coordinates (3) d - embed dim """ # compute feature embedding discretize_scale = self.discretize_scale(scale) scale_embed = self.scale_embed(discretize_scale) scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)') discretize_rotation = self.discretize_rotation(rotation) rotation_embed = self.rotation_embed(discretize_rotation) rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)') discretize_translation = self.discretize_translation(translation) translation_embed = self.translation_embed(discretize_translation) translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)') type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0)) # combine all features and project into model dimension if self.embed_order == 'srtc': primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *') else: primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *') primitive_embed = self.project_in(primitive_embed) primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.) if not return_primitives: return primitive_embed primitive_embed_unpacked = { 'scale': scale_embed, 'rotation': rotation_embed, 'translation': translation_embed, 'type_code': type_embed } primitives_gt = { 'scale': discretize_scale, 'rotation': discretize_rotation, 'translation': discretize_translation, 'type_code': type_code } return primitive_embed, primitive_embed_unpacked, primitives_gt @typecheck def compute_chamfer_distance( self, scale_pred: Float['b np 3'], rotation_pred: Float['b np 3'], translation_pred: Float['b np 3'], type_pred: Int['b np'], primitive_mask: Bool['b np'], pc: Tensor, # b, num_points, c single_directional = True ): scale_pred = scale_pred.float() rotation_pred = rotation_pred.float() translation_pred = translation_pred.float() pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred) pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device)) pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c') pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1]) if single_directional: recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional else: recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float()) return recon_loss def forward( self, *, scale: Float['b np 3'], rotation: Float['b np 3'], translation: Float['b np 3'], type_code: Int['b np'], loss_reduction: str = 'mean', return_cache = False, append_eos = True, cache: LayerIntermediates | None = None, pc: Tensor | None = None, pc_embed: Tensor | None = None, **kwargs ): primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all') if scale.shape[1] > 0: codes, primitives_embeds, primitives_gt = self.encode( scale=scale, rotation=rotation, translation=translation, type_code=type_code, primitive_mask=primitive_mask, return_primitives=True ) else: codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device) # handle shape conditions attn_context_kwargs = dict() if self.condition_on_shape: assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' if exists(pc): if 'michelangelo' in self.shape_condition_model_type: pc_head, pc_embed = self.conditioner(shape=pc) pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2) else: raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes' pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim) if self.shape_cond_with_cross_attn: attn_context_kwargs = dict( context=pc_embed ) if self.coarse_adaptive_rmsnorm: attn_context_kwargs.update( condition=pooled_pc_embed ) batch, seq_len, _ = codes.shape # (b, np, dim) device = codes.device assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}' if append_eos: assert exists(codes) code_lens = primitive_mask.sum(dim=-1) codes = pad_tensor(codes) batch_arange = torch.arange(batch, device=device) batch_arange = rearrange(batch_arange, '... -> ... 1') code_lens = rearrange(code_lens, '... -> ... 1') codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim) primitive_codes = codes # (b, np, dim) primitive_codes_len = primitive_codes.shape[-2] ( coarse_cache, coarse_gateloop_cache, coarse_post_gateloop_cache, ) = cache if exists(cache) else ((None,) * 3) if not exists(cache): sos = repeat(self.sos_token, 'n d -> b n d', b=batch) if self.shape_cond_with_cat: sos, _ = pack([pc_embed, sos], 'b * d') primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim) # condition primitive codes with shape if needed if self.condition_on_shape: primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed) # attention on primitive codes (coarse) if exists(self.coarse_gateloop_block): primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache) attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim) primitive_codes, cache=coarse_cache, return_hiddens=True, **attn_context_kwargs ) if exists(self.coarse_post_gateloop_block): primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache) embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim) if not return_cache: return embed[:, -1:] next_cache = ( coarse_cache, coarse_gateloop_cache, coarse_post_gateloop_cache ) return embed[:, -1:], next_cache def pad_tensor(tensor): if tensor.dim() == 3: bs, seq_len, dim = tensor.shape padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device) elif tensor.dim() == 2: bs, seq_len = tensor.shape padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device) else: raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape)) return torch.cat([tensor, padding], dim=1) def apply_transformation(pc, scale, rotation_vector, translation): bs, np, num_points, _ = pc.shape scaled_pc = pc * scale.unsqueeze(2) rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc) transformed_pc = rotated_pc + translation.unsqueeze(2) return transformed_pc def random_sample_pc(pc, max_lens, n_points=10000): bs = max_lens.shape[0] max_len = max_lens.max().item() * n_points random_values = torch.rand(bs, max_len, device=max_lens.device) mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points) masked_random_values = random_values * mask.float() _, indices = torch.topk(masked_random_values, n_points, dim=1) return pc[torch.arange(bs).unsqueeze(1), indices]