import functools import logging import typing import beartype import torch from jaxtyping import Float, jaxtyped from torch import Tensor from torchvision.transforms import v2 logger = logging.getLogger("modeling.py") @jaxtyped(typechecker=beartype.beartype) class SplitDinov2(torch.nn.Module): def __init__(self, *, split_at: int): super().__init__() self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval() self.split_at = split_at def forward_start( self, x: Float[Tensor, "batch channels width height"] ) -> Float[Tensor, "batch total_patches dim"]: x_BPD = self.vit.prepare_tokens_with_masks(x) for blk in self.vit.blocks[: self.split_at]: x_BPD = blk(x_BPD) return x_BPD def forward_end( self, x_BPD: Float[Tensor, "batch total_patches dim"] ) -> Float[Tensor, "batch patches dim"]: for blk in self.vit.blocks[-self.split_at :]: x_BPD = blk(x_BPD) x_BPD = self.vit.norm(x_BPD) return x_BPD[:, self.vit.num_register_tokens + 1 :] @functools.cache def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]: vit = SplitDinov2(split_at=11).to(device) vit_transform = v2.Compose([ v2.Resize(size=(256, 256)), v2.CenterCrop(size=(224, 224)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), ]) logger.info("Loaded ViT.") return vit, vit_transform