Mapper / mapper /models /feature_extractor_DPT.py
Cherie Ho
Initial upload
fd01725
raw
history blame
2.77 kB
from .base import BaseModel
from .schema import DINOConfiguration
import logging
import torch
import torch.nn as nn
import sys
import re
import os
from .dinov2.eval.depth.ops.wrappers import resize
from .dinov2.hub.backbones import dinov2_vitb14_reg
module_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(module_dir)
logger = logging.getLogger(__name__)
class FeatureExtractor(BaseModel):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def build_encoder(self, conf: DINOConfiguration):
BACKBONE_SIZE = "small"
backbone_archs = {
"small": "vits14",
"base": "vitb14", # this one
"large": "vitl14",
"giant": "vitg14",
}
backbone_arch = backbone_archs[BACKBONE_SIZE]
self.crop_size = int(re.search(r"\d+", backbone_arch).group())
backbone_name = f"dinov2_{backbone_arch}"
self.backbone_model = dinov2_vitb14_reg(
pretrained=conf.pretrained, drop_path_rate=0.1)
if conf.frozen:
for param in self.backbone_model.patch_embed.parameters():
param.requires_grad = False
for i in range(0, 10):
for param in self.backbone_model.blocks[i].parameters():
param.requires_grad = False
self.backbone_model.blocks[i].drop_path1 = nn.Identity()
self.backbone_model.blocks[i].drop_path2 = nn.Identity()
self.feat_projection = torch.nn.Conv2d(
768, conf.output_dim, kernel_size=1)
return self.backbone_model
def _init(self, conf: DINOConfiguration):
# Preprocessing
self.register_buffer("mean_", torch.tensor(
self.mean), persistent=False)
self.register_buffer("std_", torch.tensor(self.std), persistent=False)
self.build_encoder(conf)
def _forward(self, data):
_, _, h, w = data["image"].shape
h_num_patches = h // self.crop_size
w_num_patches = w // self.crop_size
h_dino = h_num_patches * self.crop_size
w_dino = w_num_patches * self.crop_size
image = resize(data["image"], (h_dino, w_dino))
image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
output = self.backbone_model.forward_features(
image)['x_norm_patchtokens']
output = output.reshape(-1, h_num_patches,
w_num_patches, output.shape[-1])
output = output.permute(0, 3, 1, 2) # channel first
output = self.feat_projection(output)
camera = data['camera'].to(data["image"].device, non_blocking=True)
camera = camera.scale(output.shape[-1] / data["image"].shape[-1])
return output, camera