bpt / model /miche_conditioner.py
whaohan's picture
init commit
ada4b81 verified
raw
history blame
2.46 kB
import torch
from torch import nn
from beartype import beartype
from miche.encode import load_model
# helper functions
def exists(val):
return val is not None
def default(*values):
for value in values:
if exists(value):
return value
return None
# point-cloud encoder from Michelangelo
@beartype
class PointConditioner(torch.nn.Module):
def __init__(
self,
*,
dim_latent = None,
model_name = 'miche-256-feature',
cond_dim = 768,
freeze = True,
):
super().__init__()
# open-source version of miche
if model_name == 'miche-256-feature':
ckpt_path = None
config_path = 'miche/shapevae-256.yaml'
self.feature_dim = 1024 # embedding dimension
self.cond_length = 257 # length of embedding
self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path)
# additional layers to connect miche and GPT
self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim)
self.cond_proj = nn.Linear(cond_dim, self.feature_dim)
else:
raise NotImplementedError
# whether to finetuen point-cloud encoder
if freeze:
for parameter in self.point_encoder.parameters():
parameter.requires_grad = False
self.freeze = freeze
self.model_name = model_name
self.dim_latent = default(dim_latent, self.feature_dim)
self.register_buffer('_device_param', torch.tensor(0.), persistent = False)
@property
def device(self):
return next(self.buffers()).device
def embed_pc(self, pc_normal):
# encode point cloud to embeddings
if self.model_name == 'miche-256-feature':
point_feature = self.point_encoder.encode_latents(pc_normal)
pc_embed_head = self.cond_head_proj(point_feature[:, 0:1])
pc_embed = self.cond_proj(point_feature[:, 1:])
pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1)
return pc_embed
def forward(
self,
pc = None,
pc_embeds = None,
):
if pc_embeds is None:
pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype))
assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings'
return pc_embeds