| from dataclasses import dataclass |
| from typing import Literal |
|
|
| import torch |
| from einops import rearrange, repeat |
| from jaxtyping import Float |
| from torch import Tensor, nn |
|
|
| from src.dataset.types import BatchedViews |
| from .backbone import Backbone |
| from .backbone_resnet import BackboneResnet, BackboneResnetCfg |
|
|
|
|
| @dataclass |
| class BackboneDinoCfg: |
| name: Literal["dino"] |
| model: Literal["dino_vits16", "dino_vits8", "dino_vitb16", "dino_vitb8"] |
| d_out: int |
|
|
|
|
| class BackboneDino(Backbone[BackboneDinoCfg]): |
| def __init__(self, cfg: BackboneDinoCfg, d_in: int) -> None: |
| super().__init__(cfg) |
| assert d_in == 3 |
| self.dino = torch.hub.load("facebookresearch/dino:main", cfg.model) |
| self.resnet_backbone = BackboneResnet( |
| BackboneResnetCfg("resnet", "dino_resnet50", 4, False, cfg.d_out), |
| d_in, |
| ) |
| self.global_token_mlp = nn.Sequential( |
| nn.Linear(768, 768), |
| nn.ReLU(), |
| nn.Linear(768, cfg.d_out), |
| ) |
| self.local_token_mlp = nn.Sequential( |
| nn.Linear(768, 768), |
| nn.ReLU(), |
| nn.Linear(768, cfg.d_out), |
| ) |
|
|
| def forward( |
| self, |
| context: BatchedViews, |
| ) -> Float[Tensor, "batch view d_out height width"]: |
| |
| resnet_features = self.resnet_backbone(context) |
|
|
| |
| b, v, _, h, w = context["image"].shape |
| assert h % self.patch_size == 0 and w % self.patch_size == 0 |
| tokens = rearrange(context["image"], "b v c h w -> (b v) c h w") |
| tokens = self.dino.get_intermediate_layers(tokens)[0] |
| global_token = self.global_token_mlp(tokens[:, 0]) |
| local_tokens = self.local_token_mlp(tokens[:, 1:]) |
|
|
| |
| global_token = repeat(global_token, "(b v) c -> b v c h w", b=b, v=v, h=h, w=w) |
|
|
| |
| local_tokens = repeat( |
| local_tokens, |
| "(b v) (h w) c -> b v c (h hps) (w wps)", |
| b=b, |
| v=v, |
| h=h // self.patch_size, |
| hps=self.patch_size, |
| w=w // self.patch_size, |
| wps=self.patch_size, |
| ) |
|
|
| return resnet_features + local_tokens + global_token |
|
|
| @property |
| def patch_size(self) -> int: |
| return int("".join(filter(str.isdigit, self.cfg.model))) |
|
|
| @property |
| def d_out(self) -> int: |
| return self.cfg.d_out |
|
|