Spaces:
Running
Running
import re | |
import math | |
import torch | |
import collections | |
from torch import nn | |
from functools import partial | |
from torch.utils import model_zoo | |
from torch.nn import functional as F | |
from resnet import resnet50 | |
################################################################################ | |
### Help functions for model architecture | |
################################################################################ | |
# Params: namedtuple | |
# get_width_and_height_from_size and calculate_output_image_size | |
# Parameters for the entire model (stem, all blocks, and head) | |
Params = collections.namedtuple('Params', [ | |
'image_size', 'patch_size', 'emb_dim', 'mlp_dim', 'num_heads', 'num_layers', | |
'num_classes', 'attn_dropout_rate', 'dropout_rate', 'resnet' | |
]) | |
# Set Params and BlockArgs's defaults | |
Params.__new__.__defaults__ = (None, ) * len(Params._fields) | |
def get_width_and_height_from_size(x): | |
"""Obtain height and width from x. | |
Args: | |
x (int, tuple or list): Data size. | |
Returns: | |
size: A tuple or list (H,W). | |
""" | |
if isinstance(x, int): | |
return x, x | |
if isinstance(x, list) or isinstance(x, tuple): | |
return x | |
else: | |
raise TypeError() | |
################################################################################ | |
### Helper functions for loading model params | |
################################################################################ | |
# get_model_params and efficientnet: | |
# Functions to get BlockArgs and GlobalParams for efficientnet | |
# url_map and url_map_advprop: Dicts of url_map for pretrained weights | |
# load_pretrained_weights: A function to load pretrained weights | |
def vision_transformer(model_name): | |
"""Create Params for vision transformer model. | |
Args: | |
model_name (str): Model name to be queried. | |
Returns: | |
Params(params_dict[model_name]) | |
""" | |
params_dict = { | |
'ViT-B_16': (384, 16, 768, 3072, 12, 12, 1000, 0.0, 0.1, None), | |
'ViT-B_32': (384, 32, 768, 3072, 12, 12, 1000, 0.0, 0.1, None), | |
'ViT-L_16': (384, 16, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None), | |
'ViT-L_32': (384, 32, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None), | |
'R50+ViT-B_16': (384, 1, 768, 3072, 12, 12, 1000, 0.0, 0.1, resnet50), | |
} | |
image_size, patch_size, emb_dim, mlp_dim, num_heads, num_layers, num_classes, attn_dropout_rate, dropout_rate, resnet = params_dict[ | |
model_name] | |
params = Params(image_size=image_size, | |
patch_size=patch_size, | |
emb_dim=emb_dim, | |
mlp_dim=mlp_dim, | |
num_heads=num_heads, | |
num_layers=num_layers, | |
num_classes=num_classes, | |
attn_dropout_rate=attn_dropout_rate, | |
dropout_rate=dropout_rate, | |
resnet=resnet) | |
return params | |
def get_model_params(model_name, override_params): | |
"""Get the block args and global params for a given model name. | |
Args: | |
model_name (str): Model's name. | |
override_params (dict): A dict to modify params. | |
Returns: | |
params | |
""" | |
params = vision_transformer(model_name) | |
if override_params: | |
# ValueError will be raised here if override_params has fields not included in params. | |
params = params._replace(**override_params) | |
return params | |
# train with Standard methods | |
# check more details in paper(An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale) | |
url_map = { | |
'ViT-B_16': | |
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_16_imagenet21k_imagenet2012.pth', | |
'ViT-B_32': | |
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_32_imagenet21k_imagenet2012.pth', | |
'ViT-L_16': | |
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_16_imagenet21k_imagenet2012.pth', | |
'ViT-L_32': | |
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_32_imagenet21k_imagenet2012.pth', | |
'R50+ViT-B_16': | |
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/R50+ViT-B_16_imagenet21k_imagenet2012.pth', | |
} | |
def load_pretrained_weights(model, | |
model_name, | |
weights_path=None, | |
load_fc=True, | |
advprop=False): | |
"""Loads pretrained weights from weights path or download using url. | |
Args: | |
model (Module): The whole model of vision transformer. | |
model_name (str): Model name of vision transformer. | |
weights_path (None or str): | |
str: path to pretrained weights file on the local disk. | |
None: use pretrained weights downloaded from the Internet. | |
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. | |
""" | |
if isinstance(weights_path, str): | |
state_dict = torch.load(weights_path) | |
else: | |
state_dict = model_zoo.load_url(url_map[model_name]) | |
if load_fc: | |
ret = model.load_state_dict(state_dict, strict=False) | |
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format( | |
ret.missing_keys) | |
else: | |
state_dict.pop('classifier.weight') | |
state_dict.pop('classifier.bias') | |
ret = model.load_state_dict(state_dict, strict=False) | |
assert set(ret.missing_keys) == set([ | |
'classifier.weight', 'classifier.bias' | |
]), 'Missing keys when loading pretrained weights: {}'.format( | |
ret.missing_keys) | |
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format( | |
ret.unexpected_keys) | |
print('Loaded pretrained weights for {}'.format(model_name)) | |