Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from beartype import beartype | |
from miche.encode import load_model | |
# helper functions | |
def exists(val): | |
return val is not None | |
def default(*values): | |
for value in values: | |
if exists(value): | |
return value | |
return None | |
# point-cloud encoder from Michelangelo | |
class PointConditioner(torch.nn.Module): | |
def __init__( | |
self, | |
*, | |
dim_latent = None, | |
model_name = 'miche-256-feature', | |
cond_dim = 768, | |
freeze = True, | |
): | |
super().__init__() | |
# open-source version of miche | |
if model_name == 'miche-256-feature': | |
ckpt_path = None | |
config_path = 'miche/shapevae-256.yaml' | |
self.feature_dim = 1024 # embedding dimension | |
self.cond_length = 257 # length of embedding | |
self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path) | |
# additional layers to connect miche and GPT | |
self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim) | |
self.cond_proj = nn.Linear(cond_dim, self.feature_dim) | |
else: | |
raise NotImplementedError | |
# whether to finetuen point-cloud encoder | |
if freeze: | |
for parameter in self.point_encoder.parameters(): | |
parameter.requires_grad = False | |
self.freeze = freeze | |
self.model_name = model_name | |
self.dim_latent = default(dim_latent, self.feature_dim) | |
self.register_buffer('_device_param', torch.tensor(0.), persistent = False) | |
def device(self): | |
return next(self.buffers()).device | |
def embed_pc(self, pc_normal): | |
# encode point cloud to embeddings | |
if self.model_name == 'miche-256-feature': | |
point_feature = self.point_encoder.encode_latents(pc_normal) | |
pc_embed_head = self.cond_head_proj(point_feature[:, 0:1]) | |
pc_embed = self.cond_proj(point_feature[:, 1:]) | |
pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1) | |
return pc_embed | |
def forward( | |
self, | |
pc = None, | |
pc_embeds = None, | |
): | |
if pc_embeds is None: | |
pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype)) | |
assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings' | |
return pc_embeds | |