# Copyright (C) 2022-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # -------------------------------------------------------- # Heads for downstream tasks # -------------------------------------------------------- """ A head is a module where the __init__ defines only the head hyperparameters. A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' """ import torch import torch.nn as nn from .dpt_block import DPTOutputAdapter class PixelwiseTaskWithDPT(nn.Module): """ DPT module for CroCo. by default, hooks_idx will be equal to: * for encoder-only: 4 equally spread layers * for encoder+decoder: last encoder + 3 equally spread layers of the decoder """ def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): super(PixelwiseTaskWithDPT, self).__init__() self.return_all_blocks = True # backbone needs to return all layers self.postprocess = postprocess self.output_width_ratio = output_width_ratio self.num_channels = num_channels self.hooks_idx = hooks_idx self.layer_dims = layer_dims def setup(self, croconet): dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} if self.hooks_idx is None: if hasattr(croconet, 'dec_blocks'): # encoder + decoder step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] else: # encoder only step = croconet.enc_depth//4 hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] self.hooks_idx = hooks_idx print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') dpt_args['hooks'] = self.hooks_idx dpt_args['layer_dims'] = self.layer_dims self.dpt = DPTOutputAdapter(**dpt_args) dim_tokens = [croconet.enc_embed_dim if hook