import pytorch_lightning as pl import torch import torch.nn as nn from monoscene.unet3d_nyu import UNet3D as UNet3DNYU from monoscene.unet3d_kitti import UNet3D as UNet3DKitti from monoscene.flosp import FLoSP import numpy as np import torch.nn.functional as F from monoscene.unet2d import UNet2D class MonoScene(pl.LightningModule): def __init__( self, n_classes, feature, project_scale, full_scene_size, dataset, n_relations=4, context_prior=True, fp_loss=True, project_res=[], frustum_size=4, relation_loss=False, CE_ssc_loss=True, geo_scal_loss=True, sem_scal_loss=True, lr=1e-4, weight_decay=1e-4, ): super().__init__() self.project_res = project_res self.fp_loss = fp_loss self.dataset = dataset self.context_prior = context_prior self.frustum_size = frustum_size self.relation_loss = relation_loss self.CE_ssc_loss = CE_ssc_loss self.sem_scal_loss = sem_scal_loss self.geo_scal_loss = geo_scal_loss self.project_scale = project_scale self.lr = lr self.weight_decay = weight_decay self.projects = {} self.scale_2ds = [1, 2, 4, 8] # 2D scales for scale_2d in self.scale_2ds: self.projects[str(scale_2d)] = FLoSP( full_scene_size, project_scale=self.project_scale, dataset=self.dataset ) self.projects = nn.ModuleDict(self.projects) self.n_classes = n_classes if self.dataset == "NYU": self.net_3d_decoder = UNet3DNYU( self.n_classes, nn.BatchNorm3d, n_relations=n_relations, feature=feature, full_scene_size=full_scene_size, context_prior=context_prior, ) elif self.dataset == "kitti": self.net_3d_decoder = UNet3DKitti( self.n_classes, nn.BatchNorm3d, project_scale=project_scale, feature=feature, full_scene_size=full_scene_size, context_prior=context_prior, ) self.net_rgb = UNet2D.build(out_feature=feature, use_decoder=True) def forward(self, batch): img = batch["img"] bs = len(img) out = {} x_rgb = self.net_rgb(img) x3ds = [] for i in range(bs): x3d = None for scale_2d in self.project_res: # project features at each 2D scale to target 3D scale scale_2d = int(scale_2d) projected_pix = batch["projected_pix_{}".format(self.project_scale)][i].cuda() fov_mask = batch["fov_mask_{}".format(self.project_scale)][i].cuda() # Sum all the 3D features if x3d is None: x3d = self.projects[str(scale_2d)]( x_rgb["1_" + str(scale_2d)][i], projected_pix // scale_2d, fov_mask, ) else: x3d += self.projects[str(scale_2d)]( x_rgb["1_" + str(scale_2d)][i], projected_pix // scale_2d, fov_mask, ) x3ds.append(x3d) input_dict = { "x3d": torch.stack(x3ds), } out_dict = self.net_3d_decoder(input_dict) ssc_pred = out_dict["ssc_logit"] y_pred = ssc_pred.detach().cpu().numpy() y_pred = np.argmax(y_pred, axis=1) return y_pred