| """ |
| DINOv2 vision backbone wrapper. |
| |
| Features: |
| - Loads DINOv2 variants via torch.hub (with local fallback) |
| - Exposes patch token features (x_norm_patchtokens) |
| - Provides preprocessing (resize + normalization) for multi-view PIL images |
| - Parallel per-view preprocessing using ThreadPoolExecutor |
| """ |
|
|
| from collections import OrderedDict |
| import os |
|
|
| from concurrent.futures import ThreadPoolExecutor |
| import torch |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from torchvision.models._utils import IntermediateLayerGetter |
| from typing import Dict, List |
| from torchvision import transforms |
|
|
|
|
| def apply_transform(view, transform): |
| return transform(view) |
|
|
|
|
| |
|
|
|
|
| class DINOv2BackBone(nn.Module): |
| """ |
| Thin wrapper around a DINOv2 model. |
| |
| Args: |
| backone_name: DINOv2 model id (e.g. dinov2_vits14, dinov2_vitb14). |
| output_channels: (Unused placeholder; retained for future extension). |
| |
| Attributes: |
| body: Loaded DINOv2 model. |
| num_channels: Feature dimension of patch tokens. |
| dino_transform: Preprocessing pipeline (resize + tensor + normalize). |
| """ |
|
|
| def __init__(self, backone_name="dinov2_vits14", output_channels=1024) -> None: |
| super().__init__() |
| try: |
| self.body = torch.hub.load("facebookresearch/dinov2", backone_name) |
| except: |
| import traceback |
|
|
| traceback.print_exc() |
| print(f"Failed to load dinov2 from torch hub, loading from local") |
| TORCH_HOME = os.environ.get("TORCH_HOME", "~/.cache/torch/") |
| weights_path = os.path.expanduser(f"{TORCH_HOME}/hub/checkpoints/{backone_name}_pretrain.pth") |
|
|
| code_path = os.path.expanduser(f"{TORCH_HOME}/hub/facebookresearch_dinov2_main") |
|
|
| self.body = torch.hub.load(code_path, backone_name, source="local", pretrained=False) |
|
|
| state_dict = torch.load(weights_path) |
| self.body.load_state_dict(state_dict) |
| if backone_name == "dinov2_vits14": |
| self.num_channels = 384 |
| elif backone_name == "dinov2_vitb14": |
| self.num_channels = 768 |
| elif backone_name == "dinov2_vitl14": |
| self.num_channels = 1024 |
| elif backone_name == "dinov2_vitg14": |
| self.num_channels = 1408 |
| else: |
| raise NotImplementedError(f"DINOv2 backbone {backone_name} not implemented") |
| self.dino_transform = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
| |
|
|
| |
| def forward(self, tensor): |
| """ |
| Forward pass. |
| |
| Args: |
| tensor: Image batch tensor [B*views, 3, H, W]. |
| |
| Returns: |
| torch.Tensor: Patch token features [B*views, N_tokens, C]. |
| """ |
| xs = self.body.forward_features(tensor)["x_norm_patchtokens"] |
|
|
| return xs |
|
|
| def prepare_dino_input(self, img_list): |
| """ |
| Preprocess a batch of multi-view PIL image lists into a tensor suitable for DINO. |
| |
| Args: |
| img_list: List of samples; each sample is List[PIL.Image] (multi-view). |
| |
| Returns: |
| torch.Tensor: Flattened batch of shape [B * num_view, 3, H, W] on model device. |
| """ |
| |
| |
|
|
| |
| with ThreadPoolExecutor() as executor: |
| image_tensors = torch.stack( |
| [ |
| torch.stack(list(executor.map(lambda view: apply_transform(view, self.dino_transform), views))) |
| for views in img_list |
| ] |
| ) |
|
|
| |
| B, num_view, C, H, W = image_tensors.shape |
| image_tensors = image_tensors.view(B * num_view, C, H, W) |
| device = next(self.parameters()).device |
| image_tensors = image_tensors.to(device) |
|
|
| return image_tensors |
|
|
|
|
| def get_dino_model(backone_name="dinov2_vits14") -> DINOv2BackBone: |
| """ |
| Factory helper returning a configured DINOv2BackBone. |
| |
| Args: |
| backone_name: DINOv2 variant name. |
| |
| Returns: |
| DINOv2BackBone: Initialized backbone instance. |
| """ |
| return DINOv2BackBone(backone_name) |
|
|
|
|
| if __name__ == "__main__": |
| dino = DINOv2BackBone() |
| pass |
|
|