Spaces:
Runtime error
Runtime error
import timm | |
import torch.nn as nn | |
from pathlib import Path | |
from .utils import activations, forward_default, get_activation | |
from ..external.next_vit.classification.nextvit import * | |
def forward_next_vit(pretrained, x): | |
return forward_default(pretrained, x, "forward") | |
def _make_next_vit_backbone( | |
model, | |
hooks=[2, 6, 36, 39], | |
): | |
pretrained = nn.Module() | |
pretrained.model = model | |
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) | |
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) | |
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) | |
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) | |
pretrained.activations = activations | |
return pretrained | |
def _make_pretrained_next_vit_large_6m(hooks=None): | |
model = timm.create_model("nextvit_large") | |
hooks = [2, 6, 36, 39] if hooks == None else hooks | |
return _make_next_vit_backbone( | |
model, | |
hooks=hooks, | |
) | |