import torch import torch.nn as nn from monoscene.modules import ( Process, ASPP, ) class CPMegaVoxels(nn.Module): def __init__(self, feature, size, n_relations=4, bn_momentum=0.0003): super().__init__() self.size = size self.n_relations = n_relations print("n_relations", self.n_relations) self.flatten_size = size[0] * size[1] * size[2] self.feature = feature self.context_feature = feature * 2 self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) padding = ((size[0] + 1) % 2, (size[1] + 1) % 2, (size[2] + 1) % 2) self.mega_context = nn.Sequential( nn.Conv3d( feature, self.context_feature, stride=2, padding=padding, kernel_size=3 ), ) self.flatten_context_size = (size[0] // 2) * (size[1] // 2) * (size[2] // 2) self.context_prior_logits = nn.ModuleList( [ nn.Sequential( nn.Conv3d( self.feature, self.flatten_context_size, padding=0, kernel_size=1, ), ) for i in range(n_relations) ] ) self.aspp = ASPP(feature, [1, 2, 3]) self.resize = nn.Sequential( nn.Conv3d( self.context_feature * self.n_relations + feature, feature, kernel_size=1, padding=0, bias=False, ), Process(feature, nn.BatchNorm3d, bn_momentum, dilations=[1]), ) def forward(self, input): ret = {} bs = input.shape[0] x_agg = self.aspp(input) # get the mega context x_mega_context_raw = self.mega_context(x_agg) x_mega_context = x_mega_context_raw.reshape(bs, self.context_feature, -1) x_mega_context = x_mega_context.permute(0, 2, 1) # get context prior map x_context_prior_logits = [] x_context_rels = [] for rel in range(self.n_relations): # Compute the relation matrices x_context_prior_logit = self.context_prior_logits[rel](x_agg) x_context_prior_logit = x_context_prior_logit.reshape( bs, self.flatten_context_size, self.flatten_size ) x_context_prior_logits.append(x_context_prior_logit.unsqueeze(1)) x_context_prior_logit = x_context_prior_logit.permute(0, 2, 1) x_context_prior = torch.sigmoid(x_context_prior_logit) # Multiply the relation matrices with the mega context to gather context features x_context_rel = torch.bmm(x_context_prior, x_mega_context) # bs, N, f x_context_rels.append(x_context_rel) x_context = torch.cat(x_context_rels, dim=2) x_context = x_context.permute(0, 2, 1) x_context = x_context.reshape( bs, x_context.shape[1], self.size[0], self.size[1], self.size[2] ) x = torch.cat([input, x_context], dim=1) x = self.resize(x) x_context_prior_logits = torch.cat(x_context_prior_logits, dim=1) ret["P_logits"] = x_context_prior_logits ret["x"] = x return ret