| 
							 | 
						from collections import defaultdict | 
					
					
						
						| 
							 | 
						from contextlib import contextmanager | 
					
					
						
						| 
							 | 
						from logging import getLogger | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						from typing import List, Union, Iterable | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						from torch import nn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from timm.models import VisionTransformer | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from .extra_models import DinoWrapper | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						DEFAULT_NUM_WINDOWED = 5 | 
					
					
						
						| 
							 | 
						DEFAULT_NUM_GLOBAL = 4 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class VitDetArgs: | 
					
					
						
						| 
							 | 
						    def __init__(self, | 
					
					
						
						| 
							 | 
						                 window_size: int, | 
					
					
						
						| 
							 | 
						                 num_summary_tokens: int, | 
					
					
						
						| 
							 | 
						                 num_windowed: int = None, | 
					
					
						
						| 
							 | 
						                 num_global: int = None, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        self.window_size = window_size | 
					
					
						
						| 
							 | 
						        self.num_summary_tokens = num_summary_tokens | 
					
					
						
						| 
							 | 
						        self.num_windowed = num_windowed | 
					
					
						
						| 
							 | 
						        self.num_global = num_global | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def apply_vitdet_arch(model: Union[VisionTransformer, DinoWrapper], args: VitDetArgs): | 
					
					
						
						| 
							 | 
						    if isinstance(model, VisionTransformer): | 
					
					
						
						| 
							 | 
						        patch_embed = getattr(model, 'patch_generator', model.patch_embed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return ViTDetHook(patch_embed, model.blocks, args) | 
					
					
						
						| 
							 | 
						    elif isinstance(model, DinoWrapper): | 
					
					
						
						| 
							 | 
						        inner = model.inner | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patch_embed = getattr(inner, 'patch_generator', inner.patch_embed) | 
					
					
						
						| 
							 | 
						        return ViTDetHook(patch_embed, inner.blocks, args) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class ViTDetHook: | 
					
					
						
						| 
							 | 
						    def __init__(self, | 
					
					
						
						| 
							 | 
						                 embedder: nn.Module, | 
					
					
						
						| 
							 | 
						                 blocks: nn.Sequential, | 
					
					
						
						| 
							 | 
						                 args: VitDetArgs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        self.blocks = blocks | 
					
					
						
						| 
							 | 
						        self.num_summary_tokens = args.num_summary_tokens | 
					
					
						
						| 
							 | 
						        self.window_size = args.window_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self._input_resolution = None | 
					
					
						
						| 
							 | 
						        self._num_windows = None | 
					
					
						
						| 
							 | 
						        self._cls_patch = None | 
					
					
						
						| 
							 | 
						        self._order_cache = dict() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        embedder.register_forward_pre_hook(self._enter_model) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        blocks.register_forward_pre_hook(self._enter_blocks) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        is_global = True | 
					
					
						
						| 
							 | 
						        if args.num_windowed is not None: | 
					
					
						
						| 
							 | 
						            period = args.num_windowed + 1 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            num_global = args.num_global or DEFAULT_NUM_GLOBAL | 
					
					
						
						| 
							 | 
						            period = max(len(blocks) // num_global, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for i, layer in enumerate(blocks[:-1]): | 
					
					
						
						| 
							 | 
						            ctr = i % period | 
					
					
						
						| 
							 | 
						            if ctr == 0: | 
					
					
						
						| 
							 | 
						                layer.register_forward_pre_hook(self._to_windows) | 
					
					
						
						| 
							 | 
						                is_global = False | 
					
					
						
						| 
							 | 
						            elif ctr == period - 1: | 
					
					
						
						| 
							 | 
						                layer.register_forward_pre_hook(self._to_global) | 
					
					
						
						| 
							 | 
						                is_global = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if not is_global: | 
					
					
						
						| 
							 | 
						            blocks[-1].register_forward_pre_hook(self._to_global) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        blocks.register_forward_hook(self._exit_model) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _enter_model(self, _, input: List[torch.Tensor]): | 
					
					
						
						| 
							 | 
						        self._input_resolution = input[0].shape[-2:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _enter_blocks(self, _, input: List[torch.Tensor]): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patches = input[0] | 
					
					
						
						| 
							 | 
						        patches = self._rearrange_patches(patches) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return (patches,) + input[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _to_windows(self, _, input: List[torch.Tensor]): | 
					
					
						
						| 
							 | 
						        patches = input[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.num_summary_tokens: | 
					
					
						
						| 
							 | 
						            self._cls_patch = patches[:, :self.num_summary_tokens] | 
					
					
						
						| 
							 | 
						            patches = patches[:, self.num_summary_tokens:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patches = rearrange( | 
					
					
						
						| 
							 | 
						            patches, 'b (p t) c -> (b p) t c', | 
					
					
						
						| 
							 | 
						            p=self._num_windows, t=self.window_size ** 2, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return (patches,) + input[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _to_global(self, _, input: List[torch.Tensor]): | 
					
					
						
						| 
							 | 
						        patches = input[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patches = rearrange( | 
					
					
						
						| 
							 | 
						            patches, '(b p) t c -> b (p t) c', | 
					
					
						
						| 
							 | 
						            p=self._num_windows, t=self.window_size ** 2, | 
					
					
						
						| 
							 | 
						            b=patches.shape[0] // self._num_windows, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.num_summary_tokens: | 
					
					
						
						| 
							 | 
						            patches = torch.cat([ | 
					
					
						
						| 
							 | 
						                self._cls_patch, | 
					
					
						
						| 
							 | 
						                patches, | 
					
					
						
						| 
							 | 
						            ], dim=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return (patches,) + input[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        patch_order = self._order_cache[self._input_resolution][0] | 
					
					
						
						| 
							 | 
						        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        ret_patches = torch.empty_like(patches) | 
					
					
						
						| 
							 | 
						        ret_patches = torch.scatter( | 
					
					
						
						| 
							 | 
						            ret_patches, | 
					
					
						
						| 
							 | 
						            dim=1, | 
					
					
						
						| 
							 | 
						            index=patch_order, | 
					
					
						
						| 
							 | 
						            src=patches, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return ret_patches | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def _rearrange_patches(self, patches: torch.Tensor): | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None)) | 
					
					
						
						| 
							 | 
						        if patch_order is None: | 
					
					
						
						| 
							 | 
						            num_feat_patches = patches.shape[1] - self.num_summary_tokens | 
					
					
						
						| 
							 | 
						            num_pixels = self._input_resolution[0] * self._input_resolution[1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            patch_size = int(round(math.sqrt(num_pixels / num_feat_patches))) | 
					
					
						
						| 
							 | 
						            rows = self._input_resolution[-2] // patch_size | 
					
					
						
						| 
							 | 
						            cols = self._input_resolution[-1] // patch_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            w_rows = rows // self.window_size | 
					
					
						
						| 
							 | 
						            w_cols = cols // self.window_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            patch_order = torch.arange(0, num_feat_patches, device=patches.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            patch_order = rearrange( | 
					
					
						
						| 
							 | 
						                patch_order, '(wy py wx px) -> (wy wx py px)', | 
					
					
						
						| 
							 | 
						                wy=w_rows, wx=w_cols, | 
					
					
						
						| 
							 | 
						                py=self.window_size, px=self.window_size, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if self.num_summary_tokens: | 
					
					
						
						| 
							 | 
						                patch_order = torch.cat([ | 
					
					
						
						| 
							 | 
						                    torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device), | 
					
					
						
						| 
							 | 
						                    patch_order + self.num_summary_tokens, | 
					
					
						
						| 
							 | 
						                ]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            self._num_windows = w_rows * w_cols | 
					
					
						
						| 
							 | 
						            self._order_cache[self._input_resolution] = ( | 
					
					
						
						| 
							 | 
						                patch_order, | 
					
					
						
						| 
							 | 
						                self._num_windows, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patch_order = patch_order.reshape(1, -1, 1).expand_as(patches) | 
					
					
						
						| 
							 | 
						        patches = torch.gather(patches, dim=1, index=patch_order) | 
					
					
						
						| 
							 | 
						        return patches | 
					
					
						
						| 
							 | 
						
 |