TextureUpscaleBeta / models /vit_utils.py
NightRaven109's picture
Upload 73 files
6ecc7d4 verified
# MIT License
#
# Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Based on code from https://github.com/isl-org/DPT
"""Flexible configuration and feature extraction of timm VisionTransformers."""
import types
import math
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class AddReadout(nn.Module):
def __init__(self, start_index: bool = 1):
super(AddReadout, self).__init__()
self.start_index = start_index
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.start_index == 2:
readout = (x[:, 0] + x[:, 1]) / 2
else:
readout = x[:, 0]
return x[:, self.start_index:] + readout.unsqueeze(1)
class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.transpose(self.dim0, self.dim1)
return x.contiguous()
def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict:
_, _, H, W = x.size()
_ = pretrained.model.forward_flex(x)
return {k: pretrained.rearrange(v) for k, v in activations.items()}
def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:
posemb_tok, posemb_grid = (
posemb[:, : self.start_index],
posemb[0, self.start_index :],
)
gs_old = int(math.sqrt(len(posemb_grid)))
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
def forward_flex(self, x: torch.Tensor) -> torch.Tensor:
# patch proj and dynamically resize
B, C, H, W = x.size()
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
pos_embed = self._resize_pos_embed(
self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]
)
# add cls token
cls_tokens = self.cls_token.expand(
x.size(0), -1, -1
)
x = torch.cat((cls_tokens, x), dim=1)
# forward pass
x = x + pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
activations = {}
def get_activation(name: str) -> Callable:
def hook(model, input, output):
activations[name] = output
return hook
def make_sd_backbone(
model: nn.Module,
hooks: list[int] = [2, 5, 8, 11],
hook_patch: bool = True,
start_index: list[int] = 1,
):
assert len(hooks) == 4
pretrained = nn.Module()
pretrained.model = model
# add hooks
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
if hook_patch:
pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
# configure readout
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
pretrained.model.start_index = start_index
pretrained.model.patch_size = patch_size
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained
def make_vit_backbone(
model: nn.Module,
patch_size: list[int] = [16, 16],
hooks: list[int] = [2, 5, 8, 11],
hook_patch: bool = True,
start_index: list[int] = 1,
):
assert len(hooks) == 4
pretrained = nn.Module()
pretrained.model = model
# add hooks
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))
if hook_patch:
pretrained.model.pos_drop.register_forward_hook(get_activation('4'))
# configure readout
pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))
pretrained.model.start_index = start_index
pretrained.model.patch_size = patch_size
# We inject this function into the VisionTransformer instances so that
# we can use it with interpolated position embeddings without modifying the library source.
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(
_resize_pos_embed, pretrained.model
)
return pretrained