dmitriitochilkin's picture
add dependencies
ff49a48
raw
history blame
No virus
2.14 kB
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers.models.vit.modeling_vit import ViTModel
from ...utils import BaseModule
class DINOSingleImageTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
enable_gradient_checkpointing: bool = False
cfg: Config
def configure(self) -> None:
self.model: ViTModel = ViTModel(
ViTModel.config_class.from_pretrained(
hf_hub_download(
repo_id=self.cfg.pretrained_model_name_or_path,
filename="config.json",
)
)
)
if self.cfg.enable_gradient_checkpointing:
self.model.encoder.gradient_checkpointing = True
self.register_buffer(
"image_mean",
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"image_std",
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
packed = False
if images.ndim == 4:
packed = True
images = images.unsqueeze(1)
batch_size, n_input_views = images.shape[:2]
images = (images - self.image_mean) / self.image_std
out = self.model(
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
)
local_features, global_features = out.last_hidden_state, out.pooler_output
local_features = local_features.permute(0, 2, 1)
local_features = rearrange(
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
)
if packed:
local_features = local_features.squeeze(1)
return local_features
def detokenize(self, *args, **kwargs):
raise NotImplementedError