# -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu, Yutong Lin, Yixuan Wei # -------------------------------------------------------- # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py import logging import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec logger = logging.getLogger('gwm') import argparse import torch import torchvision.transforms from torch import nn from torchvision import transforms import torch.nn.modules.utils as nn_utils import math import timm import types from pathlib import Path from typing import Union, List, Tuple from PIL import Image import einops class ViTExtractor: """ This class facilitates extraction of features, descriptors, and saliency maps from a ViT. We use the following notation in the documentation of the module's methods: B - batch size h - number of heads. usually takes place of the channel dimension in pytorch's convention BxCxHxW p - patch size of the ViT. either 8 or 16. t - number of tokens. equals the number of patches + 1, e.g. HW / p**2 + 1. Where H and W are the height and width of the input image. d - the embedding dimension in the ViT. """ def __init__(self, model_type: str = 'dino_vits8', stride: int = 4, model: nn.Module = None, device: str = 'cuda'): """ :param model_type: A string specifying the type of model to extract from. [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] :param stride: stride of first convolution layer. small stride -> higher resolution. :param model: Optional parameter. The nn.Module to extract from instead of creating a new one in ViTExtractor. should be compatible with model_type. """ self.model_type = model_type self.device = device if model is not None: self.model = model else: self.model = ViTExtractor.create_model(model_type) self.model = ViTExtractor.patch_vit_resolution(self.model, stride=stride) # self.model.eval() self.model.to(self.device) self.p = self.model.patch_embed.patch_size self.stride = self.model.patch_embed.proj.stride self.mean = (0.485, 0.456, 0.406) if "dino" in self.model_type else (0.5, 0.5, 0.5) self.std = (0.229, 0.224, 0.225) if "dino" in self.model_type else (0.5, 0.5, 0.5) self._feats = [] self.hook_handlers = [] self.load_size = None self.num_patches = None @staticmethod def create_model(model_type: str) -> nn.Module: """ :param model_type: a string specifying which model to load. [dino_vits8 | dino_vits16 | dino_vitb8 | dino_vitb16 | vit_small_patch8_224 | vit_small_patch16_224 | vit_base_patch8_224 | vit_base_patch16_224] :return: the model """ if 'dino' in model_type: model = torch.hub.load('facebookresearch/dino:main', model_type) else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images). temp_model = timm.create_model(model_type, pretrained=True) model_type_dict = { 'vit_small_patch16_224': 'dino_vits16', 'vit_small_patch8_224': 'dino_vits8', 'vit_base_patch16_224': 'dino_vitb16', 'vit_base_patch8_224': 'dino_vitb8' } model = torch.hub.load('facebookresearch/dino:main', model_type_dict[model_type]) temp_state_dict = temp_model.state_dict() del temp_state_dict['head.weight'] del temp_state_dict['head.bias'] model.load_state_dict(temp_state_dict) return model @staticmethod def _fix_pos_enc(patch_size: int, stride_hw: Tuple[int, int]): """ Creates a method for position encoding interpolation. :param patch_size: patch size of the model. :param stride_hw: A tuple containing the new height and width stride respectively. :return: the interpolation method """ def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor: npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed class_pos_embed = self.pos_embed[:, 0] patch_pos_embed = self.pos_embed[:, 1:] dim = x.shape[-1] # compute number of tokens taking stride into account w0 = 1 + (w - patch_size) // stride_hw[1] h0 = 1 + (h - patch_size) // stride_hw[0] assert (w0 * h0 == npatch), f"""got wrong grid size for {h}x{w} with patch_size {patch_size} and stride {stride_hw} got {h0}x{w0}={h0 * w0} expecting {npatch}""" # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), mode='bicubic', align_corners=False, recompute_scale_factor=False ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) return interpolate_pos_encoding @staticmethod def patch_vit_resolution(model: nn.Module, stride: int) -> nn.Module: """ change resolution of model output by changing the stride of the patch extraction. :param model: the model to change resolution for. :param stride: the new stride parameter. :return: the adjusted model """ patch_size = model.patch_embed.patch_size if stride == patch_size: # nothing to do return model stride = nn_utils._pair(stride) assert all([(patch_size // s_) * s_ == patch_size for s_ in stride]), f'stride {stride} should divide patch_size {patch_size}' # fix the stride model.patch_embed.proj.stride = stride # fix the positional encoding code model.interpolate_pos_encoding = types.MethodType(ViTExtractor._fix_pos_enc(patch_size, stride), model) return model def preprocess(self, image_path: Union[str, Path], load_size: Union[int, Tuple[int, int]] = None) -> Tuple[torch.Tensor, Image.Image]: """ Preprocesses an image before extraction. :param image_path: path to image to be extracted. :param load_size: optional. Size to resize image before the rest of preprocessing. :return: a tuple containing: (1) the preprocessed image as a tensor to insert the model of shape BxCxHxW. (2) the pil image in relevant dimensions """ pil_image = Image.open(image_path).convert('RGB') if load_size is not None: pil_image = transforms.Resize(load_size, interpolation=transforms.InterpolationMode.LANCZOS)(pil_image) prep = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=self.mean, std=self.std) ]) prep_img = prep(pil_image)[None, ...] return prep_img, pil_image def _get_hook(self, facet: str): """ generate a hook method for a specific block and facet. """ if facet in ['attn', 'token']: def _hook(model, input, output): self._feats.append(output) return _hook if facet == 'query': facet_idx = 0 elif facet == 'key': facet_idx = 1 elif facet == 'value': facet_idx = 2 else: raise TypeError(f"{facet} is not a supported facet.") def _inner_hook(module, input, output): input = input[0] B, N, C = input.shape qkv = module.qkv(input).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4) self._feats.append(qkv[facet_idx]) #Bxhxtxd return _inner_hook def _register_hooks(self, layers: List[int], facet: str) -> None: """ register hook to extract features. :param layers: layers from which to extract features. :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] """ for block_idx, block in enumerate(self.model.blocks): if block_idx in layers: if facet == 'token': self.hook_handlers.append(block.register_forward_hook(self._get_hook(facet))) elif facet == 'attn': self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_hook(facet))) elif facet in ['key', 'query', 'value']: self.hook_handlers.append(block.attn.register_forward_hook(self._get_hook(facet))) else: raise TypeError(f"{facet} is not a supported facet.") def _unregister_hooks(self) -> None: """ unregisters the hooks. should be called after feature extraction. """ for handle in self.hook_handlers: handle.remove() self.hook_handlers = [] def _extract_features(self, batch: torch.Tensor, layers: List[int] = 11, facet: str = 'key') -> List[torch.Tensor]: """ extract features from the model :param batch: batch to extract features for. Has shape BxCxHxW. :param layers: layer to extract. A number between 0 to 11. :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token' | 'attn'] :return : tensor of features. if facet is 'key' | 'query' | 'value' has shape Bxhxtxd if facet is 'attn' has shape Bxhxtxt if facet is 'token' has shape Bxtxd """ B, C, H, W = batch.shape self._feats = [] self._register_hooks(layers, facet) with torch.no_grad(): _ = self.model(batch) self._unregister_hooks() self.load_size = (H, W) self.num_patches = (1 + (H - self.p) // self.stride[0], 1 + (W - self.p) // self.stride[1]) return self._feats def _log_bin(self, x: torch.Tensor, hierarchy: int = 2) -> torch.Tensor: """ create a log-binned descriptor. :param x: tensor of features. Has shape Bxhxtxd. :param hierarchy: how many bin hierarchies to use. """ B = x.shape[0] num_bins = 1 + 8 * hierarchy bin_x = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1) # Bx(t-1)x(dxh) bin_x = bin_x.permute(0, 2, 1) bin_x = bin_x.reshape(B, bin_x.shape[1], self.num_patches[0], self.num_patches[1]) # Bx(dxh)xnum_patches[0]xnum_patches[1] sub_desc_dim = bin_x.shape[1] avg_pools = [] # compute bins of all sizes for all spatial locations. for k in range(0, hierarchy): # avg pooling with kernel 3**kx3**k win_size = 3 ** k avg_pool = torch.nn.AvgPool2d(win_size, stride=1, padding=win_size // 2, count_include_pad=False) avg_pools.append(avg_pool(bin_x)) bin_x = torch.zeros((B, sub_desc_dim * num_bins, self.num_patches[0], self.num_patches[1])).to(self.device) for y in range(self.num_patches[0]): for x in range(self.num_patches[1]): part_idx = 0 # fill all bins for a spatial location (y, x) for k in range(0, hierarchy): kernel_size = 3 ** k for i in range(y - kernel_size, y + kernel_size + 1, kernel_size): for j in range(x - kernel_size, x + kernel_size + 1, kernel_size): if i == y and j == x and k != 0: continue if 0 <= i < self.num_patches[0] and 0 <= j < self.num_patches[1]: bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ :, :, i, j] else: # handle padding in a more delicate way than zero padding temp_i = max(0, min(i, self.num_patches[0] - 1)) temp_j = max(0, min(j, self.num_patches[1] - 1)) bin_x[:, part_idx * sub_desc_dim: (part_idx + 1) * sub_desc_dim, y, x] = avg_pools[k][ :, :, temp_i, temp_j] part_idx += 1 bin_x = bin_x.flatten(start_dim=-2, end_dim=-1).permute(0, 2, 1).unsqueeze(dim=1) # Bx1x(t-1)x(dxh) return bin_x def extract_descriptors(self, batch: torch.Tensor, layer: int = 11, facet: str = 'key', bin: bool = False, include_cls: bool = False) -> torch.Tensor: """ extract descriptors from the model :param batch: batch to extract descriptors for. Has shape BxCxHxW. :param layers: layer to extract. A number between 0 to 11. :param facet: facet to extract. One of the following options: ['key' | 'query' | 'value' | 'token'] :param bin: apply log binning to the descriptor. default is False. :return: tensor of descriptors. Bx1xtxd' where d' is the dimension of the descriptors. """ assert facet in ['key', 'query', 'value', 'token'], f"""{facet} is not a supported facet for descriptors. choose from ['key' | 'query' | 'value' | 'token'] """ self._extract_features(batch, [layer], facet) x = self._feats[0] if facet == 'token': x.unsqueeze_(dim=1) #Bx1xtxd if not include_cls: x = x[:, :, 1:, :] # remove cls token else: assert not bin, "bin = True and include_cls = True are not supported together, set one of them False." if not bin: desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) else: desc = self._log_bin(x) return desc def extract_saliency_maps(self, batch: torch.Tensor) -> torch.Tensor: """ extract saliency maps. The saliency maps are extracted by averaging several attention heads from the last layer in of the CLS token. All values are then normalized to range between 0 and 1. :param batch: batch to extract saliency maps for. Has shape BxCxHxW. :return: a tensor of saliency maps. has shape Bxt-1 """ assert self.model_type == "dino_vits8", f"saliency maps are supported only for dino_vits model_type." self._extract_features(batch, [11], 'attn') head_idxs = [0, 2, 4, 5] curr_feats = self._feats[0] #Bxhxtxt cls_attn_map = curr_feats[:, head_idxs, 0, 1:].mean(dim=1) #Bx(t-1) temp_mins, temp_maxs = cls_attn_map.min(dim=1)[0], cls_attn_map.max(dim=1)[0] cls_attn_maps = (cls_attn_map - temp_mins) / (temp_maxs - temp_mins) # normalize to range [0,1] return cls_attn_maps @BACKBONE_REGISTRY.register() class D2ViTTransformer(Backbone): def __init__(self, cfg, input_shape): pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE patch_size = cfg.MODEL.SWIN.PATCH_SIZE in_chans = 3 embed_dim = cfg.MODEL.SWIN.EMBED_DIM depths = cfg.MODEL.SWIN.DEPTHS num_heads = cfg.MODEL.SWIN.NUM_HEADS window_size = cfg.MODEL.SWIN.WINDOW_SIZE mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO qkv_bias = cfg.MODEL.SWIN.QKV_BIAS qk_scale = cfg.MODEL.SWIN.QK_SCALE drop_rate = cfg.MODEL.SWIN.DROP_RATE attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE norm_layer = nn.LayerNorm ape = cfg.MODEL.SWIN.APE patch_norm = cfg.MODEL.SWIN.PATCH_NORM frozen_stages = cfg.MODEL.BACKBONE.FREEZE_AT super().__init__() self.num_layers = 12 num_features = [int(embed_dim) for i in range(self.num_layers)] self.num_features = num_features self.frozen_stages = frozen_stages self.extractor = ViTExtractor( model_type='dino_vitb8', stride = 4, model = None, device = cfg.MODEL.DEVICE) if self.frozen_stages >= 0: for block_idx, block in enumerate(self.extractor.model.blocks): if block_idx <= self.frozen_stages: block.eval() for p in block.parameters(): p.requires_grad = False for block_idx, block in enumerate(self.extractor.model.blocks): if all(p.requires_grad == False for p in block.parameters()): print(f"DINO {block_idx} frozen") self._out_features = cfg.MODEL.SWIN.OUT_FEATURES self._out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } self._out_feature_channels = { "res2": self.num_features[0], "res3": self.num_features[1], "res4": self.num_features[2], "res5": self.num_features[3], } def forward(self, x): facet = 'key' self.extractor._extract_features(x, [5, 7, 9, 11], facet=facet) res2 = self.extractor._feats[0].unsqueeze_(dim=1) # Bx1xtxd res3 = self.extractor._feats[1].unsqueeze_(dim=1) # Bx1xtxd res4 = self.extractor._feats[2].unsqueeze_(dim=1) # Bx1xtxd res5 = self.extractor._feats[3].unsqueeze_(dim=1) # Bx1xtxd if facet == 'key': res2 = einops.rearrange(res2, 'b c h t d -> b c t (d h)') # Bx1xtxd res3 = einops.rearrange(res3, 'b c h t d -> b c t (d h)') # Bx1xtxd res4 = einops.rearrange(res4, 'b c h t d -> b c t (d h)') # Bx1xtxd res5 = einops.rearrange(res5, 'b c h t d -> b c t (d h)') # Bx1xtxd res2 = res2.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) res3 = res3.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) res4 = res4.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) res5 = res5.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1).unsqueeze(dim=1) # Bx1xtx(dxh) res2 = res2[:, :, 1:, :] # remove cls token res3 = res3[:, :, 1:, :] # remove cls token res4 = res4[:, :, 1:, :] # remove cls token res5 = res5[:, :, 1:, :] # remove cls token res2 = res2.reshape(res2.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) res3 = res3.reshape(res3.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) res4 = res4.reshape(res4.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) res5 = res5.reshape(res5.size(0), *self.extractor.num_patches, -1).permute(0, 3, 1, 2) return { "res2": res2, "res3": res3, "res4": res4, "res5": res5, } def output_shape(self): return { name: ShapeSpec( channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] ) for name in self._out_features } @property def size_divisibility(self): return 32