pmf_with_gis / models /__init__.py
hushell's picture
add app.py
b9288df
import os
import numpy as np
import torch
#from timm.models import create_model
from .protonet import ProtoNet
from .deploy import ProtoNet_Finetune, ProtoNet_Auto_Finetune, ProtoNet_AdaTok, ProtoNet_AdaTok_EntMin
def get_backbone(args):
if args.arch == 'vit_base_patch16_224_in21k':
from .vit_google import VisionTransformer, CONFIGS
config = CONFIGS['ViT-B_16']
model = VisionTransformer(config, 224)
url = 'https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz'
pretrained_weights = 'pretrained_ckpts/vit_base_patch16_224_in21k.npz'
if not os.path.exists(pretrained_weights):
try:
import wget
os.makedirs('pretrained_ckpts', exist_ok=True)
wget.download(url, pretrained_weights)
except:
print(f'Cannot download pretrained weights from {url}. Check if `pip install wget` works.')
model.load_from(np.load(pretrained_weights))
print('Pretrained weights found at {}'.format(pretrained_weights))
elif args.arch == 'dino_base_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'deit_base_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_base'](patch_size=16, num_classes=0)
url = "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth"
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
for k in ['head.weight', 'head.bias']:
if k in state_dict:
print(f"removing key {k} from pretrained checkpoint")
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'deit_small_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
url = "https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth"
state_dict = torch.hub.load_state_dict_from_url(url=url)["model"]
for k in ['head.weight', 'head.bias']:
if k in state_dict:
print(f"removing key {k} from pretrained checkpoint")
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'dino_small_patch16':
from . import vision_transformer as vit
model = vit.__dict__['vit_small'](patch_size=16, num_classes=0)
if not args.no_pretrain:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
model.load_state_dict(state_dict, strict=True)
print('Pretrained weights found at {}'.format(url))
elif args.arch == 'beit_base_patch16_224_pt22k':
from .beit import default_pretrained_model
model = default_pretrained_model(args)
print('Pretrained BEiT loaded')
elif args.arch == 'clip_base_patch16_224':
from . import clip
model, _ = clip.load('ViT-B/16', 'cpu')
elif args.arch == 'clip_resnet50':
from . import clip
model, _ = clip.load('RN50', 'cpu')
elif args.arch == 'dino_resnet50':
from torchvision.models.resnet import resnet50
model = resnet50(pretrained=False)
model.fc = torch.nn.Identity()
if not args.no_pretrain:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=False)
elif args.arch == 'resnet50':
from torchvision.models.resnet import resnet50
pretrained = not args.no_pretrain
model = resnet50(pretrained=pretrained)
model.fc = torch.nn.Identity()
elif args.arch == 'resnet18':
from torchvision.models.resnet import resnet18
pretrained = not args.no_pretrain
model = resnet18(pretrained=pretrained)
model.fc = torch.nn.Identity()
elif args.arch == 'dino_xcit_medium_24_p16':
model = torch.hub.load('facebookresearch/xcit:main', 'xcit_medium_24_p16')
model.head = torch.nn.Identity()
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_xcit_medium_24_p16_pretrain/dino_xcit_medium_24_p16_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=False)
elif args.arch == 'dino_xcit_medium_24_p8':
model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
elif args.arch == 'simclrv2_resnet50':
import sys
sys.path.insert(
0,
'cog',
)
import model_utils
model_utils.MODELS_ROOT_DIR = 'cog/models'
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts/simclrv2_resnet50.pth')
resnet, _ = model_utils.load_pretrained_backbone(args.arch, ckpt_file)
class Wrapper(torch.nn.Module):
def __init__(self, model):
super(Wrapper, self).__init__()
self.model = model
def forward(self, x):
return self.model(x, apply_fc=False)
model = Wrapper(resnet)
elif args.arch in ['mocov2_resnet50', 'swav_resnet50', 'barlow_resnet50']:
from torchvision.models.resnet import resnet50
model = resnet50(pretrained=False)
ckpt_file = os.path.join(args.pretrained_checkpoint_path, 'pretrained_ckpts_converted/{}.pth'.format(args.arch))
ckpt = torch.load(ckpt_file)
msg = model.load_state_dict(ckpt, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
# remove the fully-connected layer
model.fc = torch.nn.Identity()
else:
raise ValueError(f'{args.arch} is not conisdered in the current code.')
return model
def get_model(args):
backbone = get_backbone(args)
if args.deploy == 'vanilla':
model = ProtoNet(backbone)
elif args.deploy == 'finetune':
model = ProtoNet_Finetune(backbone, args.ada_steps, args.ada_lr, args.aug_prob, args.aug_types)
elif args.deploy == 'finetune_autolr':
model = ProtoNet_Auto_Finetune(backbone, args.ada_steps, args.aug_prob, args.aug_types)
elif args.deploy == 'ada_tokens':
model = ProtoNet_AdaTok(backbone, args.num_adapters,
args.ada_steps, args.ada_lr)
elif args.deploy == 'ada_tokens_entmin':
model = ProtoNet_AdaTok_EntMin(backbone, args.num_adapters,
args.ada_steps, args.ada_lr)
else:
raise ValueError(f'deploy method {args.deploy} is not supported.')
return model