Spaces:
Running
Running
File size: 7,241 Bytes
4a3ad95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import torch
import importlib
import torch.distributed as dist
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def relative_bias_interpolate(checkpoint,config):
for k in list(checkpoint['model']):
if 'relative_position_index' in k:
del checkpoint['model'][k]
if 'relative_position_bias_table' in k:
relative_position_bias_table = checkpoint['model'][k]
cls_bias = relative_position_bias_table[:1,:]
relative_position_bias_table = relative_position_bias_table[1:,:]
size = int(relative_position_bias_table.shape[0]**0.5)
img_size = (size+1)//2
if 'stage_3' in k:
downsample_ratio = 16
elif 'stage_4' in k:
downsample_ratio = 32
new_img_size = config.DATA.IMG_SIZE//downsample_ratio
new_size = 2*new_img_size-1
if new_size == size:
continue
relative_position_bias_table = relative_position_bias_table.reshape(size,size,-1)
relative_position_bias_table = relative_position_bias_table.unsqueeze(0).permute(0,3,1,2)#bs,nhead,h,w
relative_position_bias_table = torch.nn.functional.interpolate(
relative_position_bias_table, size=(new_size, new_size), mode='bicubic', align_corners=False)
relative_position_bias_table = relative_position_bias_table.permute(0,2,3,1)
relative_position_bias_table = relative_position_bias_table.squeeze(0).reshape(new_size*new_size,-1)
relative_position_bias_table = torch.cat((cls_bias,relative_position_bias_table),dim=0)
checkpoint['model'][k] = relative_position_bias_table
return checkpoint
def load_pretained(config,model,logger=None,strict=False):
if logger is not None:
logger.info(f"==============> pretrain form {config.MODEL.PRETRAINED}....................")
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
if 'model' not in checkpoint:
if 'state_dict_ema' in checkpoint:
checkpoint['model'] = checkpoint['state_dict_ema']
else:
checkpoint['model'] = checkpoint
if config.MODEL.DORP_HEAD:
if 'head.weight' in checkpoint['model'] and 'head.bias' in checkpoint['model']:
if logger is not None:
logger.info(f"==============> drop head....................")
del checkpoint['model']['head.weight']
del checkpoint['model']['head.bias']
if 'head.fc.weight' in checkpoint['model'] and 'head.fc.bias' in checkpoint['model']:
if logger is not None:
logger.info(f"==============> drop head....................")
del checkpoint['model']['head.fc.weight']
del checkpoint['model']['head.fc.bias']
if config.MODEL.DORP_META:
if logger is not None:
logger.info(f"==============> drop meta head....................")
for k in list(checkpoint['model']):
if 'meta' in k:
del checkpoint['model'][k]
checkpoint = relative_bias_interpolate(checkpoint,config)
if 'point_coord' in checkpoint['model']:
if logger is not None:
logger.info(f"==============> drop point coord....................")
del checkpoint['model']['point_coord']
msg = model.load_state_dict(checkpoint['model'], strict=strict)
del checkpoint
torch.cuda.empty_cache()
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
if 'model' not in checkpoint:
if 'state_dict_ema' in checkpoint:
checkpoint['model'] = checkpoint['state_dict_ema']
else:
checkpoint['model'] = checkpoint
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
lastest_save_path = os.path.join(config.OUTPUT, f'latest.pth')
logger.info(f"{lastest_save_path} saving......")
torch.save(save_state, lastest_save_path)
logger.info(f"{lastest_save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1. / norm_type)
return total_norm
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def load_ext(name, funcs):
ext = importlib.import_module(name)
for fun in funcs:
assert hasattr(ext, fun), f'{fun} miss in module {name}'
return ext
|