Spaces:
Running
Running
# Copyright (C) 2022-present Naver Corporation. All rights reserved. | |
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
# -------------------------------------------------------- | |
# Heads for downstream tasks | |
# -------------------------------------------------------- | |
""" | |
A head is a module where the __init__ defines only the head hyperparameters. | |
A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes. | |
The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height' | |
""" | |
import torch | |
import torch.nn as nn | |
from .dpt_block import DPTOutputAdapter | |
class PixelwiseTaskWithDPT(nn.Module): | |
""" DPT module for CroCo. | |
by default, hooks_idx will be equal to: | |
* for encoder-only: 4 equally spread layers | |
* for encoder+decoder: last encoder + 3 equally spread layers of the decoder | |
""" | |
def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768], | |
output_width_ratio=1, num_channels=1, postprocess=None, **kwargs): | |
super(PixelwiseTaskWithDPT, self).__init__() | |
self.return_all_blocks = True # backbone needs to return all layers | |
self.postprocess = postprocess | |
self.output_width_ratio = output_width_ratio | |
self.num_channels = num_channels | |
self.hooks_idx = hooks_idx | |
self.layer_dims = layer_dims | |
def setup(self, croconet): | |
dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels} | |
if self.hooks_idx is None: | |
if hasattr(croconet, 'dec_blocks'): # encoder + decoder | |
step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth] | |
hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)] | |
else: # encoder only | |
step = croconet.enc_depth//4 | |
hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)] | |
self.hooks_idx = hooks_idx | |
print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}') | |
dpt_args['hooks'] = self.hooks_idx | |
dpt_args['layer_dims'] = self.layer_dims | |
self.dpt = DPTOutputAdapter(**dpt_args) | |
dim_tokens = [croconet.enc_embed_dim if hook<croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx] | |
dpt_init_args = {'dim_tokens_enc': dim_tokens} | |
self.dpt.init(**dpt_init_args) | |
def forward(self, x, img_info): | |
out = self.dpt(x, image_size=(img_info['height'],img_info['width'])) | |
if self.postprocess: out = self.postprocess(out) | |
return out |