Spaces:
Running
on
Zero
Running
on
Zero
# MIT License | |
# | |
# Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# | |
# Based on code from https://github.com/isl-org/DPT | |
"""Flexible configuration and feature extraction of timm VisionTransformers.""" | |
import types | |
import math | |
from typing import Callable | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class AddReadout(nn.Module): | |
def __init__(self, start_index: bool = 1): | |
super(AddReadout, self).__init__() | |
self.start_index = start_index | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
if self.start_index == 2: | |
readout = (x[:, 0] + x[:, 1]) / 2 | |
else: | |
readout = x[:, 0] | |
return x[:, self.start_index:] + readout.unsqueeze(1) | |
class Transpose(nn.Module): | |
def __init__(self, dim0: int, dim1: int): | |
super(Transpose, self).__init__() | |
self.dim0 = dim0 | |
self.dim1 = dim1 | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = x.transpose(self.dim0, self.dim1) | |
return x.contiguous() | |
def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict: | |
_, _, H, W = x.size() | |
_ = pretrained.model.forward_flex(x) | |
return {k: pretrained.rearrange(v) for k, v in activations.items()} | |
def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: | |
posemb_tok, posemb_grid = ( | |
posemb[:, : self.start_index], | |
posemb[0, self.start_index :], | |
) | |
gs_old = int(math.sqrt(len(posemb_grid))) | |
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) | |
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) | |
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |
posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |
return posemb | |
def forward_flex(self, x: torch.Tensor) -> torch.Tensor: | |
# patch proj and dynamically resize | |
B, C, H, W = x.size() | |
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) | |
pos_embed = self._resize_pos_embed( | |
self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] | |
) | |
# add cls token | |
cls_tokens = self.cls_token.expand( | |
x.size(0), -1, -1 | |
) | |
x = torch.cat((cls_tokens, x), dim=1) | |
# forward pass | |
x = x + pos_embed | |
x = self.pos_drop(x) | |
for blk in self.blocks: | |
x = blk(x) | |
x = self.norm(x) | |
return x | |
activations = {} | |
def get_activation(name: str) -> Callable: | |
def hook(model, input, output): | |
activations[name] = output | |
return hook | |
def make_sd_backbone( | |
model: nn.Module, | |
hooks: list[int] = [2, 5, 8, 11], | |
hook_patch: bool = True, | |
start_index: list[int] = 1, | |
): | |
assert len(hooks) == 4 | |
pretrained = nn.Module() | |
pretrained.model = model | |
# add hooks | |
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) | |
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) | |
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) | |
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) | |
if hook_patch: | |
pretrained.model.pos_drop.register_forward_hook(get_activation('4')) | |
# configure readout | |
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) | |
pretrained.model.start_index = start_index | |
pretrained.model.patch_size = patch_size | |
# We inject this function into the VisionTransformer instances so that | |
# we can use it with interpolated position embeddings without modifying the library source. | |
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) | |
pretrained.model._resize_pos_embed = types.MethodType( | |
_resize_pos_embed, pretrained.model | |
) | |
return pretrained | |
def make_vit_backbone( | |
model: nn.Module, | |
patch_size: list[int] = [16, 16], | |
hooks: list[int] = [2, 5, 8, 11], | |
hook_patch: bool = True, | |
start_index: list[int] = 1, | |
): | |
assert len(hooks) == 4 | |
pretrained = nn.Module() | |
pretrained.model = model | |
# add hooks | |
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) | |
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) | |
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) | |
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) | |
if hook_patch: | |
pretrained.model.pos_drop.register_forward_hook(get_activation('4')) | |
# configure readout | |
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) | |
pretrained.model.start_index = start_index | |
pretrained.model.patch_size = patch_size | |
# We inject this function into the VisionTransformer instances so that | |
# we can use it with interpolated position embeddings without modifying the library source. | |
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) | |
pretrained.model._resize_pos_embed = types.MethodType( | |
_resize_pos_embed, pretrained.model | |
) | |
return pretrained | |