# encoding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from monoscene.CRP3D import CPMegaVoxels from monoscene.modules import ( Process, Upsample, Downsample, SegmentationHead, ASPP, ) class UNet3D(nn.Module): def __init__( self, class_num, norm_layer, feature, full_scene_size, n_relations=4, project_res=[], context_prior=True, bn_momentum=0.1, ): super(UNet3D, self).__init__() self.business_layer = [] self.project_res = project_res self.feature_1_4 = feature self.feature_1_8 = feature * 2 self.feature_1_16 = feature * 4 self.feature_1_16_dec = self.feature_1_16 self.feature_1_8_dec = self.feature_1_8 self.feature_1_4_dec = self.feature_1_4 self.process_1_4 = nn.Sequential( Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]), Downsample(self.feature_1_4, norm_layer, bn_momentum), ) self.process_1_8 = nn.Sequential( Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]), Downsample(self.feature_1_8, norm_layer, bn_momentum), ) self.up_1_16_1_8 = Upsample( self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum ) self.up_1_8_1_4 = Upsample( self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum ) self.ssc_head_1_4 = SegmentationHead( self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3] ) self.context_prior = context_prior size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size) if context_prior: self.CP_mega_voxels = CPMegaVoxels( self.feature_1_16, size_1_16, n_relations=n_relations, bn_momentum=bn_momentum, ) # def forward(self, input_dict): res = {} x3d_1_4 = input_dict["x3d"] x3d_1_8 = self.process_1_4(x3d_1_4) x3d_1_16 = self.process_1_8(x3d_1_8) if self.context_prior: ret = self.CP_mega_voxels(x3d_1_16) x3d_1_16 = ret["x"] for k in ret.keys(): res[k] = ret[k] x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8 x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4 ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4) res["ssc_logit"] = ssc_logit_1_4 return res