Spaces:
Sleeping
Sleeping
File size: 3,620 Bytes
bc8c24d fe6bd89 bc8c24d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch, os
import torch.distributed as dist
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
"""
Re-start from checkpoint
"""
if not os.path.isfile(ckp_path):
return
print("Found checkpoint at {}".format(ckp_path))
if ckp_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
ckp_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(ckp_path, map_location='cpu')
for key, value in kwargs.items():
if key in checkpoint and value is not None:
if key == "model_ema":
value.ema.load_state_dict(checkpoint[key])
else:
value.load_state_dict(checkpoint[key])
else:
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))
# re load variable important for the run
if run_variables is not None:
for var_name in run_variables:
if var_name in checkpoint:
run_variables[var_name] = checkpoint[var_name]
def load_pretrained_weights(model, pretrained_weights, checkpoint_key=None, prefixes=None,drop_head="head"):
"""load vit weights"""
if pretrained_weights == '':
return
elif pretrained_weights.startswith('https'):
state_dict = torch.hub.load_state_dict_from_url(
pretrained_weights, map_location='cpu', check_hash=True)
else:
state_dict = torch.load(pretrained_weights, map_location='cpu')
epoch = state_dict['epoch'] if 'epoch' in state_dict else -1
if not checkpoint_key:
for key in ['model', 'teacher', 'encoder']:
if key in state_dict: checkpoint_key = key
print("Load pre-trained checkpoint from: %s[%s] at %d epoch" % (pretrained_weights, checkpoint_key, epoch))
if checkpoint_key:
state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
if prefixes is None: prefixes= ["module.","backbone."]
for prefix in prefixes:
state_dict = {k.replace(prefix, ""): v for k, v in state_dict.items() if not drop_head in k }
# remove `backbone.` prefix induced by multicrop wrapper
checkpoint_model = state_dict
# interpolate position embedding
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] ) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
# print('debug:', pos_embed_checkpoint.shape,orig_size,new_size,num_extra_tokens)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
msg = model.load_state_dict(checkpoint_model, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) |