PKaushik commited on
Commit
4ee00b7
1 Parent(s): eddd5b6
Files changed (1) hide show
  1. yolov6/utils/checkpoint.py +60 -0
yolov6/utils/checkpoint.py CHANGED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import shutil
5
+ import torch
6
+ import os.path as osp
7
+ from yolov6.utils.events import LOGGER
8
+ from yolov6.utils.torch_utils import fuse_model
9
+
10
+
11
+ def load_state_dict(weights, model, map_location=None):
12
+ """Load weights from checkpoint file, only assign weights those layers' name and shape are match."""
13
+ ckpt = torch.load(weights, map_location=map_location)
14
+ state_dict = ckpt['model'].float().state_dict()
15
+ model_state_dict = model.state_dict()
16
+ state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.shape == model_state_dict[k].shape}
17
+ model.load_state_dict(state_dict, strict=False)
18
+ del ckpt, state_dict, model_state_dict
19
+ return model
20
+
21
+
22
+ def load_checkpoint(weights, map_location=None, inplace=True, fuse=True):
23
+ """Load model from checkpoint file."""
24
+ LOGGER.info("Loading checkpoint from {}".format(weights))
25
+ ckpt = torch.load(weights, map_location=map_location) # load
26
+ model = ckpt['ema' if ckpt.get('ema') else 'model'].float()
27
+ if fuse:
28
+ LOGGER.info("\nFusing model...")
29
+ model = fuse_model(model).eval()
30
+ else:
31
+ model = model.eval()
32
+ return model
33
+
34
+
35
+ def save_checkpoint(ckpt, is_best, save_dir, model_name=""):
36
+ """ Save checkpoint to the disk."""
37
+ if not osp.exists(save_dir):
38
+ os.makedirs(save_dir)
39
+ filename = osp.join(save_dir, model_name + '.pt')
40
+ torch.save(ckpt, filename)
41
+ if is_best:
42
+ best_filename = osp.join(save_dir, 'best_ckpt.pt')
43
+ shutil.copyfile(filename, best_filename)
44
+
45
+
46
+ def strip_optimizer(ckpt_dir, epoch):
47
+ for s in ['best', 'last']:
48
+ ckpt_path = osp.join(ckpt_dir, '{}_ckpt.pt'.format(s))
49
+ if not osp.exists(ckpt_path):
50
+ continue
51
+ ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
52
+ if ckpt.get('ema'):
53
+ ckpt['model'] = ckpt['ema'] # replace model with ema
54
+ for k in ['optimizer', 'ema', 'updates']: # keys
55
+ ckpt[k] = None
56
+ ckpt['epoch'] = epoch
57
+ ckpt['model'].half() # to FP16
58
+ for p in ckpt['model'].parameters():
59
+ p.requires_grad = False
60
+ torch.save(ckpt, ckpt_path)