# encoding: utf-8 import torch import torch.nn as nn import torch.nn.functional as F from monoscene.modules import SegmentationHead from monoscene.CRP3D import CPMegaVoxels from monoscene.modules import Process, Upsample, Downsample class UNet3D(nn.Module): def __init__( self, class_num, norm_layer, full_scene_size, feature, project_scale, context_prior=None, bn_momentum=0.1, ): super(UNet3D, self).__init__() self.business_layer = [] self.project_scale = project_scale self.full_scene_size = full_scene_size self.feature = feature size_l1 = ( int(self.full_scene_size[0] / project_scale), int(self.full_scene_size[1] / project_scale), int(self.full_scene_size[2] / project_scale), ) size_l2 = (size_l1[0] // 2, size_l1[1] // 2, size_l1[2] // 2) size_l3 = (size_l2[0] // 2, size_l2[1] // 2, size_l2[2] // 2) dilations = [1, 2, 3] self.process_l1 = nn.Sequential( Process(self.feature, norm_layer, bn_momentum, dilations=[1, 2, 3]), Downsample(self.feature, norm_layer, bn_momentum), ) self.process_l2 = nn.Sequential( Process(self.feature * 2, norm_layer, bn_momentum, dilations=[1, 2, 3]), Downsample(self.feature * 2, norm_layer, bn_momentum), ) self.up_13_l2 = Upsample( self.feature * 4, self.feature * 2, norm_layer, bn_momentum ) self.up_12_l1 = Upsample( self.feature * 2, self.feature, norm_layer, bn_momentum ) self.up_l1_lfull = Upsample( self.feature, self.feature // 2, norm_layer, bn_momentum ) self.ssc_head = SegmentationHead( self.feature // 2, self.feature // 2, class_num, dilations ) self.context_prior = context_prior if context_prior: self.CP_mega_voxels = CPMegaVoxels( self.feature * 4, size_l3, bn_momentum=bn_momentum ) def forward(self, input_dict): res = {} x3d_l1 = input_dict["x3d"] x3d_l2 = self.process_l1(x3d_l1) x3d_l3 = self.process_l2(x3d_l2) if self.context_prior: ret = self.CP_mega_voxels(x3d_l3) x3d_l3 = ret["x"] for k in ret.keys(): res[k] = ret[k] x3d_up_l2 = self.up_13_l2(x3d_l3) + x3d_l2 x3d_up_l1 = self.up_12_l1(x3d_up_l2) + x3d_l1 x3d_up_lfull = self.up_l1_lfull(x3d_up_l1) ssc_logit_full = self.ssc_head(x3d_up_lfull) res["ssc_logit"] = ssc_logit_full return res