from .base import BaseModel from .schema import DINOConfiguration import logging import torch import torch.nn as nn import sys import re import os from .dinov2.eval.depth.ops.wrappers import resize from .dinov2.hub.backbones import dinov2_vitb14_reg module_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(module_dir) logger = logging.getLogger(__name__) class FeatureExtractor(BaseModel): mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] def build_encoder(self, conf: DINOConfiguration): BACKBONE_SIZE = "small" backbone_archs = { "small": "vits14", "base": "vitb14", # this one "large": "vitl14", "giant": "vitg14", } backbone_arch = backbone_archs[BACKBONE_SIZE] self.crop_size = int(re.search(r"\d+", backbone_arch).group()) backbone_name = f"dinov2_{backbone_arch}" self.backbone_model = dinov2_vitb14_reg( pretrained=conf.pretrained, drop_path_rate=0.1) if conf.frozen: for param in self.backbone_model.patch_embed.parameters(): param.requires_grad = False for i in range(0, 10): for param in self.backbone_model.blocks[i].parameters(): param.requires_grad = False self.backbone_model.blocks[i].drop_path1 = nn.Identity() self.backbone_model.blocks[i].drop_path2 = nn.Identity() self.feat_projection = torch.nn.Conv2d( 768, conf.output_dim, kernel_size=1) return self.backbone_model def _init(self, conf: DINOConfiguration): # Preprocessing self.register_buffer("mean_", torch.tensor( self.mean), persistent=False) self.register_buffer("std_", torch.tensor(self.std), persistent=False) self.build_encoder(conf) def _forward(self, data): _, _, h, w = data["image"].shape h_num_patches = h // self.crop_size w_num_patches = w // self.crop_size h_dino = h_num_patches * self.crop_size w_dino = w_num_patches * self.crop_size image = resize(data["image"], (h_dino, w_dino)) image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] output = self.backbone_model.forward_features( image)['x_norm_patchtokens'] output = output.reshape(-1, h_num_patches, w_num_patches, output.shape[-1]) output = output.permute(0, 3, 1, 2) # channel first output = self.feat_projection(output) camera = data['camera'].to(data["image"].device, non_blocking=True) camera = camera.scale(output.shape[-1] / data["image"].shape[-1]) return output, camera