from transformers import PreTrainedModel import timm import torch.nn as nn import numpy as np from .model_config import FocusOnDepthConfig from .reassemble import Reassemble from .fusion import Fusion from .head import HeadDepth, HeadSeg class FocusOnDepth(PreTrainedModel): config_class = FocusOnDepthConfig def __init__(self, config): super().__init__(config) self.transformer_encoders = timm.create_model(config.model_timm, pretrained=True) self.type_ = config.type_ #Register hooks self.activation = {} self.hooks = config.hooks self._get_layers_from_hooks(self.hooks) #Reassembles Fusion self.reassembles = [] self.fusions = [] for s in config.reassemble_s: self.reassembles.append(Reassemble(config.image_size, config.read, config.patch_size, s, config.emb_dim, config.resample_dim)) self.fusions.append(Fusion(config.resample_dim)) self.reassembles = nn.ModuleList(self.reassembles) self.fusions = nn.ModuleList(self.fusions) #Head if self.type_ == "full": self.head_depth = HeadDepth(config.resample_dim) self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses) elif self.type_ == "depth": self.head_depth = HeadDepth(config.resample_dim) self.head_segmentation = None else: self.head_depth = None self.head_segmentation = HeadSeg(config.resample_dim, nclasses=config.nclasses) def forward(self, img): _ = self.transformer_encoders(img) previous_stage = None for i in np.arange(len(self.fusions)-1, -1, -1): hook_to_take = 't'+str(self.hooks[i]) activation_result = self.activation[hook_to_take] reassemble_result = self.reassembles[i](activation_result) fusion_result = self.fusions[i](reassemble_result, previous_stage) previous_stage = fusion_result out_depth = None out_segmentation = None if self.head_depth != None: out_depth = self.head_depth(previous_stage) if self.head_segmentation != None: out_segmentation = self.head_segmentation(previous_stage) return out_depth, out_segmentation def _get_layers_from_hooks(self, hooks): def get_activation(name): def hook(model, input, output): self.activation[name] = output return hook for h in hooks: self.transformer_encoders.blocks[h].register_forward_hook(get_activation('t'+str(h)))