# -*- coding: utf-8 -*- # # @File: gancraft.py # @Author: Haozhe Xie # @Date: 2023-04-12 19:53:21 # @Last Modified by: Haozhe Xie # @Last Modified at: 2024-03-03 11:15:36 # @Email: root@haozhexie.com # @Ref: https://github.com/FrozenBurning/SceneDreamer import numpy as np import torch import torch.nn.functional as F import citydreamer.extensions.grid_encoder class GanCraftGenerator(torch.nn.Module): def __init__(self, cfg): super(GanCraftGenerator, self).__init__() self.cfg = cfg self.render_net = RenderMLP(cfg) self.denoiser = RenderCNN(cfg) if cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL": self.encoder = GlobalEncoder(cfg) elif cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL": self.encoder = LocalEncoder(cfg) else: self.encoder = None if ( not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): raise ValueError( "Either POS_EMD_INCUDE_CORDS or POS_EMD_INCUDE_FEATURES should be True." ) if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID": grid_encoder_in_dim = 3 if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS else 0 if ( cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"] and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): grid_encoder_in_dim += cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM self.pos_encoder = citydreamer.extensions.grid_encoder.GridEncoder( in_channels=grid_encoder_in_dim, n_levels=cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS, lvl_channels=cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM, desired_resolution=cfg.NETWORK.GANCRAFT.HASH_GRID_RESOLUTION, ) elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS": self.pos_encoder = SinCosEncoder(cfg) def forward( self, hf_seg, voxel_id, depth2, raydirs, cam_origin, building_stats=None, z=None, deterministic=False, ): r"""GANcraft Generator forward. Args: hf_seg (N x (1 + M) x H' x W' tensor) : height field + seg map, where M is the number of classes. voxel_id (N x H x W x max_samples x 1 tensor): IDs of intersected tensors along each ray. depth2 (N x H x W x 2 x max_samples x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. cam_origin (N x 3 tensor): Camera origins. building_stats (N x 5 tensor): The dy, dx, h, w, ID of the target building. (Only used in building mode) z (N x STYLE_DIM tensor): The style vector. deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. Returns: fake_images (N x 3 x H x W tensor): fake images """ bs, device = hf_seg.size(0), hf_seg.device if z is None and self.cfg.NETWORK.GANCRAFT.STYLE_DIM is not None: z = torch.randn( bs, self.cfg.NETWORK.GANCRAFT.STYLE_DIM, dtype=torch.float32, device=device, ) features = None if self.encoder is not None: features = self.encoder(hf_seg) net_out = self._forward_perpix( features, voxel_id, depth2, raydirs, cam_origin, z, building_stats, deterministic, ) fake_images = self._forward_global(net_out, z) return fake_images def _forward_perpix( self, features, voxel_id, depth2, raydirs, cam_origin, z, building_stats=None, deterministic=False, ): r"""Sample points along rays, forwarding the per-point MLP and aggregate pixel features Args: features (N x C1 tensor): Local features determined by the current pixel. voxel_id (N x H x W x M x 1 tensor): Voxel ids from ray-voxel intersection test. M: num intersected voxels depth2 (N x H x W x 2 x M x 1 tensor): Depths of entrance and exit points for each ray-voxel intersection. raydirs (N x H x W x 1 x 3 tensor): The direction of each ray. cam_origin (N x 3 tensor): Camera origins. z (N x C3 tensor): Intermediate style vectors. building_stats (N x 4 tensor): The dy, dx, h, w of the target building. (Only used in building mode) deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. """ # Generate sky_mask; PE transform on ray direction. with torch.no_grad(): # sky_only_mask: when True, ray hits nothing but sky sky_only_mask = voxel_id[:, :, :, [0], :] == 0 with torch.no_grad(): normalized_cord, new_dists, new_idx = self._get_sampled_coordinates( self.cfg.NETWORK.GANCRAFT.N_SAMPLE_POINTS_PER_RAY, depth2, raydirs, cam_origin, building_stats, deterministic, ) # Generate per-sample segmentation label seg_map_bev = torch.gather(voxel_id, -2, new_idx) # print(seg_map_bev.size()) # torch.Size([N, H, W, n_samples + 1, 1]) # In Building Mode, the one more channel is used for building roofs n_classes = ( self.cfg.NETWORK.GANCRAFT.N_CLASSES + 1 if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE else self.cfg.NETWORK.GANCRAFT.N_CLASSES ) seg_map_bev_onehot = torch.zeros( [ seg_map_bev.size(0), seg_map_bev.size(1), seg_map_bev.size(2), seg_map_bev.size(3), n_classes, ], dtype=torch.float, device=voxel_id.device, ) # print(seg_map_bev_onehot.size()) # torch.Size([N, H, W, n_samples + 1, 1]) seg_map_bev_onehot.scatter_(-1, seg_map_bev.long(), 1.0) net_out_s, net_out_c = self._forward_perpix_sub( features, normalized_cord, z, seg_map_bev_onehot ) # Blending weights = self._volum_rendering_relu( net_out_s, new_dists * self.cfg.NETWORK.GANCRAFT.DIST_SCALE, dim=-2 ) # If a ray exclusively hits the sky (no intersection with the voxels), set its weight to zero. weights = weights * torch.logical_not(sky_only_mask).float() # print(weights.size()) # torch.Size([N, H, W, n_samples + 1, 1]) rgbs = torch.clamp(net_out_c, -1, 1) + 1 net_out = torch.sum(weights * rgbs, dim=-2, keepdim=True) net_out = net_out.squeeze(-2) net_out = net_out - 1 return net_out def _get_sampled_coordinates( self, n_samples, depth2, raydirs, cam_origin, building_stats=None, deterministic=False, ): # Random sample points along the ray rand_depth, new_dists, new_idx = self._sample_depth_batched( depth2, n_samples + 1, deterministic=deterministic, use_box_boundaries=False, sample_depth=3, ) nan_mask = torch.isnan(rand_depth) inf_mask = torch.isinf(rand_depth) rand_depth[nan_mask | inf_mask] = 0.0 world_coord = raydirs * rand_depth + cam_origin[:, None, None, None, :] # assert worldcoord2.shape[-1] == 3 if self.cfg.NETWORK.GANCRAFT.BUILDING_MODE: assert building_stats is not None # Make the building object-centric building_stats = building_stats[:, None, None, None, :].repeat( 1, world_coord.size(1), world_coord.size(2), world_coord.size(3), 1 ) world_coord[..., 0] -= ( building_stats[..., 0] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET ) world_coord[..., 1] -= ( building_stats[..., 1] + self.cfg.NETWORK.GANCRAFT.CENTER_OFFSET ) # TODO: Fix non-building rays zero_rd_mask = raydirs.repeat(1, 1, 1, n_samples, 1) world_coord[zero_rd_mask == 0] = 0 normalized_cord = self._get_normalized_coordinates(world_coord) return normalized_cord, new_dists, new_idx def _get_normalized_coordinates(self, world_coord): delimeter = torch.tensor( self.cfg.NETWORK.GANCRAFT.NORMALIZE_DELIMETER, device=world_coord.device ) normalized_cord = world_coord / delimeter * 2 - 1 # TODO: Temporary fix normalized_cord[normalized_cord > 1] = 1 normalized_cord[normalized_cord < -1] = -1 # assert (normalized_cord <= 1).all() # assert (normalized_cord >= -1).all() # print(delimeter, torch.min(normalized_cord), torch.max(normalized_cord)) # print(normalized_cord.size()) # torch.Size([1, 192, 192, 24, 3]) return normalized_cord def _sample_depth_batched( self, depth2, n_samples, deterministic=False, use_box_boundaries=True, sample_depth=3, ): r"""Make best effort to sample points within the same distance for every ray. Exception: When there is not enough voxel. Args: depth2 (N x H x W x 2 x M x 1 tensor): - N: Batch. - H, W: Height, Width. - 2: Entrance / exit depth for each intersected box. - M: Number of intersected boxes along the ray. - 1: One extra dim for consistent tensor dims. depth2 can include NaNs. deterministic (bool): Whether to use equal-distance sampling instead of random stratified sampling. use_box_boundaries (bool): Whether to add the entrance / exit points into the sample. sample_depth (float): Truncate the ray when it travels further than sample_depth inside voxels. """ bs = depth2.size(0) dim0 = depth2.size(1) dim1 = depth2.size(2) dists = depth2[:, :, :, 1] - depth2[:, :, :, 0] dists[torch.isnan(dists)] = 0 # print(dists.size()) # torch.Size([N, H, W, M, 1]) accu_depth = torch.cumsum(dists, dim=-2) # print(accu_depth.size()) # torch.Size([N, H, W, M, 1]) total_depth = accu_depth[..., [-1], :] # print(total_depth.size()) # torch.Size([N, H, W, 1, 1]) total_depth = torch.clamp(total_depth, None, sample_depth) # Ignore out of range box boundaries. Fill with random samples. if use_box_boundaries: boundary_samples = accu_depth.clone().detach() boundary_samples_filler = torch.rand_like(boundary_samples) * total_depth bad_mask = (accu_depth > sample_depth) | (dists == 0) boundary_samples[bad_mask] = boundary_samples_filler[bad_mask] rand_shape = [bs, dim0, dim1, n_samples, 1] if deterministic: rand_samples = torch.empty( rand_shape, dtype=total_depth.dtype, device=total_depth.device ) rand_samples[..., :, 0] = torch.linspace(0, 1, n_samples + 2)[1:-1] else: rand_samples = torch.rand( rand_shape, dtype=total_depth.dtype, device=total_depth.device ) # Stratified sampling as in NeRF rand_samples = rand_samples / n_samples rand_samples[..., :, 0] += torch.linspace( 0, 1, n_samples + 1, device=rand_samples.device )[:-1] rand_samples = rand_samples * total_depth # print(rand_samples.size()) # torch.Size([N, H, W, n_samples, 1]) # Can also include boundaries if use_box_boundaries: rand_samples = torch.cat( [ rand_samples, boundary_samples, torch.zeros( [bs, dim0, dim1, 1, 1], dtype=total_depth.dtype, device=total_depth.device, ), ], dim=-2, ) rand_samples, _ = torch.sort(rand_samples, dim=-2, descending=False) midpoints = (rand_samples[..., 1:, :] + rand_samples[..., :-1, :]) / 2 # print(midpoints.size()) # torch.Size([N, H, W, n_samples, 1]) new_dists = rand_samples[..., 1:, :] - rand_samples[..., :-1, :] # Scatter the random samples back # print(midpoints.unsqueeze(-3).size()) # torch.Size([N, H, W, 1, n_samples, 1]) # print(accu_depth.unsqueeze(-2).size()) # torch.Size([N, H, W, M, 1, 1]) idx = torch.sum(midpoints.unsqueeze(-3) > accu_depth.unsqueeze(-2), dim=-3) # print(idx.shape, idx.max(), idx.min()) # torch.Size([N, H, W, n_samples, 1]) max 5, min 0 depth_deltas = ( depth2[:, :, :, 0, 1:, :] - depth2[:, :, :, 1, :-1, :] ) # There might be NaNs! # print(depth_deltas.size()) # torch.Size([N, H, W, M, M - 1, 1]) depth_deltas = torch.cumsum(depth_deltas, dim=-2) depth_deltas = torch.cat( [depth2[:, :, :, 0, [0], :], depth_deltas + depth2[:, :, :, 0, [0], :]], dim=-2, ) heads = torch.gather(depth_deltas, -2, idx) # print(heads.size()) # torch.Size([N, H, W, M, 1]) # print(torch.any(torch.isnan(heads))) rand_depth = heads + midpoints # print(rand_depth.size()) # torch.Size([N, H, W, M, n_samples, 1]) return rand_depth, new_dists, idx def _volum_rendering_relu(self, sigma, dists, dim=2): free_energy = F.relu(sigma) * dists a = 1 - torch.exp(-free_energy.float()) # probability of it is not empty here b = torch.exp( -self._cumsum_exclusive(free_energy, dim=dim) ) # probability of everything is empty up to now return a * b # probability of the ray hits something here def _cumsum_exclusive(self, tensor, dim): cumsum = torch.cumsum(tensor, dim) cumsum = torch.roll(cumsum, 1, dim) cumsum.index_fill_( dim, torch.tensor([0], dtype=torch.long, device=tensor.device), 0 ) return cumsum def _forward_perpix_sub(self, features, normalized_cord, z, seg_map_bev_onehot): r"""Forwarding the MLP. Args: features (N x C1 x ...? tensor): Local features determined by the current pixel. normalized_coord (N x H x W x L x 3 tensor): 3D world coordinates of sampled points. L is number of samples; N is batch size, always 1. z (N x C3 tensor): Intermediate style vectors. seg_map_bev_onehot (N x H x W x L x C4): One-hot segmentation maps. Returns: net_out_s (N x H x W x L x 1 tensor): Opacities. net_out_c (N x H x W x L x C5 tensor): Color embeddings. """ feature_in = torch.empty( normalized_cord.size(0), normalized_cord.size(1), normalized_cord.size(2), normalized_cord.size(3), 0, device=normalized_cord.device, ) if self.cfg.NETWORK.GANCRAFT.ENCODER == "GLOBAL": # print(features.size()) # torch.Size([N, ENCODER_OUT_DIM]) feature_in = features[:, None, None, None, :].repeat( 1, normalized_cord.size(1), normalized_cord.size(2), normalized_cord.size(3), 1, ) elif self.cfg.NETWORK.GANCRAFT.ENCODER == "LOCAL": # print(features.size()) # torch.Size([N, ENCODER_OUT_DIM - 1, H, W]) # print(world_coord.size()) # torch.Size([N, H, W, L, 3]) # NOTE: grid specifies the sampling pixel locations normalized by the input spatial # dimensions. Therefore, it should have most values in the range of [-1, 1]. grid = normalized_cord.permute(0, 3, 1, 2, 4).reshape( -1, normalized_cord.size(1), normalized_cord.size(2), 3 ) # print(grid.size()) # torch.Size([N * L, H, W, 3]) feature_in = F.grid_sample( features.repeat(grid.size(0), 1, 1, 1), grid[..., [1, 0]], align_corners=False, ) # print(feature_in.size()) # torch.Size([N * L, ENCODER_OUT_DIM - 1, H, W]) feature_in = feature_in.reshape( normalized_cord.size(0), normalized_cord.size(3), feature_in.size(1), feature_in.size(2), feature_in.size(3), ).permute(0, 3, 4, 1, 2) # print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM - 1]) feature_in = torch.cat([feature_in, normalized_cord[..., [2]]], dim=-1) # print(feature_in.size()) # torch.Size([N, H, W, L, ENCODER_OUT_DIM]) if self.cfg.NETWORK.GANCRAFT.POS_EMD in ["HASH_GRID", "SIN_COS"]: if ( self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): feature_in = self.pos_encoder( torch.cat([normalized_cord, feature_in], dim=-1) ) elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS: feature_in = torch.cat( [self.pos_encoder(normalized_cord), feature_in], dim=-1 ) elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES: # Ignore normalized_cord here to make it decoupled with coordinates feature_in = torch.cat([self.pos_encoder(feature_in)], dim=-1) else: if ( self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): feature_in = torch.cat([normalized_cord, feature_in], dim=-1) elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS: feature_in = normalized_cord elif self.cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES: feature_in = feature_in net_out_s, net_out_c = self.render_net(feature_in, z, seg_map_bev_onehot) return net_out_s, net_out_c def _forward_global(self, net_out, z): r"""Forward the CNN Args: net_out (N x C5 x H x W tensor): Intermediate feature maps. z (N x C3 tensor): Intermediate style vectors. Returns: fake_images (N x 3 x H x W tensor): Output image. """ fake_images = net_out.permute(0, 3, 1, 2).contiguous() if self.denoiser is not None: fake_images = self.denoiser(fake_images, z) fake_images = torch.tanh(fake_images) return fake_images class GlobalEncoder(torch.nn.Module): def __init__(self, cfg): super(GlobalEncoder, self).__init__() n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES self.hf_conv = torch.nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1) self.seg_conv = torch.nn.Conv2d( n_classes, 8, kernel_size=3, stride=2, padding=1, ) conv_blocks = [] cur_hidden_channels = 16 for _ in range(1, cfg.NETWORK.GANCRAFT.GLOBAL_ENCODER_N_BLOCKS): conv_blocks.append( SRTConvBlock(in_channels=cur_hidden_channels, out_channels=None) ) cur_hidden_channels *= 2 self.conv_blocks = torch.nn.Sequential(*conv_blocks) self.fc1 = torch.nn.Linear(cur_hidden_channels, 16) self.fc2 = torch.nn.Linear(16, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM) self.act = torch.nn.LeakyReLU(0.2) def forward(self, hf_seg): hf = self.act(self.hf_conv(hf_seg[:, [0]])) seg = self.act(self.seg_conv(hf_seg[:, 1:])) out = torch.cat([hf, seg], dim=1) for layer in self.conv_blocks: out = self.act(layer(out)) out = out.permute(0, 2, 3, 1) out = torch.mean(out.reshape(out.shape[0], -1, out.shape[-1]), dim=1) cond = self.act(self.fc1(out)) cond = torch.tanh(self.fc2(cond)) return cond class LocalEncoder(torch.nn.Module): def __init__(self, cfg): super(LocalEncoder, self).__init__() n_classes = cfg.NETWORK.GANCRAFT.N_CLASSES self.hf_conv = torch.nn.Conv2d(1, 32, kernel_size=7, stride=2, padding=3) self.seg_conv = torch.nn.Conv2d( n_classes, 32, kernel_size=7, stride=2, padding=3 ) if cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "BATCH_NORM": self.bn1 = torch.nn.BatchNorm2d(64) elif cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM == "GROUP_NORM": self.bn1 = torch.nn.GroupNorm(32, 64) else: raise ValueError( "Unknown normalization: %s" % cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM ) self.conv2 = ResConvBlock(64, 128, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) self.conv3 = ResConvBlock(128, 256, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) self.conv4 = ResConvBlock(256, 512, cfg.NETWORK.GANCRAFT.LOCAL_ENCODER_NORM) self.dconv5 = torch.nn.ConvTranspose2d( 512, 128, kernel_size=4, stride=2, padding=1 ) self.dconv6 = torch.nn.ConvTranspose2d( 128, 32, kernel_size=4, stride=2, padding=1 ) self.dconv7 = torch.nn.Conv2d( 32, cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM - 1, kernel_size=1 ) def forward(self, hf_seg): hf = self.hf_conv(hf_seg[:, [0]]) seg = self.seg_conv(hf_seg[:, 1:]) out = F.relu(self.bn1(torch.cat([hf, seg], dim=1)), inplace=True) # print(out.size()) # torch.Size([N, 64, H/2, W/2]) out = F.avg_pool2d(self.conv2(out), 2, stride=2) # print(out.size()) # torch.Size([N, 128, H/4, W/4]) out = self.conv3(out) # print(out.size()) # torch.Size([N, 256, H/4, W/4]) out = self.conv4(out) # print(out.size()) # torch.Size([N, 512, H/4, W/4]) out = self.dconv5(out) # print(out.size()) # torch.Size([N, 128, H/2, W/2]) out = self.dconv6(out) # print(out.size()) # torch.Size([N, 32, H, W]) out = self.dconv7(out) # print(out.size()) # torch.Size([N, OUT_DIM - 1, H, W]) return torch.tanh(out) class SinCosEncoder(torch.nn.Module): def __init__(self, cfg): super(SinCosEncoder, self).__init__() self.freq_bands = 2.0 ** torch.linspace( 0, cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS - 1, steps=cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS, ) def forward(self, features): cord_sin = torch.cat( [torch.sin(features * fb) for fb in self.freq_bands], dim=-1 ) cord_cos = torch.cat( [torch.cos(features * fb) for fb in self.freq_bands], dim=-1 ) return torch.cat([cord_sin, cord_cos], dim=-1) class RenderMLP(torch.nn.Module): r"""MLP with affine modulation.""" def __init__(self, cfg): super(RenderMLP, self).__init__() in_dim = 0 f_dim = ( cfg.NETWORK.GANCRAFT.ENCODER_OUT_DIM if cfg.NETWORK.GANCRAFT.ENCODER in ["GLOBAL", "LOCAL"] else 0 ) if cfg.NETWORK.GANCRAFT.POS_EMD == "HASH_GRID": in_dim = ( cfg.NETWORK.GANCRAFT.HASH_GRID_N_LEVELS * cfg.NETWORK.GANCRAFT.HASH_GRID_LEVEL_DIM ) in_dim += ( f_dim if cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and not cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES else 0 ) elif cfg.NETWORK.GANCRAFT.POS_EMD == "SIN_COS": if ( cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): in_dim = (3 + f_dim) * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2 elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS: in_dim = 3 * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2 + f_dim elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES: in_dim = f_dim * cfg.NETWORK.GANCRAFT.SIN_COS_FREQ_BENDS * 2 else: if ( cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS and cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES ): in_dim = 3 + f_dim elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_CORDS: in_dim = 3 elif cfg.NETWORK.GANCRAFT.POS_EMD_INCUDE_FEATURES: in_dim = f_dim self.fc_m_a = torch.nn.Linear( cfg.NETWORK.GANCRAFT.N_CLASSES + 1 if cfg.NETWORK.GANCRAFT.BUILDING_MODE else cfg.NETWORK.GANCRAFT.N_CLASSES, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, bias=False, ) self.fc_1 = torch.nn.Linear( in_dim, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) self.fc_2 = ( ModLinear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.STYLE_DIM, bias=False, mod_bias=True, output_mode=True, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) ) self.fc_3 = ( ModLinear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.STYLE_DIM, bias=False, mod_bias=True, output_mode=True, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) ) self.fc_4 = ( ModLinear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.STYLE_DIM, bias=False, mod_bias=True, output_mode=True, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) ) self.fc_sigma = ( torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_SIGMA, ) ) self.fc_5 = ( ModLinear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.STYLE_DIM, bias=False, mod_bias=True, output_mode=True, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) ) self.fc_6 = ( ModLinear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.STYLE_DIM, bias=False, mod_bias=True, output_mode=True, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) ) self.fc_out_c = ( torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR, ) if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None else torch.nn.Linear( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR, ) ) self.act = torch.nn.LeakyReLU(negative_slope=0.2) def forward(self, x, z, m): r"""Forward network Args: x (N x H x W x M x in_channels tensor): Projected features. z (N x cfg.NETWORK.GANCRAFT.STYLE_DIM tensor): Style codes. m (N x H x W x M x mask_channels tensor): One-hot segmentation maps. """ # b, h, w, n, _ = x.size() if z is not None: z = z[:, None, None, None, :] f = self.fc_1(x) f = f + self.fc_m_a(m) # Common MLP f = self.act(f) f = self.act(self.fc_2(f, z)) if z is not None else self.act(self.fc_2(f)) f = self.act(self.fc_3(f, z)) if z is not None else self.act(self.fc_3(f)) f = self.act(self.fc_4(f, z)) if z is not None else self.act(self.fc_4(f)) # Sigma MLP sigma = self.fc_sigma(f) if z is not None else self.act(self.fc_sigma(f)) # Color MLP f = self.act(self.fc_5(f, z)) if z is not None else self.act(self.fc_5(f)) f = self.act(self.fc_6(f, z)) if z is not None else self.act(self.fc_6(f)) c = self.fc_out_c(f) return sigma, c class RenderCNN(torch.nn.Module): r"""CNN converting intermediate feature map to final image.""" def __init__(self, cfg): super(RenderCNN, self).__init__() if cfg.NETWORK.GANCRAFT.STYLE_DIM is not None: self.fc_z_cond = torch.nn.Linear( cfg.NETWORK.GANCRAFT.STYLE_DIM, 2 * 2 * cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, ) self.conv1 = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_OUT_DIM_COLOR, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 1, stride=1, padding=0, ) self.conv2a = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, stride=1, padding=1, ) self.conv2b = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, stride=1, padding=1, bias=False, ) self.conv3a = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, stride=1, padding=1, ) self.conv3b = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, stride=1, padding=1, bias=False, ) self.conv4a = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 1, stride=1, padding=0, ) self.conv4b = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 1, stride=1, padding=0, ) self.conv4 = torch.nn.Conv2d( cfg.NETWORK.GANCRAFT.RENDER_HIDDEN_DIM, 3, 1, stride=1, padding=0 ) self.act = torch.nn.LeakyReLU(negative_slope=0.2, inplace=True) def modulate(self, x, w, b): w = w[..., None, None] b = b[..., None, None] return x * (w + 1) + b def forward(self, x, z): r"""Forward network. Args: x (N x in_channels x H x W tensor): Intermediate feature map z (N x style_dim tensor): Style codes. """ if z is not None: z = self.fc_z_cond(z) adapt = torch.chunk(z, 2 * 2, dim=-1) y = self.act(self.conv1(x)) y = y + self.conv2b(self.act(self.conv2a(y))) if z is not None: y = self.act(self.modulate(y, adapt[0], adapt[1])) else: y = self.act(y) y = y + self.conv3b(self.act(self.conv3a(y))) if z is not None: y = self.act(self.modulate(y, adapt[2], adapt[3])) else: y = self.act(y) y = y + self.conv4b(self.act(self.conv4a(y))) y = self.act(y) y = self.conv4(y) return y class SRTConvBlock(torch.nn.Module): def __init__(self, in_channels, hidden_channels=None, out_channels=None): super(SRTConvBlock, self).__init__() if hidden_channels is None: hidden_channels = in_channels if out_channels is None: out_channels = 2 * hidden_channels self.layers = torch.nn.Sequential( torch.nn.Conv2d( in_channels, hidden_channels, stride=1, kernel_size=3, padding=1, bias=False, ), torch.nn.ReLU(), torch.nn.Conv2d( hidden_channels, out_channels, stride=2, kernel_size=3, padding=1, bias=False, ), torch.nn.ReLU(), ) def forward(self, x): return self.layers(x) class ResConvBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, norm, bias=False): super(ResConvBlock, self).__init__() # conv3x3(in_planes, int(out_planes / 2)) self.conv1 = torch.nn.Conv2d( in_channels, out_channels // 2, kernel_size=3, stride=1, padding=1, bias=bias, ) # conv3x3(int(out_planes / 2), int(out_planes / 4)) self.conv2 = torch.nn.Conv2d( out_channels // 2, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=bias, ) # conv3x3(int(out_planes / 4), int(out_planes / 4)) self.conv3 = torch.nn.Conv2d( out_channels // 4, out_channels // 4, kernel_size=3, stride=1, padding=1, bias=bias, ) if norm == "BATCH_NORM": self.bn1 = torch.nn.BatchNorm2d(in_channels) self.bn2 = torch.nn.BatchNorm2d(out_channels // 2) self.bn3 = torch.nn.BatchNorm2d(out_channels // 4) self.bn4 = torch.nn.BatchNorm2d(in_channels) elif norm == "GROUP_NORM": self.bn1 = torch.nn.GroupNorm(32, in_channels) self.bn2 = torch.nn.GroupNorm(32, out_channels // 2) self.bn3 = torch.nn.GroupNorm(32, out_channels // 4) self.bn4 = torch.nn.GroupNorm(32, in_channels) if in_channels != out_channels: self.downsample = torch.nn.Sequential( self.bn4, torch.nn.ReLU(True), torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, bias=False ), ) else: self.downsample = None def forward(self, x): residual = x # print(residual.size()) # torch.Size([N, 64, H, W]) out1 = self.bn1(x) out1 = F.relu(out1, True) out1 = self.conv1(out1) # print(out1.size()) # torch.Size([N, 64, H, W]) out2 = self.bn2(out1) out2 = F.relu(out2, True) out2 = self.conv2(out2) # print(out2.size()) # torch.Size([N, 32, H, W]) out3 = self.bn3(out2) out3 = F.relu(out3, True) out3 = self.conv3(out3) # print(out3.size()) # torch.Size([N, 32, H, W]) out3 = torch.cat((out1, out2, out3), dim=1) # print(out3.size()) # torch.Size([N, 128, H, W]) if self.downsample is not None: residual = self.downsample(residual) # print(residual.size()) # torch.Size([N, 128, H, W]) out3 += residual return out3 class ModLinear(torch.nn.Module): r"""Linear layer with affine modulation (Based on StyleGAN2 mod demod). Equivalent to affine modulation following linear, but faster when the same modulation parameters are shared across multiple inputs. Args: in_features (int): Number of input features. out_features (int): Number of output features. style_features (int): Number of style features. bias (bool): Apply additive bias before the activation function? mod_bias (bool): Whether to modulate bias. output_mode (bool): If True, modulate output instead of input. weight_gain (float): Initialization gain """ def __init__( self, in_features, out_features, style_features, bias=True, mod_bias=True, output_mode=False, weight_gain=1, bias_init=0, ): super(ModLinear, self).__init__() weight_gain = weight_gain / np.sqrt(in_features) self.weight = torch.nn.Parameter( torch.randn([out_features, in_features]) * weight_gain ) self.bias = ( torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None ) self.weight_alpha = torch.nn.Parameter( torch.randn([in_features, style_features]) / np.sqrt(style_features) ) self.bias_alpha = torch.nn.Parameter( torch.full([in_features], 1, dtype=torch.float) ) # init to 1 self.weight_beta = None self.bias_beta = None self.mod_bias = mod_bias self.output_mode = output_mode if mod_bias: if output_mode: mod_bias_dims = out_features else: mod_bias_dims = in_features self.weight_beta = torch.nn.Parameter( torch.randn([mod_bias_dims, style_features]) / np.sqrt(style_features) ) self.bias_beta = torch.nn.Parameter( torch.full([mod_bias_dims], 0, dtype=torch.float) ) @staticmethod def _linear_f(x, w, b): w = w.to(x.dtype) x_shape = x.shape x = x.reshape(-1, x_shape[-1]) if b is not None: b = b.to(x.dtype) x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = x.reshape(*x_shape[:-1], -1) return x # x: B, ... , Cin # z: B, 1, 1, , Cz def forward(self, x, z): x_shape = x.shape z_shape = z.shape x = x.reshape(x_shape[0], -1, x_shape[-1]) z = z.reshape(z_shape[0], 1, z_shape[-1]) alpha = self._linear_f(z, self.weight_alpha, self.bias_alpha) # [B, ..., I] w = self.weight.to(x.dtype) # [O I] w = w.unsqueeze(0) * alpha # [1 O I] * [B 1 I] = [B O I] if self.mod_bias: beta = self._linear_f(z, self.weight_beta, self.bias_beta) # [B, ..., I] if not self.output_mode: x = x + beta b = self.bias if b is not None: b = b.to(x.dtype)[None, None, :] if self.mod_bias and self.output_mode: if b is None: b = beta else: b = b + beta # [B ? I] @ [B I O] = [B ? O] if b is not None: x = torch.baddbmm(b, x, w.transpose(1, 2)) else: x = x.bmm(w.transpose(1, 2)) x = x.reshape(*x_shape[:-1], x.shape[-1]) return x class GanCraftDiscriminator(torch.nn.Module): def __init__(self, cfg): super(GanCraftDiscriminator, self).__init__() # bottom-up pathway # down_conv2d_block = Conv2dBlock, stride=2, kernel=3, padding=1, weight_norm=spectral # self.enc1 = down_conv2d_block(num_input_channels, num_filters) # 3 self.enc1 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 3, # RGB cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=2, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.enc2 = down_conv2d_block(1 * num_filters, 2 * num_filters) # 7 self.enc2 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 1 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=2, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.enc3 = down_conv2d_block(2 * num_filters, 4 * num_filters) # 15 self.enc3 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=2, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.enc4 = down_conv2d_block(4 * num_filters, 8 * num_filters) # 31 self.enc4 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=2, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.enc5 = down_conv2d_block(8 * num_filters, 8 * num_filters) # 63 self.enc5 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=2, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # top-down pathway # latent_conv2d_block = Conv2dBlock, stride=1, kernel=1, weight_norm=spectral # self.lat2 = latent_conv2d_block(2 * num_filters, 4 * num_filters) self.lat2 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=1, kernel_size=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.lat3 = latent_conv2d_block(4 * num_filters, 4 * num_filters) self.lat3 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=1, kernel_size=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.lat4 = latent_conv2d_block(8 * num_filters, 4 * num_filters) self.lat4 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=1, kernel_size=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.lat5 = latent_conv2d_block(8 * num_filters, 4 * num_filters) self.lat5 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 8 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=1, kernel_size=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # upsampling self.upsample2x = torch.nn.Upsample( scale_factor=2, mode="bilinear", align_corners=False ) # final layers # stride1_conv2d_block = Conv2dBlock, stride=1, kernel=3, padding=1, weight_norm=spectral # self.final2 = stride1_conv2d_block(4 * num_filters, 2 * num_filters) self.final2 = torch.nn.Sequential( torch.nn.utils.spectral_norm( torch.nn.Conv2d( 4 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, stride=1, kernel_size=3, padding=1, bias=True, ) ), torch.nn.LeakyReLU(0.2), ) # self.output = Conv2dBlock(num_filters * 2, num_labels + 1, kernel_size=1) self.output = torch.nn.Sequential( torch.nn.Conv2d( 2 * cfg.NETWORK.GANCRAFT.DIS_N_CHANNEL_BASE, cfg.NETWORK.GANCRAFT.N_CLASSES + 1, stride=1, kernel_size=1, bias=True, ), torch.nn.LeakyReLU(0.2), ) self.interpolator = self._smooth_interp @staticmethod def _smooth_interp(x, size): r"""Smooth interpolation of segmentation maps. Args: x (4D tensor): Segmentation maps. size(2D list): Target size (H, W). """ x = F.interpolate(x, size=size, mode="area") onehot_idx = torch.argmax(x, dim=-3, keepdims=True) x.fill_(0.0) x.scatter_(1, onehot_idx, 1.0) return x def _single_forward(self, images, seg_maps): # bottom-up pathway feat11 = self.enc1(images) feat12 = self.enc2(feat11) feat13 = self.enc3(feat12) feat14 = self.enc4(feat13) feat15 = self.enc5(feat14) # top-down pathway and lateral connections feat25 = self.lat5(feat15) feat24 = self.upsample2x(feat25) + self.lat4(feat14) feat23 = self.upsample2x(feat24) + self.lat3(feat13) feat22 = self.upsample2x(feat23) + self.lat2(feat12) # final prediction layers feat32 = self.final2(feat22) label_map = self.interpolator(seg_maps, size=feat32.size()[2:]) pred = self.output(feat32) # N, num_labels + 1, H//4, W//4 return {"pred": pred, "label": label_map} def forward(self, images, seg_maps, masks): # print(seg_maps.size()) # torch.Size([1, 7, H, W]) # print(masks.size()) # torch.Size([1, 1, H, W]) seg_maps = seg_maps * masks return self._single_forward(images * masks, seg_maps)