from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn from einops import rearrange from huggingface_hub import hf_hub_download from transformers.models.vit.modeling_vit import ViTModel from ...utils import BaseModule class DINOSingleImageTokenizer(BaseModule): @dataclass class Config(BaseModule.Config): pretrained_model_name_or_path: str = "facebook/dino-vitb16" enable_gradient_checkpointing: bool = False cfg: Config def configure(self) -> None: self.model: ViTModel = ViTModel( ViTModel.config_class.from_pretrained( hf_hub_download( repo_id=self.cfg.pretrained_model_name_or_path, filename="config.json", ) ) ) if self.cfg.enable_gradient_checkpointing: self.model.encoder.gradient_checkpointing = True self.register_buffer( "image_mean", torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1), persistent=False, ) self.register_buffer( "image_std", torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1), persistent=False, ) def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor: packed = False if images.ndim == 4: packed = True images = images.unsqueeze(1) batch_size, n_input_views = images.shape[:2] images = (images - self.image_mean) / self.image_std out = self.model( rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True ) local_features, global_features = out.last_hidden_state, out.pooler_output local_features = local_features.permute(0, 2, 1) local_features = rearrange( local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size ) if packed: local_features = local_features.squeeze(1) return local_features def detokenize(self, *args, **kwargs): raise NotImplementedError