import os from omegaconf import OmegaConf import torch from torch import nn from .utils.misc import instantiate_from_config from ..utils import default, exists def load_model(): model_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), "shapevae-256.yaml")) # print(model_config) if hasattr(model_config, "model"): model_config = model_config.model ckpt_path = "./ckpt/checkpoints/aligned_shape_latents/shapevae-256.ckpt" model = instantiate_from_config(model_config, ckpt_path=ckpt_path) # model = model.cuda() model = model.eval() return model class ShapeConditioner(nn.Module): def __init__( self, *, dim_latent = None ): super().__init__() self.model = load_model() self.dim_model_out = 768 dim_latent = default(dim_latent, self.dim_model_out) self.dim_latent = dim_latent def forward( self, shape = None, shape_embed = None, ): assert exists(shape) ^ exists(shape_embed) if not exists(shape_embed): point_feature = self.model.encode_latents(shape) shape_latents = self.model.to_shape_latents(point_feature[:, 1:]) shape_head = point_feature[:, 0:1] shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-1) # shape_embed = torch.cat([point_feature[:, 1:], shape_latents], dim=-2) # cat tmp return shape_head, shape_embed