import torch import torch.nn as nn import torch.nn.functional as F from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d from climategan.utils import find_target_size def create_depth_decoder(opts, no_init=False, verbose=0): if opts.gen.d.architecture == "base": decoder = BaseDepthDecoder(opts) if "s" in opts.task: assert opts.gen.s.use_dada is False if "m" in opts.tasks: assert opts.gen.m.use_dada is False else: decoder = DADADepthDecoder(opts) if verbose > 0: print(f" - Add {decoder.__class__.__name__}") return decoder class DADADepthDecoder(nn.Module): """ Depth decoder based on depth auxiliary task in DADA paper """ def __init__(self, opts): super().__init__() if ( opts.gen.encoder.architecture == "deeplabv3" and opts.gen.deeplabv3.backbone == "mobilenet" ): res_dim = 320 else: res_dim = 2048 mid_dim = 512 self.do_feat_fusion = False if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada): self.do_feat_fusion = True self.dec4 = Conv2dBlock( 128, res_dim, 1, stride=1, padding=0, bias=True, activation="lrelu", norm="none", ) self.relu = nn.ReLU(inplace=True) self.enc4_1 = Conv2dBlock( res_dim, mid_dim, 1, stride=1, padding=0, bias=False, activation="lrelu", pad_type="reflect", norm="batch", ) self.enc4_2 = Conv2dBlock( mid_dim, mid_dim, 3, stride=1, padding=1, bias=False, activation="lrelu", pad_type="reflect", norm="batch", ) self.enc4_3 = Conv2dBlock( mid_dim, 128, 1, stride=1, padding=0, bias=False, activation="lrelu", pad_type="reflect", norm="batch", ) self.upsample = None if opts.gen.d.upsample_featuremaps: self.upsample = nn.Sequential( *[ InterpolateNearest2d(), Conv2dBlock( 128, 32, 3, stride=1, padding=1, bias=False, activation="lrelu", pad_type="reflect", norm="batch", ), nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), ] ) self._target_size = find_target_size(opts, "d") print( " - {}: setting target size to {}".format( self.__class__.__name__, self._target_size ) ) def set_target_size(self, size): """ Set final interpolation's target size Args: size (int, list, tuple): target size (h, w). If int, target will be (i, i) """ if isinstance(size, (list, tuple)): self._target_size = size[:2] else: self._target_size = (size, size) def forward(self, z): if isinstance(z, (list, tuple)): z = z[0] z4_enc = self.enc4_1(z) z4_enc = self.enc4_2(z4_enc) z4_enc = self.enc4_3(z4_enc) z_depth = None if self.do_feat_fusion: z_depth = self.dec4(z4_enc) if self.upsample is not None: z4_enc = self.upsample(z4_enc) depth = torch.mean(z4_enc, dim=1, keepdim=True) # DADA paper decoder if depth.shape[-1] != self._target_size: depth = F.interpolate( depth, size=(384, 384), # size used in MiDaS inference mode="bicubic", # what MiDaS uses align_corners=False, ) depth = F.interpolate( depth, (self._target_size, self._target_size), mode="nearest" ) # what we used in the transforms to resize input return depth, z_depth def __str__(self): return "DADA Depth Decoder" class BaseDepthDecoder(BaseDecoder): def __init__(self, opts): low_level_feats_dim = -1 use_v3 = opts.gen.encoder.architecture == "deeplabv3" use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet" use_low = opts.gen.d.use_low_level_feats if use_v3 and use_mobile_net: input_dim = 320 if use_low: low_level_feats_dim = 24 elif use_v3: input_dim = 2048 if use_low: low_level_feats_dim = 256 else: input_dim = 2048 n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0 output_dim = ( 1 if not opts.gen.d.classify.enable else opts.gen.d.classify.linspace.buckets ) self._target_size = find_target_size(opts, "d") print( " - {}: setting target size to {}".format( self.__class__.__name__, self._target_size ) ) super().__init__( n_upsample=n_upsample, n_res=opts.gen.d.n_res, input_dim=input_dim, proj_dim=opts.gen.d.proj_dim, output_dim=output_dim, norm=opts.gen.d.norm, activ=opts.gen.d.activ, pad_type=opts.gen.d.pad_type, output_activ="none", low_level_feats_dim=low_level_feats_dim, ) def set_target_size(self, size): """ Set final interpolation's target size Args: size (int, list, tuple): target size (h, w). If int, target will be (i, i) """ if isinstance(size, (list, tuple)): self._target_size = size[:2] else: self._target_size = (size, size) def forward(self, z, cond=None): if self._target_size is None: error = "self._target_size should be set with self.set_target_size()" error += "to interpolate depth to the target depth map's size" raise ValueError(error) d = super().forward(z) preds = F.interpolate( d, size=self._target_size, mode="bilinear", align_corners=True ) return preds, None