Spaces:
Sleeping
Sleeping
File size: 2,455 Bytes
ada4b81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from torch import nn
from beartype import beartype
from miche.encode import load_model
# helper functions
def exists(val):
return val is not None
def default(*values):
for value in values:
if exists(value):
return value
return None
# point-cloud encoder from Michelangelo
@beartype
class PointConditioner(torch.nn.Module):
def __init__(
self,
*,
dim_latent = None,
model_name = 'miche-256-feature',
cond_dim = 768,
freeze = True,
):
super().__init__()
# open-source version of miche
if model_name == 'miche-256-feature':
ckpt_path = None
config_path = 'miche/shapevae-256.yaml'
self.feature_dim = 1024 # embedding dimension
self.cond_length = 257 # length of embedding
self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path)
# additional layers to connect miche and GPT
self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim)
self.cond_proj = nn.Linear(cond_dim, self.feature_dim)
else:
raise NotImplementedError
# whether to finetuen point-cloud encoder
if freeze:
for parameter in self.point_encoder.parameters():
parameter.requires_grad = False
self.freeze = freeze
self.model_name = model_name
self.dim_latent = default(dim_latent, self.feature_dim)
self.register_buffer('_device_param', torch.tensor(0.), persistent = False)
@property
def device(self):
return next(self.buffers()).device
def embed_pc(self, pc_normal):
# encode point cloud to embeddings
if self.model_name == 'miche-256-feature':
point_feature = self.point_encoder.encode_latents(pc_normal)
pc_embed_head = self.cond_head_proj(point_feature[:, 0:1])
pc_embed = self.cond_proj(point_feature[:, 1:])
pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1)
return pc_embed
def forward(
self,
pc = None,
pc_embeds = None,
):
if pc_embeds is None:
pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype))
assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings'
return pc_embeds
|