laurenok24 commited on
Commit
b41b87f
1 Parent(s): 2114261

Upload 6 files

Browse files
models/detectron2/diver_detector_setup.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, distutils.core
2
+
3
+ # os.system('python -m pip install pyyaml==5.3.1')
4
+ # dist = distutils.core.run_setup("./detectron2/setup.py")
5
+ # temp = ' '.join([f"'{x}'" for x in dist.install_requires])
6
+ # cmd = "python -m pip install {0}".format(temp)
7
+ # os.system(cmd)
8
+ sys.path.insert(0, os.path.abspath('./detectron2'))
9
+
10
+ import detectron2
11
+ import cv2
12
+
13
+ from detectron2.utils.logger import setup_logger
14
+ setup_logger()
15
+
16
+ # from detectron2.modeling import build_model
17
+ from detectron2 import model_zoo
18
+ from detectron2.engine import DefaultPredictor
19
+ from detectron2.config import get_cfg
20
+ from detectron2.utils.visualizer import Visualizer
21
+ from detectron2.data import MetadataCatalog, DatasetCatalog
22
+ from detectron2.checkpoint import DetectionCheckpointer
23
+ from detectron2.data.datasets import register_coco_instances
24
+
25
+ def get_diver_detector():
26
+ cfg = get_cfg()
27
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
28
+ cfg.OUTPUT_DIR = "./output/diver/"
29
+ cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
30
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold
31
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
32
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (ballon). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
33
+ diver_detector = DefaultPredictor(cfg)
34
+ return diver_detector
35
+
36
+
37
+
models/detectron2/platform_detector_setup.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, distutils.core
2
+
3
+ # os.system('python -m pip install pyyaml==5.3.1')
4
+ # dist = distutils.core.run_setup("./detectron2/setup.py")
5
+ # temp = ' '.join([f"'{x}'" for x in dist.install_requires])
6
+ # cmd = "python -m pip install {0}".format(temp)
7
+ # os.system(cmd)
8
+ sys.path.insert(0, os.path.abspath('./detectron2'))
9
+
10
+ import detectron2
11
+ import cv2
12
+
13
+ from detectron2.utils.logger import setup_logger
14
+ setup_logger()
15
+
16
+ # from detectron2.modeling import build_model
17
+ from detectron2 import model_zoo
18
+ from detectron2.engine import DefaultPredictor
19
+ from detectron2.config import get_cfg
20
+ from detectron2.utils.visualizer import Visualizer
21
+ from detectron2.data import MetadataCatalog, DatasetCatalog
22
+ from detectron2.checkpoint import DetectionCheckpointer
23
+ from detectron2.data.datasets import register_coco_instances
24
+
25
+ def get_platform_detector():
26
+ cfg = get_cfg()
27
+ cfg.OUTPUT_DIR = "./output/platform/"
28
+ # model = build_model(cfg) # returns a torch.nn.Module
29
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
30
+ cfg.DATASETS.TEST = ()
31
+ cfg.DATALOADER.NUM_WORKERS = 2
32
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
33
+ cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real "batch size" commonly known to deep learning people
34
+ cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
35
+ cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
36
+ cfg.SOLVER.STEPS = [] # do not decay learning rate
37
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
38
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
39
+ cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
40
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold
41
+ predictor = DefaultPredictor(cfg)
42
+ return predictor
43
+
44
+ # register_coco_instances("springboard_trains", {}, "./coco_annotations/springboard/train.json", "../data/Boards/spring")
45
+ # register_coco_instances("springboard_vals", {}, "./coco_annotations/springboard/val.json", "../data/Boards/spring")
46
+
47
+ # from detectron2.utils.visualizer import ColorMode
48
+ # splash_metadata = MetadataCatalog.get('springboard_vals')
49
+ # dataset_dicts = DatasetCatalog.get("springboard_vals")
50
+
51
+
models/detectron2/splash_detector_setup.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, distutils.core
2
+
3
+ # os.system('python -m pip install pyyaml==5.3.1')
4
+ # dist = distutils.core.run_setup("./detectron2/setup.py")
5
+ # temp = ' '.join([f"'{x}'" for x in dist.install_requires])
6
+ # cmd = "python -m pip install {0}".format(temp)
7
+ # os.system(cmd)
8
+ sys.path.insert(0, os.path.abspath('./detectron2'))
9
+
10
+ import detectron2
11
+ import cv2
12
+
13
+ from detectron2.utils.logger import setup_logger
14
+ setup_logger()
15
+
16
+ # from detectron2.modeling import build_model
17
+ from detectron2 import model_zoo
18
+ from detectron2.engine import DefaultPredictor
19
+ from detectron2.config import get_cfg
20
+ from detectron2.utils.visualizer import Visualizer
21
+ from detectron2.data import MetadataCatalog, DatasetCatalog
22
+ from detectron2.utils.visualizer import Visualizer
23
+ from detectron2.checkpoint import DetectionCheckpointer
24
+ from detectron2.data.datasets import register_coco_instances
25
+
26
+ def get_splash_detector():
27
+ cfg = get_cfg()
28
+ cfg.OUTPUT_DIR = "./output/splash/"
29
+ # model = build_model(cfg) # returns a torch.nn.Module
30
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
31
+ cfg.DATASETS.TRAIN = ("splash_trains",)
32
+ cfg.DATASETS.TEST = ()
33
+ cfg.DATALOADER.NUM_WORKERS = 2
34
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
35
+ cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real "batch size" commonly known to deep learning people
36
+ cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
37
+ cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
38
+ cfg.SOLVER.STEPS = [] # do not decay learning rate
39
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
40
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
41
+ cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
42
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold
43
+ predictor = DefaultPredictor(cfg)
44
+ return predictor
45
+
models/detectron2/springboard_detector_setup.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, distutils.core
2
+
3
+ # os.system('python -m pip install pyyaml==5.3.1')
4
+ # dist = distutils.core.run_setup("./detectron2/setup.py")
5
+ # temp = ' '.join([f"'{x}'" for x in dist.install_requires])
6
+ # cmd = "python -m pip install {0}".format(temp)
7
+ # os.system(cmd)
8
+ sys.path.insert(0, os.path.abspath('./detectron2'))
9
+
10
+ import detectron2
11
+ import cv2
12
+
13
+ from detectron2.utils.logger import setup_logger
14
+ setup_logger()
15
+
16
+ # from detectron2.modeling import build_model
17
+ from detectron2 import model_zoo
18
+ from detectron2.engine import DefaultPredictor
19
+ from detectron2.config import get_cfg
20
+ from detectron2.utils.visualizer import Visualizer
21
+ from detectron2.data import MetadataCatalog, DatasetCatalog
22
+ from detectron2.utils.visualizer import Visualizer
23
+ from detectron2.checkpoint import DetectionCheckpointer
24
+ from detectron2.data.datasets import register_coco_instances
25
+
26
+ def get_springboard_detector():
27
+ cfg = get_cfg()
28
+ cfg.OUTPUT_DIR = "./output/springboard/"
29
+ # model = build_model(cfg) # returns a torch.nn.Module
30
+ cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
31
+ cfg.DATASETS.TEST = ()
32
+ cfg.DATALOADER.NUM_WORKERS = 2
33
+ cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
34
+ cfg.SOLVER.IMS_PER_BATCH = 2 # This is the real "batch size" commonly known to deep learning people
35
+ cfg.SOLVER.BASE_LR = 0.00025 # pick a good LR
36
+ cfg.SOLVER.MAX_ITER = 300 # 300 iterations seems good enough for this toy dataset; you will need to train longer for a practical dataset
37
+ cfg.SOLVER.STEPS = [] # do not decay learning rate
38
+ cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
39
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
40
+ cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") # path to the model we just trained
41
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold
42
+ predictor = DefaultPredictor(cfg)
43
+ return predictor
44
+
45
+ # register_coco_instances("springboard_trains", {}, "./coco_annotations/springboard/train.json", "../data/Boards/spring")
46
+ # register_coco_instances("springboard_vals", {}, "./coco_annotations/springboard/val.json", "../data/Boards/spring")
47
+
48
+ # from detectron2.utils.visualizer import ColorMode
49
+ # splash_metadata = MetadataCatalog.get('springboard_vals')
50
+ # dataset_dicts = DatasetCatalog.get("springboard_vals")
51
+
52
+
models/pose_estimator/pose_estimator_model_setup.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import argparse
6
+ import csv
7
+ import os
8
+ import shutil
9
+ import sys
10
+
11
+ from PIL import Image
12
+ import torch
13
+ import torch.nn.parallel
14
+ import torch.backends.cudnn as cudnn
15
+ import torch.optim
16
+ import torch.utils.data
17
+ import torch.utils.data.distributed
18
+ import torchvision.transforms as transforms
19
+ import torchvision
20
+ import cv2
21
+ import numpy as np
22
+ import time
23
+ sys.path.append('./deep-high-resolution-net.pytorch/lib')
24
+ import models
25
+ from config import cfg
26
+ from config import update_config
27
+ from core.function import get_final_preds
28
+ from utils.transforms import get_affine_transform
29
+
30
+ import distutils.core
31
+
32
+ # os.system('python -m pip install pyyaml==5.3.1')
33
+ # dist = distutils.core.run_setup("./detectron2/setup.py")
34
+ # temp = ' '.join([f"'{x}'" for x in dist.install_requires])
35
+ # cmd = "python -m pip install {0}".format(temp)
36
+ # os.system(cmd)
37
+ # sys.path.insert(0, os.path.abspath('./detectron2'))
38
+
39
+ # import detectron2
40
+ # # from detectron2.modeling import build_model
41
+ # from detectron2 import model_zoo
42
+ # from detectron2.engine import DefaultPredictor
43
+ # from detectron2.config import get_cfg
44
+ # from detectron2.utils.visualizer import Visualizer
45
+ # from detectron2.data import MetadataCatalog, DatasetCatalog
46
+ # from detectron2.utils.visualizer import Visualizer
47
+ # from detectron2.checkpoint import DetectionCheckpointer
48
+ # from detectron2.data.datasets import register_coco_instances
49
+ # from detectron2.utils.visualizer import ColorMode
50
+ from models.detectron2.diver_detector_setup import get_diver_detector
51
+ from models.pose_estimator.pose_hrnet import get_pose_net
52
+
53
+
54
+ def box_to_center_scale(box, model_image_width, model_image_height):
55
+ """convert a box to center,scale information required for pose transformation
56
+ Parameters
57
+ ----------
58
+ box : list of tuple
59
+ list of length 2 with two tuples of floats representing
60
+ bottom left and top right corner of a box
61
+ model_image_width : int
62
+ model_image_height : int
63
+
64
+ Returns
65
+ -------
66
+ (numpy array, numpy array)
67
+ Two numpy arrays, coordinates for the center of the box and the scale of the box
68
+ """
69
+ center = np.zeros((2), dtype=np.float32)
70
+
71
+ bottom_left_corner = (box[0].data.cpu().item(), box[1].data.cpu().item())
72
+ top_right_corner = (box[2].data.cpu().item(), box[3].data.cpu().item())
73
+ box_width = top_right_corner[0]-bottom_left_corner[0]
74
+ box_height = top_right_corner[1]-bottom_left_corner[1]
75
+ bottom_left_x = bottom_left_corner[0]
76
+ bottom_left_y = bottom_left_corner[1]
77
+ center[0] = bottom_left_x + box_width * 0.5
78
+ center[1] = bottom_left_y + box_height * 0.5
79
+
80
+ aspect_ratio = model_image_width * 1.0 / model_image_height
81
+ pixel_std = 200
82
+
83
+ if box_width > aspect_ratio * box_height:
84
+ box_height = box_width * 1.0 / aspect_ratio
85
+ elif box_width < aspect_ratio * box_height:
86
+ box_width = box_height * aspect_ratio
87
+ scale = np.array(
88
+ [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
89
+ dtype=np.float32)
90
+ if center[0] != -1:
91
+ scale = scale * 1.25
92
+
93
+ return center, scale
94
+
95
+
96
+ def parse_args():
97
+ parser = argparse.ArgumentParser(description='Train keypoints network')
98
+ # general
99
+ parser.add_argument('--cfg', type=str, default='./deep-high-resolution-net.pytorch/experiments/mpii/hrnet/w32_256x256_adam_lr1e-3.yaml')
100
+ parser.add_argument('opts',
101
+ help='Modify config options using the command-line',
102
+ default=None,
103
+ nargs=argparse.REMAINDER)
104
+
105
+ args = parser.parse_args()
106
+
107
+ # args expected by supporting codebase
108
+ args.modelDir = ''
109
+ args.logDir = ''
110
+ args.dataDir = ''
111
+ args.prevModelDir = ''
112
+ return args
113
+
114
+ def get_pose_estimation_prediction(pose_model, image, center, scale):
115
+ rotation = 0
116
+ trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
117
+ # trans = cv2.getAffineTransform(srcTri, dstTri)
118
+ transform = transforms.Compose([
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
121
+ std=[0.229, 0.224, 0.225]),
122
+ ])
123
+ model_input = cv2.warpAffine(
124
+ image,
125
+ trans,
126
+ (256, 256),
127
+ flags=cv2.INTER_LINEAR)
128
+
129
+ # pose estimation inference
130
+ model_input = transform(model_input).unsqueeze(0)
131
+ # switch to evaluate mode
132
+ pose_model.eval()
133
+ with torch.no_grad():
134
+ # compute output heatmap
135
+ output = pose_model(model_input)
136
+ preds, _ = get_final_preds(
137
+ cfg,
138
+ output.clone().cpu().numpy(),
139
+ np.asarray([center]),
140
+ np.asarray([scale]))
141
+ return preds
142
+
143
+ def get_pose_model():
144
+ CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
145
+ cudnn.benchmark = cfg.CUDNN.BENCHMARK
146
+ torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
147
+ torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
148
+ args = parse_args()
149
+ update_config(cfg, args)
150
+ pose_model = get_pose_net(cfg, is_train=False)
151
+ if cfg.TEST.MODEL_FILE:
152
+ print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
153
+ pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
154
+ else:
155
+ print('expected model defined in config at TEST.MODEL_FILE')
156
+ pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
157
+ pose_model.to(CTX)
158
+ pose_model.eval()
159
+ return pose_model
160
+
161
+
162
+ def get_pose_estimation(filepath, image_bgr=None, diver_detector=None, pose_model=None):
163
+ if image_bgr is None:
164
+ image_bgr = cv2.imread(filepath)
165
+ if image_bgr is None:
166
+ print("ERROR: image {} does not exist".format(filepath))
167
+ return None
168
+ if diver_detector is None:
169
+ diver_detector = get_diver_detector()
170
+
171
+ if pose_model is None:
172
+ pose_model = get_pose_model()
173
+
174
+ image = image_bgr[:, :, [2, 1, 0]]
175
+
176
+ outputs = diver_detector(image_bgr)
177
+ scores = outputs['instances'].scores
178
+ pred_boxes = []
179
+ if len(scores) > 0:
180
+ pred_boxes = outputs['instances'].pred_boxes
181
+
182
+ if len(pred_boxes) >= 1:
183
+ for box in pred_boxes:
184
+ center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
185
+ image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
186
+ box = box.detach().cpu().numpy()
187
+ return box, get_pose_estimation_prediction(pose_model, image_pose, center, scale)
188
+ # print("pose_preds", pose_preds)
189
+ # draw_bbox(box,image_bgr)
190
+ # if len(pose_preds)>=1:
191
+ # print('drawing preds')
192
+ # for kpt in pose_preds:
193
+ # draw_pose(kpt,image_bgr) # draw the poses
194
+ # break # only want to use the box with the highest confidence score
195
+ return None, None
196
+
197
+
198
+
models/pose_estimator/pose_hrnet.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import os
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ BN_MOMENTUM = 0.1
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def conv3x3(in_planes, out_planes, stride=1):
23
+ """3x3 convolution with padding"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=1, bias=False)
26
+
27
+
28
+ class BasicBlock(nn.Module):
29
+ expansion = 1
30
+
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
32
+ super(BasicBlock, self).__init__()
33
+ self.conv1 = conv3x3(inplanes, planes, stride)
34
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.conv2 = conv3x3(planes, planes)
37
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+
41
+ def forward(self, x):
42
+ residual = x
43
+
44
+ out = self.conv1(x)
45
+ out = self.bn1(out)
46
+ out = self.relu(out)
47
+
48
+ out = self.conv2(out)
49
+ out = self.bn2(out)
50
+
51
+ if self.downsample is not None:
52
+ residual = self.downsample(x)
53
+
54
+ out += residual
55
+ out = self.relu(out)
56
+
57
+ return out
58
+
59
+
60
+ class Bottleneck(nn.Module):
61
+ expansion = 4
62
+
63
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
64
+ super(Bottleneck, self).__init__()
65
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
67
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
68
+ padding=1, bias=False)
69
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
70
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
71
+ bias=False)
72
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
73
+ momentum=BN_MOMENTUM)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ residual = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class HighResolutionModule(nn.Module):
102
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
103
+ num_channels, fuse_method, multi_scale_output=True):
104
+ super(HighResolutionModule, self).__init__()
105
+ self._check_branches(
106
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
107
+
108
+ self.num_inchannels = num_inchannels
109
+ self.fuse_method = fuse_method
110
+ self.num_branches = num_branches
111
+
112
+ self.multi_scale_output = multi_scale_output
113
+
114
+ self.branches = self._make_branches(
115
+ num_branches, blocks, num_blocks, num_channels)
116
+ self.fuse_layers = self._make_fuse_layers()
117
+ self.relu = nn.ReLU(True)
118
+
119
+ def _check_branches(self, num_branches, blocks, num_blocks,
120
+ num_inchannels, num_channels):
121
+ if num_branches != len(num_blocks):
122
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
123
+ num_branches, len(num_blocks))
124
+ logger.error(error_msg)
125
+ raise ValueError(error_msg)
126
+
127
+ if num_branches != len(num_channels):
128
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
129
+ num_branches, len(num_channels))
130
+ logger.error(error_msg)
131
+ raise ValueError(error_msg)
132
+
133
+ if num_branches != len(num_inchannels):
134
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
135
+ num_branches, len(num_inchannels))
136
+ logger.error(error_msg)
137
+ raise ValueError(error_msg)
138
+
139
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
140
+ stride=1):
141
+ downsample = None
142
+ if stride != 1 or \
143
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
144
+ downsample = nn.Sequential(
145
+ nn.Conv2d(
146
+ self.num_inchannels[branch_index],
147
+ num_channels[branch_index] * block.expansion,
148
+ kernel_size=1, stride=stride, bias=False
149
+ ),
150
+ nn.BatchNorm2d(
151
+ num_channels[branch_index] * block.expansion,
152
+ momentum=BN_MOMENTUM
153
+ ),
154
+ )
155
+
156
+ layers = []
157
+ layers.append(
158
+ block(
159
+ self.num_inchannels[branch_index],
160
+ num_channels[branch_index],
161
+ stride,
162
+ downsample
163
+ )
164
+ )
165
+ self.num_inchannels[branch_index] = \
166
+ num_channels[branch_index] * block.expansion
167
+ for i in range(1, num_blocks[branch_index]):
168
+ layers.append(
169
+ block(
170
+ self.num_inchannels[branch_index],
171
+ num_channels[branch_index]
172
+ )
173
+ )
174
+
175
+ return nn.Sequential(*layers)
176
+
177
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
178
+ branches = []
179
+
180
+ for i in range(num_branches):
181
+ branches.append(
182
+ self._make_one_branch(i, block, num_blocks, num_channels)
183
+ )
184
+
185
+ return nn.ModuleList(branches)
186
+
187
+ def _make_fuse_layers(self):
188
+ if self.num_branches == 1:
189
+ return None
190
+
191
+ num_branches = self.num_branches
192
+ num_inchannels = self.num_inchannels
193
+ fuse_layers = []
194
+ for i in range(num_branches if self.multi_scale_output else 1):
195
+ fuse_layer = []
196
+ for j in range(num_branches):
197
+ if j > i:
198
+ fuse_layer.append(
199
+ nn.Sequential(
200
+ nn.Conv2d(
201
+ num_inchannels[j],
202
+ num_inchannels[i],
203
+ 1, 1, 0, bias=False
204
+ ),
205
+ nn.BatchNorm2d(num_inchannels[i]),
206
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
207
+ )
208
+ )
209
+ elif j == i:
210
+ fuse_layer.append(None)
211
+ else:
212
+ conv3x3s = []
213
+ for k in range(i-j):
214
+ if k == i - j - 1:
215
+ num_outchannels_conv3x3 = num_inchannels[i]
216
+ conv3x3s.append(
217
+ nn.Sequential(
218
+ nn.Conv2d(
219
+ num_inchannels[j],
220
+ num_outchannels_conv3x3,
221
+ 3, 2, 1, bias=False
222
+ ),
223
+ nn.BatchNorm2d(num_outchannels_conv3x3)
224
+ )
225
+ )
226
+ else:
227
+ num_outchannels_conv3x3 = num_inchannels[j]
228
+ conv3x3s.append(
229
+ nn.Sequential(
230
+ nn.Conv2d(
231
+ num_inchannels[j],
232
+ num_outchannels_conv3x3,
233
+ 3, 2, 1, bias=False
234
+ ),
235
+ nn.BatchNorm2d(num_outchannels_conv3x3),
236
+ nn.ReLU(True)
237
+ )
238
+ )
239
+ fuse_layer.append(nn.Sequential(*conv3x3s))
240
+ fuse_layers.append(nn.ModuleList(fuse_layer))
241
+
242
+ return nn.ModuleList(fuse_layers)
243
+
244
+ def get_num_inchannels(self):
245
+ return self.num_inchannels
246
+
247
+ def forward(self, x):
248
+ if self.num_branches == 1:
249
+ return [self.branches[0](x[0])]
250
+
251
+ for i in range(self.num_branches):
252
+ x[i] = self.branches[i](x[i])
253
+
254
+ x_fuse = []
255
+
256
+ for i in range(len(self.fuse_layers)):
257
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
258
+ for j in range(1, self.num_branches):
259
+ if i == j:
260
+ y = y + x[j]
261
+ else:
262
+ y = y + self.fuse_layers[i][j](x[j])
263
+ x_fuse.append(self.relu(y))
264
+
265
+ return x_fuse
266
+
267
+
268
+ blocks_dict = {
269
+ 'BASIC': BasicBlock,
270
+ 'BOTTLENECK': Bottleneck
271
+ }
272
+
273
+
274
+ class PoseHighResolutionNet(nn.Module):
275
+
276
+ def __init__(self, cfg, **kwargs):
277
+ self.inplanes = 64
278
+ extra = cfg['MODEL']['EXTRA']
279
+ super(PoseHighResolutionNet, self).__init__()
280
+
281
+ # stem net
282
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
283
+ bias=False)
284
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
285
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
286
+ bias=False)
287
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
288
+ self.relu = nn.ReLU(inplace=True)
289
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
290
+
291
+ self.stage2_cfg = extra['STAGE2']
292
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
293
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
294
+ num_channels = [
295
+ num_channels[i] * block.expansion for i in range(len(num_channels))
296
+ ]
297
+ self.transition1 = self._make_transition_layer([256], num_channels)
298
+ self.stage2, pre_stage_channels = self._make_stage(
299
+ self.stage2_cfg, num_channels)
300
+
301
+ self.stage3_cfg = extra['STAGE3']
302
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
303
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
304
+ num_channels = [
305
+ num_channels[i] * block.expansion for i in range(len(num_channels))
306
+ ]
307
+ self.transition2 = self._make_transition_layer(
308
+ pre_stage_channels, num_channels)
309
+ self.stage3, pre_stage_channels = self._make_stage(
310
+ self.stage3_cfg, num_channels)
311
+
312
+ self.stage4_cfg = extra['STAGE4']
313
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
314
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
315
+ num_channels = [
316
+ num_channels[i] * block.expansion for i in range(len(num_channels))
317
+ ]
318
+ self.transition3 = self._make_transition_layer(
319
+ pre_stage_channels, num_channels)
320
+ self.stage4, pre_stage_channels = self._make_stage(
321
+ self.stage4_cfg, num_channels, multi_scale_output=False)
322
+
323
+ self.final_layer = nn.Conv2d(
324
+ in_channels=pre_stage_channels[0],
325
+ out_channels=cfg['MODEL']['NUM_JOINTS'],
326
+ kernel_size=extra['FINAL_CONV_KERNEL'],
327
+ stride=1,
328
+ padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
329
+ )
330
+
331
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
332
+
333
+ def _make_transition_layer(
334
+ self, num_channels_pre_layer, num_channels_cur_layer):
335
+ num_branches_cur = len(num_channels_cur_layer)
336
+ num_branches_pre = len(num_channels_pre_layer)
337
+
338
+ transition_layers = []
339
+ for i in range(num_branches_cur):
340
+ if i < num_branches_pre:
341
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
342
+ transition_layers.append(
343
+ nn.Sequential(
344
+ nn.Conv2d(
345
+ num_channels_pre_layer[i],
346
+ num_channels_cur_layer[i],
347
+ 3, 1, 1, bias=False
348
+ ),
349
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
350
+ nn.ReLU(inplace=True)
351
+ )
352
+ )
353
+ else:
354
+ transition_layers.append(None)
355
+ else:
356
+ conv3x3s = []
357
+ for j in range(i+1-num_branches_pre):
358
+ inchannels = num_channels_pre_layer[-1]
359
+ outchannels = num_channels_cur_layer[i] \
360
+ if j == i-num_branches_pre else inchannels
361
+ conv3x3s.append(
362
+ nn.Sequential(
363
+ nn.Conv2d(
364
+ inchannels, outchannels, 3, 2, 1, bias=False
365
+ ),
366
+ nn.BatchNorm2d(outchannels),
367
+ nn.ReLU(inplace=True)
368
+ )
369
+ )
370
+ transition_layers.append(nn.Sequential(*conv3x3s))
371
+
372
+ return nn.ModuleList(transition_layers)
373
+
374
+ def _make_layer(self, block, planes, blocks, stride=1):
375
+ downsample = None
376
+ if stride != 1 or self.inplanes != planes * block.expansion:
377
+ downsample = nn.Sequential(
378
+ nn.Conv2d(
379
+ self.inplanes, planes * block.expansion,
380
+ kernel_size=1, stride=stride, bias=False
381
+ ),
382
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
383
+ )
384
+
385
+ layers = []
386
+ layers.append(block(self.inplanes, planes, stride, downsample))
387
+ self.inplanes = planes * block.expansion
388
+ for i in range(1, blocks):
389
+ layers.append(block(self.inplanes, planes))
390
+
391
+ return nn.Sequential(*layers)
392
+
393
+ def _make_stage(self, layer_config, num_inchannels,
394
+ multi_scale_output=True):
395
+ num_modules = layer_config['NUM_MODULES']
396
+ num_branches = layer_config['NUM_BRANCHES']
397
+ num_blocks = layer_config['NUM_BLOCKS']
398
+ num_channels = layer_config['NUM_CHANNELS']
399
+ block = blocks_dict[layer_config['BLOCK']]
400
+ fuse_method = layer_config['FUSE_METHOD']
401
+
402
+ modules = []
403
+ for i in range(num_modules):
404
+ # multi_scale_output is only used last module
405
+ if not multi_scale_output and i == num_modules - 1:
406
+ reset_multi_scale_output = False
407
+ else:
408
+ reset_multi_scale_output = True
409
+
410
+ modules.append(
411
+ HighResolutionModule(
412
+ num_branches,
413
+ block,
414
+ num_blocks,
415
+ num_inchannels,
416
+ num_channels,
417
+ fuse_method,
418
+ reset_multi_scale_output
419
+ )
420
+ )
421
+ num_inchannels = modules[-1].get_num_inchannels()
422
+
423
+ return nn.Sequential(*modules), num_inchannels
424
+
425
+ def forward(self, x):
426
+ x = self.conv1(x)
427
+ x = self.bn1(x)
428
+ x = self.relu(x)
429
+ x = self.conv2(x)
430
+ x = self.bn2(x)
431
+ x = self.relu(x)
432
+ x = self.layer1(x)
433
+
434
+ x_list = []
435
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
436
+ if self.transition1[i] is not None:
437
+ x_list.append(self.transition1[i](x))
438
+ else:
439
+ x_list.append(x)
440
+ y_list = self.stage2(x_list)
441
+
442
+ x_list = []
443
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
444
+ if self.transition2[i] is not None:
445
+ x_list.append(self.transition2[i](y_list[-1]))
446
+ else:
447
+ x_list.append(y_list[i])
448
+ y_list = self.stage3(x_list)
449
+
450
+ x_list = []
451
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
452
+ if self.transition3[i] is not None:
453
+ x_list.append(self.transition3[i](y_list[-1]))
454
+ else:
455
+ x_list.append(y_list[i])
456
+ y_list = self.stage4(x_list)
457
+
458
+ x = self.final_layer(y_list[0])
459
+
460
+ return x
461
+
462
+ def init_weights(self, pretrained=''):
463
+ logger.info('=> init weights from normal distribution')
464
+ for m in self.modules():
465
+ if isinstance(m, nn.Conv2d):
466
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
467
+ nn.init.normal_(m.weight, std=0.001)
468
+ for name, _ in m.named_parameters():
469
+ if name in ['bias']:
470
+ nn.init.constant_(m.bias, 0)
471
+ elif isinstance(m, nn.BatchNorm2d):
472
+ nn.init.constant_(m.weight, 1)
473
+ nn.init.constant_(m.bias, 0)
474
+ elif isinstance(m, nn.ConvTranspose2d):
475
+ nn.init.normal_(m.weight, std=0.001)
476
+ for name, _ in m.named_parameters():
477
+ if name in ['bias']:
478
+ nn.init.constant_(m.bias, 0)
479
+
480
+ if os.path.isfile(pretrained):
481
+ pretrained_state_dict = torch.load(pretrained)
482
+ logger.info('=> loading pretrained model {}'.format(pretrained))
483
+
484
+ need_init_state_dict = {}
485
+ for name, m in pretrained_state_dict.items():
486
+ if name.split('.')[0] in self.pretrained_layers \
487
+ or self.pretrained_layers[0] is '*':
488
+ need_init_state_dict[name] = m
489
+ self.load_state_dict(need_init_state_dict, strict=False)
490
+ elif pretrained:
491
+ logger.error('=> please download pre-trained models first!')
492
+ raise ValueError('{} is not exist!'.format(pretrained))
493
+
494
+
495
+ def get_pose_net(cfg, is_train, **kwargs):
496
+ model = PoseHighResolutionNet(cfg, **kwargs)
497
+
498
+ if is_train and cfg['MODEL']['INIT_WEIGHTS']:
499
+ model.init_weights(cfg['MODEL']['PRETRAINED'])
500
+
501
+ return model