import timm import torch.nn as nn from pathlib import Path from .utils import activations, forward_default, get_activation from ..external.next_vit.classification.nextvit import * def forward_next_vit(pretrained, x): return forward_default(pretrained, x, "forward") def _make_next_vit_backbone( model, hooks=[2, 6, 36, 39], ): pretrained = nn.Module() pretrained.model = model pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) pretrained.activations = activations return pretrained def _make_pretrained_next_vit_large_6m(hooks=None): model = timm.create_model("nextvit_large") hooks = [2, 6, 36, 39] if hooks == None else hooks return _make_next_vit_backbone( model, hooks=hooks, )