Spaces:
Build error
Build error
import os | |
import numpy as np | |
import torch | |
#from timm.models import create_model | |
from .protonet import ProtoNet | |
from .deploy import ProtoNet_Finetune, ProtoNet_Auto_Finetune, ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin | |
def get_backbone(args): | |
if args.arch == 'vit_base_patch16_224_in21k': | |
from .vit_google import VisionTransformer, CONFIGS | |
config = CONFIGS['ViT-B_16'] | |
model = VisionTransformer(config, 224) | |
url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz' | |
pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz' | |
if not os.path.exists(pretrained_weights): | |
try: | |
import wget | |
os.makedirs('pretrained_ckpts', exist_ok=True) | |
wget.download(url, pretrained_weights) | |
except: | |
print(f'Cannot download pretrained weights from {url}. Check if `pip install wget` works.') | |
model.load_from(np.load(pretrained_weights)) | |
print('Pretrained weights found at {}'.format(pretrained_weights)) | |
elif args.arch == 'dino_base_patch16': | |
from . import vision_transformer as vit | |
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0) | |
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
model.load_state_dict(state_dict, strict=True) | |
print('Pretrained weights found at {}'.format(url)) | |
elif args.arch == 'deit_base_patch16': | |
from . import vision_transformer as vit | |
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0) | |
url = "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"] | |
for k in ['head.weight', 'head.bias']: | |
if k in state_dict: | |
print(f"removing key {k} from pretrained checkpoint") | |
del state_dict[k] | |
model.load_state_dict(state_dict, strict=True) | |
print('Pretrained weights found at {}'.format(url)) | |
elif args.arch == 'deit_small_patch16': | |
from . import vision_transformer as vit | |
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0) | |
url = "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"] | |
for k in ['head.weight', 'head.bias']: | |
if k in state_dict: | |
print(f"removing key {k} from pretrained checkpoint") | |
del state_dict[k] | |
model.load_state_dict(state_dict, strict=True) | |
print('Pretrained weights found at {}'.format(url)) | |
elif args.arch == 'dino_small_patch16': | |
from . import vision_transformer as vit | |
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0) | |
if not args.no_pretrain: | |
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" | |
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) | |
model.load_state_dict(state_dict, strict=True) | |
print('Pretrained weights found at {}'.format(url)) | |
elif args.arch == 'beit_base_patch16_224_pt22k': | |
from .beit import default_pretrained_model | |
model = default_pretrained_model(args) | |
print('Pretrained BEiT loaded') | |
elif args.arch == 'clip_base_patch16_224': | |
from . import clip | |
model, _ = clip.load('ViT-B/16', 'cpu') | |
elif args.arch == 'clip_resnet50': | |
from . import clip | |
model, _ = clip.load('RN50', 'cpu') | |
elif args.arch == 'dino_resnet50': | |
from torchvision.models.resnet import resnet50 | |
model = resnet50(pretrained=False) | |
model.fc = torch.nn.Identity() | |
if not args.no_pretrain: | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth", | |
map_location="cpu", | |
) | |
model.load_state_dict(state_dict, strict=False) | |
elif args.arch == 'resnet50': | |
from torchvision.models.resnet import resnet50 | |
pretrained = not args.no_pretrain | |
model = resnet50(pretrained=pretrained) | |
model.fc = torch.nn.Identity() | |
elif args.arch == 'resnet18': | |
from torchvision.models.resnet import resnet18 | |
pretrained = not args.no_pretrain | |
model = resnet18(pretrained=pretrained) | |
model.fc = torch.nn.Identity() | |
elif args.arch == 'dino_xcit_medium_24_p16': | |
model = torch.hub.load('facebookresearch/xcit:main', 'xcit_medium_24_p16') | |
model.head = torch.nn.Identity() | |
state_dict = torch.hub.load_state_dict_from_url( | |
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth", | |
map_location="cpu", | |
) | |
model.load_state_dict(state_dict, strict=False) | |
elif args.arch == 'dino_xcit_medium_24_p8': | |
model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8') | |
elif args.arch == 'simclrv2_resnet50': | |
import sys | |
sys.path.insert( | |
0, | |
'cog', | |
) | |
import model_utils | |
model_utils.MODELS_ROOT_DIR = 'cog/models' | |
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts/simclrv2_resnet50.pth') | |
resnet, _ = model_utils.load_pretrained_backbone(args.arch, ckpt_file) | |
class Wrapper(torch.nn.Module): | |
def __init__(self, model): | |
super(Wrapper, self).__init__() | |
self.model = model | |
def forward(self, x): | |
return self.model(x, apply_fc=False) | |
model = Wrapper(resnet) | |
elif args.arch in ['mocov2_resnet50', 'swav_resnet50', 'barlow_resnet50']: | |
from torchvision.models.resnet import resnet50 | |
model = resnet50(pretrained=False) | |
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts_converted/{}.pth'.format(args.arch)) | |
ckpt = torch.load(ckpt_file) | |
msg = model.load_state_dict(ckpt, strict=False) | |
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} | |
# remove the fully-connected layer | |
model.fc = torch.nn.Identity() | |
else: | |
raise ValueError(f'{args.arch} is not conisdered in the current code.') | |
return model | |
def get_model(args): | |
backbone = get_backbone(args) | |
if args.deploy == 'vanilla': | |
model = ProtoNet(backbone) | |
elif args.deploy == 'finetune': | |
model = ProtoNet_Finetune(backbone, args.ada_steps, args.ada_lr, args.aug_prob, args.aug_types) | |
elif args.deploy == 'finetune_autolr': | |
model = ProtoNet_Auto_Finetune(backbone, args.ada_steps, args.aug_prob, args.aug_types) | |
elif args.deploy == 'ada_tokens': | |
model = ProtoNet_AdaTok(backbone, args.num_adapters, | |
args.ada_steps, args.ada_lr) | |
elif args.deploy == 'ada_tokens_entmin': | |
model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters, | |
args.ada_steps, args.ada_lr) | |
else: | |
raise ValueError(f'deploy method {args.deploy} is not supported.') | |
return model | |