|
|
|
|
|
|
|
try: |
|
from timm.layers import resample_abs_pos_embed |
|
except ImportError as err: |
|
print("ImportError: {0}".format(err)) |
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
def make_vit_b16_backbone( |
|
model, |
|
encoder_feature_dims, |
|
encoder_feature_layer_ids, |
|
vit_features, |
|
start_index=1, |
|
use_grad_checkpointing=False, |
|
) -> nn.Module: |
|
"""Make a ViTb16 backbone for the DPT model.""" |
|
if use_grad_checkpointing: |
|
model.set_grad_checkpointing() |
|
|
|
vit_model = nn.Module() |
|
vit_model.hooks = encoder_feature_layer_ids |
|
vit_model.model = model |
|
vit_model.features = encoder_feature_dims |
|
vit_model.vit_features = vit_features |
|
vit_model.model.start_index = start_index |
|
vit_model.model.patch_size = vit_model.model.patch_embed.patch_size |
|
vit_model.model.is_vit = True |
|
vit_model.model.forward = vit_model.model.forward_features |
|
|
|
return vit_model |
|
|
|
|
|
def forward_features_eva_fixed(self, x): |
|
"""Encode features.""" |
|
x = self.patch_embed(x) |
|
x, rot_pos_embed = self._pos_embed(x) |
|
for blk in self.blocks: |
|
if self.grad_checkpointing: |
|
x = checkpoint(blk, x, rot_pos_embed) |
|
else: |
|
x = blk(x, rot_pos_embed) |
|
x = self.norm(x) |
|
return x |
|
|
|
|
|
def resize_vit(model: nn.Module, img_size) -> nn.Module: |
|
"""Resample the ViT module to the given size.""" |
|
patch_size = model.patch_embed.patch_size |
|
model.patch_embed.img_size = img_size |
|
grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) |
|
model.patch_embed.grid_size = grid_size |
|
|
|
pos_embed = resample_abs_pos_embed( |
|
model.pos_embed, |
|
grid_size, |
|
num_prefix_tokens=( |
|
0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens |
|
), |
|
) |
|
model.pos_embed = torch.nn.Parameter(pos_embed) |
|
|
|
return model |
|
|
|
|
|
def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module: |
|
"""Resample the ViT patch size to the given one.""" |
|
|
|
if hasattr(model, "patch_embed"): |
|
old_patch_size = model.patch_embed.patch_size |
|
|
|
if ( |
|
new_patch_size[0] != old_patch_size[0] |
|
or new_patch_size[1] != old_patch_size[1] |
|
): |
|
patch_embed_proj = model.patch_embed.proj.weight |
|
patch_embed_proj_bias = model.patch_embed.proj.bias |
|
use_bias = True if patch_embed_proj_bias is not None else False |
|
_, _, h, w = patch_embed_proj.shape |
|
|
|
new_patch_embed_proj = torch.nn.functional.interpolate( |
|
patch_embed_proj, |
|
size=[new_patch_size[0], new_patch_size[1]], |
|
mode="bicubic", |
|
align_corners=False, |
|
) |
|
new_patch_embed_proj = ( |
|
new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1]) |
|
) |
|
|
|
model.patch_embed.proj = nn.Conv2d( |
|
in_channels=model.patch_embed.proj.in_channels, |
|
out_channels=model.patch_embed.proj.out_channels, |
|
kernel_size=new_patch_size, |
|
stride=new_patch_size, |
|
bias=use_bias, |
|
) |
|
|
|
if use_bias: |
|
model.patch_embed.proj.bias = patch_embed_proj_bias |
|
|
|
model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj) |
|
|
|
model.patch_size = new_patch_size |
|
model.patch_embed.patch_size = new_patch_size |
|
model.patch_embed.img_size = ( |
|
int( |
|
model.patch_embed.img_size[0] |
|
* new_patch_size[0] |
|
/ old_patch_size[0] |
|
), |
|
int( |
|
model.patch_embed.img_size[1] |
|
* new_patch_size[1] |
|
/ old_patch_size[1] |
|
), |
|
) |
|
|
|
return model |
|
|