File size: 2,759 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from contextlib import nullcontext

import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers import ModelMixin
from torch import Tensor

from .pvcnn.pvcnn import PVCNN2
from .pvcnn.pvcnn_plus_plus import PVCNN2PlusPlus
from .simple.simple_model import SimplePointModel


class PointCloudModel(ModelMixin, ConfigMixin):
    @register_to_config
    def __init__(
        self,
        model_type: str = 'pvcnn',
        in_channels: int = 3,
        out_channels: int = 3,
        embed_dim: int = 64,
        dropout: float = 0.1,
        width_multiplier: int = 1,
        voxel_resolution_multiplier: int = 1,
    ):
        super().__init__()
        self.model_type = model_type
        if self.model_type == 'pvcnn':
            self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
            self.model = PVCNN2(
                embed_dim=embed_dim,
                num_classes=out_channels,
                extra_feature_channels=(in_channels - 3),
                dropout=dropout, width_multiplier=width_multiplier, 
                voxel_resolution_multiplier=voxel_resolution_multiplier
            )
            self.model.classifier[-1].bias.data.normal_(0, 1e-6)
            self.model.classifier[-1].weight.data.normal_(0, 1e-6)
        elif self.model_type == 'pvcnnplusplus':
            self.autocast_context = torch.autocast('cuda', dtype=torch.float32)
            self.model = PVCNN2PlusPlus(
                embed_dim=embed_dim,
                num_classes=out_channels,
                extra_feature_channels=(in_channels - 3),
            )
            self.model.output_projection[-1].bias.data.normal_(0, 1e-6)
            self.model.output_projection[-1].weight.data.normal_(0, 1e-6)
        elif self.model_type == 'simple':
            self.autocast_context = nullcontext()
            self.model = SimplePointModel(
                embed_dim=embed_dim,
                num_classes=out_channels,
                extra_feature_channels=(in_channels - 3),
            )
            self.model.output_projection.bias.data.normal_(0, 1e-6)
            self.model.output_projection.weight.data.normal_(0, 1e-6)
        else:
            raise NotImplementedError()

    def forward(self, inputs: Tensor, t: Tensor, ret_feats=False) -> Tensor:
        """ Receives input of shape (B, N, in_channels) and returns output
            of shape (B, N, out_channels) """
        with self.autocast_context:
            if not ret_feats:
                return self.model(inputs.transpose(1, 2), t, ret_feats=False).transpose(1, 2)
            else:
                pred, feats = self.model(inputs.transpose(1, 2), t, ret_feats=True)
                return pred.transpose(1, 2), feats