File size: 2,455 Bytes
ada4b81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch import nn
from beartype import beartype
from miche.encode import load_model

# helper functions

def exists(val):
    return val is not None

def default(*values):
    for value in values:
        if exists(value):
            return value
    return None


# point-cloud encoder from Michelangelo
@beartype
class PointConditioner(torch.nn.Module):
    def __init__(
        self,
        *,
        dim_latent = None,
        model_name = 'miche-256-feature',
        cond_dim = 768,
        freeze = True,
    ):
        super().__init__()

        # open-source version of miche
        if model_name == 'miche-256-feature':
            ckpt_path = None
            config_path = 'miche/shapevae-256.yaml'

            self.feature_dim = 1024    # embedding dimension
            self.cond_length = 257     # length of embedding
            self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path)
            
            # additional layers to connect miche and GPT
            self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim)
            self.cond_proj = nn.Linear(cond_dim, self.feature_dim)
            
        else:
            raise NotImplementedError

        # whether to finetuen point-cloud encoder
        if freeze:
            for parameter in self.point_encoder.parameters():
                parameter.requires_grad = False

        self.freeze = freeze
        self.model_name = model_name
        self.dim_latent = default(dim_latent, self.feature_dim)
        
        self.register_buffer('_device_param', torch.tensor(0.), persistent = False)


    @property
    def device(self):
        return next(self.buffers()).device


    def embed_pc(self, pc_normal):
        # encode point cloud to embeddings
        if self.model_name == 'miche-256-feature':
            point_feature = self.point_encoder.encode_latents(pc_normal)
            pc_embed_head = self.cond_head_proj(point_feature[:, 0:1])
            pc_embed = self.cond_proj(point_feature[:, 1:])
            pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1)

        return pc_embed


    def forward(
        self,
        pc = None,
        pc_embeds = None,
    ):
        if pc_embeds is None:
            pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype))
            
        assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings'
        
        return pc_embeds