Spaces:
Sleeping
Sleeping
File size: 5,826 Bytes
d7b0f75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import re
import math
import torch
import collections
from torch import nn
from functools import partial
from torch.utils import model_zoo
from torch.nn import functional as F
from resnet import resnet50
################################################################################
### Help functions for model architecture
################################################################################
# Params: namedtuple
# get_width_and_height_from_size and calculate_output_image_size
# Parameters for the entire model (stem, all blocks, and head)
Params = collections.namedtuple('Params', [
'image_size', 'patch_size', 'emb_dim', 'mlp_dim', 'num_heads', 'num_layers',
'num_classes', 'attn_dropout_rate', 'dropout_rate', 'resnet'
])
# Set Params and BlockArgs's defaults
Params.__new__.__defaults__ = (None, ) * len(Params._fields)
def get_width_and_height_from_size(x):
"""Obtain height and width from x.
Args:
x (int, tuple or list): Data size.
Returns:
size: A tuple or list (H,W).
"""
if isinstance(x, int):
return x, x
if isinstance(x, list) or isinstance(x, tuple):
return x
else:
raise TypeError()
################################################################################
### Helper functions for loading model params
################################################################################
# get_model_params and efficientnet:
# Functions to get BlockArgs and GlobalParams for efficientnet
# url_map and url_map_advprop: Dicts of url_map for pretrained weights
# load_pretrained_weights: A function to load pretrained weights
def vision_transformer(model_name):
"""Create Params for vision transformer model.
Args:
model_name (str): Model name to be queried.
Returns:
Params(params_dict[model_name])
"""
params_dict = {
'ViT-B_16': (384, 16, 768, 3072, 12, 12, 1000, 0.0, 0.1, None),
'ViT-B_32': (384, 32, 768, 3072, 12, 12, 1000, 0.0, 0.1, None),
'ViT-L_16': (384, 16, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None),
'ViT-L_32': (384, 32, 1024, 4096, 16, 24, 1000, 0.0, 0.1, None),
'R50+ViT-B_16': (384, 1, 768, 3072, 12, 12, 1000, 0.0, 0.1, resnet50),
}
image_size, patch_size, emb_dim, mlp_dim, num_heads, num_layers, num_classes, attn_dropout_rate, dropout_rate, resnet = params_dict[
model_name]
params = Params(image_size=image_size,
patch_size=patch_size,
emb_dim=emb_dim,
mlp_dim=mlp_dim,
num_heads=num_heads,
num_layers=num_layers,
num_classes=num_classes,
attn_dropout_rate=attn_dropout_rate,
dropout_rate=dropout_rate,
resnet=resnet)
return params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model name.
Args:
model_name (str): Model's name.
override_params (dict): A dict to modify params.
Returns:
params
"""
params = vision_transformer(model_name)
if override_params:
# ValueError will be raised here if override_params has fields not included in params.
params = params._replace(**override_params)
return params
# train with Standard methods
# check more details in paper(An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
url_map = {
'ViT-B_16':
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_16_imagenet21k_imagenet2012.pth',
'ViT-B_32':
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-B_32_imagenet21k_imagenet2012.pth',
'ViT-L_16':
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_16_imagenet21k_imagenet2012.pth',
'ViT-L_32':
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/ViT-L_32_imagenet21k_imagenet2012.pth',
'R50+ViT-B_16':
'https://github.com/tczhangzhi/VisionTransformer-PyTorch/releases/download/1.0.1/R50+ViT-B_16_imagenet21k_imagenet2012.pth',
}
def load_pretrained_weights(model,
model_name,
weights_path=None,
load_fc=True,
advprop=False):
"""Loads pretrained weights from weights path or download using url.
Args:
model (Module): The whole model of vision transformer.
model_name (str): Model name of vision transformer.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
"""
if isinstance(weights_path, str):
state_dict = torch.load(weights_path)
else:
state_dict = model_zoo.load_url(url_map[model_name])
if load_fc:
ret = model.load_state_dict(state_dict, strict=False)
assert not ret.missing_keys, 'Missing keys when loading pretrained weights: {}'.format(
ret.missing_keys)
else:
state_dict.pop('classifier.weight')
state_dict.pop('classifier.bias')
ret = model.load_state_dict(state_dict, strict=False)
assert set(ret.missing_keys) == set([
'classifier.weight', 'classifier.bias'
]), 'Missing keys when loading pretrained weights: {}'.format(
ret.missing_keys)
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(
ret.unexpected_keys)
print('Loaded pretrained weights for {}'.format(model_name))
|