File size: 2,639 Bytes
4d85df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from monoscene.CRP3D import CPMegaVoxels
from monoscene.modules import (
    Process,
    Upsample,
    Downsample,
    SegmentationHead,
    ASPP,
)


class UNet3D(nn.Module):
    def __init__(
        self,
        class_num,
        norm_layer,
        feature,
        full_scene_size,
        n_relations=4,
        project_res=[],
        context_prior=True,
        bn_momentum=0.1,
    ):
        super(UNet3D, self).__init__()
        self.business_layer = []
        self.project_res = project_res

        self.feature_1_4 = feature
        self.feature_1_8 = feature * 2
        self.feature_1_16 = feature * 4

        self.feature_1_16_dec = self.feature_1_16
        self.feature_1_8_dec = self.feature_1_8
        self.feature_1_4_dec = self.feature_1_4

        self.process_1_4 = nn.Sequential(
            Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]),
            Downsample(self.feature_1_4, norm_layer, bn_momentum),
        )
        self.process_1_8 = nn.Sequential(
            Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]),
            Downsample(self.feature_1_8, norm_layer, bn_momentum),
        )
        self.up_1_16_1_8 = Upsample(
            self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum
        )
        self.up_1_8_1_4 = Upsample(
            self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum
        )
        self.ssc_head_1_4 = SegmentationHead(
            self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3]
        )

        self.context_prior = context_prior
        size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size)

        if context_prior:
            self.CP_mega_voxels = CPMegaVoxels(
                self.feature_1_16,                
                size_1_16,
                n_relations=n_relations,
                bn_momentum=bn_momentum,
            )

    #
    def forward(self, input_dict):
        res = {}

        x3d_1_4 = input_dict["x3d"]
        x3d_1_8 = self.process_1_4(x3d_1_4)
        x3d_1_16 = self.process_1_8(x3d_1_8)

        if self.context_prior:
            ret = self.CP_mega_voxels(x3d_1_16)
            x3d_1_16 = ret["x"]
            for k in ret.keys():
                res[k] = ret[k]

        x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8
        x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4

        ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4)

        res["ssc_logit"] = ssc_logit_1_4

        return res