Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import numpy as np | |
from .utils import activations, forward_default, get_activation, Transpose | |
def forward_swin(pretrained, x): | |
return forward_default(pretrained, x) | |
def _make_swin_backbone( | |
model, | |
hooks=[1, 1, 17, 1], | |
patch_grid=[96, 96] | |
): | |
pretrained = nn.Module() | |
pretrained.model = model | |
pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) | |
pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) | |
pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) | |
pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) | |
pretrained.activations = activations | |
if hasattr(model, "patch_grid"): | |
used_patch_grid = model.patch_grid | |
else: | |
used_patch_grid = patch_grid | |
patch_grid_size = np.array(used_patch_grid, dtype=int) | |
pretrained.act_postprocess1 = nn.Sequential( | |
Transpose(1, 2), | |
nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) | |
) | |
pretrained.act_postprocess2 = nn.Sequential( | |
Transpose(1, 2), | |
nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) | |
) | |
pretrained.act_postprocess3 = nn.Sequential( | |
Transpose(1, 2), | |
nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) | |
) | |
pretrained.act_postprocess4 = nn.Sequential( | |
Transpose(1, 2), | |
nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) | |
) | |
return pretrained | |