imabackstabber commited on
Commit
0a34307
·
1 Parent(s): d48f0a2

test postometro pipeline

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +4 -34
  2. assets/02.jpg +0 -0
  3. assets/04.jpg +0 -0
  4. assets/05.jpg +0 -0
  5. assets/06.jpg +0 -0
  6. assets/07.jpg +0 -0
  7. common/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  8. common/utils/__pycache__/inference_utils.cpython-39.pyc +0 -0
  9. common/utils/__pycache__/preprocessing.cpython-39.pyc +0 -0
  10. common/utils/__pycache__/transforms.cpython-39.pyc +0 -0
  11. common/utils/__pycache__/vis.cpython-39.pyc +0 -0
  12. common/utils/vis.py +35 -17
  13. main/__pycache__/config.cpython-39.pyc +0 -0
  14. main/__pycache__/postometro.cpython-39.pyc +0 -0
  15. main/config/config_postometro.py +3 -100
  16. main/config/config_smpler_x_b32.py +0 -112
  17. main/config/config_smpler_x_h32.py +0 -111
  18. main/config/config_smpler_x_l32.py +0 -112
  19. main/config/config_smpler_x_s32.py +0 -111
  20. main/inference.py +83 -59
  21. main/pct_utils/__pycache__/modules.cpython-39.pyc +0 -0
  22. main/pct_utils/__pycache__/pct.cpython-39.pyc +0 -0
  23. main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc +0 -0
  24. main/pct_utils/__pycache__/pct_head.cpython-39.pyc +0 -0
  25. main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc +0 -0
  26. main/pct_utils/modules.py +117 -0
  27. main/pct_utils/pct.py +69 -0
  28. main/pct_utils/pct_backbone.py +1475 -0
  29. main/pct_utils/pct_head.py +208 -0
  30. main/pct_utils/pct_tokenizer.py +315 -0
  31. main/postometro.py +305 -0
  32. main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc +0 -0
  33. main/postometro_utils/__pycache__/modules.cpython-39.pyc +0 -0
  34. main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc +0 -0
  35. main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc +0 -0
  36. main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc +0 -0
  37. main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc +0 -0
  38. main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc +0 -0
  39. main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc +0 -0
  40. main/postometro_utils/__pycache__/smpl.cpython-39.pyc +0 -0
  41. main/postometro_utils/__pycache__/transformer.cpython-39.pyc +0 -0
  42. main/postometro_utils/geometric_layers.py +679 -0
  43. main/postometro_utils/modules.py +117 -0
  44. main/postometro_utils/pose_hrnet.py +502 -0
  45. main/postometro_utils/pose_hrnet_config.py +137 -0
  46. main/postometro_utils/pose_resnet.py +318 -0
  47. main/postometro_utils/pose_resnet_config.py +229 -0
  48. main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml +127 -0
  49. main/postometro_utils/positional_encoding.py +57 -0
  50. main/postometro_utils/renderer_pyrender.py +225 -0
app.py CHANGED
@@ -33,40 +33,7 @@ def infer(image_input, in_threshold=0.5, num_people="Single person", render_mesh
33
  os.system(f'rm -rf {OUT_FOLDER}/*')
34
  multi_person = False if (num_people == "Single person") else True
35
  vis_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, 0, multi_person, not(render_mesh))
36
-
37
- # cap = cv2.VideoCapture(video_input)
38
- # fps = math.ceil(cap.get(5))
39
- # width = int(cap.get(3))
40
- # height = int(cap.get(4))
41
- # fourcc = cv2.VideoWriter_fourcc(*'mp4v')
42
- # video_path = osp.join(OUT_FOLDER, f'out.m4v')
43
- # final_video_path = osp.join(OUT_FOLDER, f'out.mp4')
44
- # video_output = cv2.VideoWriter(video_path, fourcc, fps, (width, height))
45
- # success = 1
46
- # frame = 0
47
- # while success:
48
- # success, original_img = cap.read()
49
- # if not success:
50
- # break
51
- # frame += 1
52
- # img, mesh_paths, smplx_paths = inferer.infer(original_img, in_threshold, frame, multi_person, not(render_mesh))
53
- # video_output.write(img)
54
- # yield img, None, None, None
55
- # cap.release()
56
- # video_output.release()
57
- # cv2.destroyAllWindows()
58
- # os.system(f'ffmpeg -i {video_path} -c copy {final_video_path}')
59
-
60
- # #Compress mesh and smplx files
61
- # save_path_mesh = os.path.join(OUT_FOLDER, 'mesh')
62
- # save_mesh_file = os.path.join(OUT_FOLDER, 'mesh.zip')
63
- # os.makedirs(save_path_mesh, exist_ok= True)
64
- # save_path_smplx = os.path.join(OUT_FOLDER, 'smplx')
65
- # save_smplx_file = os.path.join(OUT_FOLDER, 'smplx.zip')
66
- # os.makedirs(save_path_smplx, exist_ok= True)
67
- # os.system(f'zip -r {save_mesh_file} {save_path_mesh}')
68
- # os.system(f'zip -r {save_smplx_file} {save_path_smplx}')
69
- # yield img, video_path, save_mesh_file, save_smplx_file
70
  return vis_img, "bbox num: {}, bbox meta: {}".format(num_bbox, mmdet_box)
71
 
72
  TITLE = '''<h1 align="center">PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery</h1>'''
@@ -113,6 +80,9 @@ with gr.Blocks(title="PostoMETRO", css=".gradio-container") as demo:
113
  ['/home/user/app/assets/02.jpg'],
114
  ['/home/user/app/assets/03.jpg'],
115
  ['/home/user/app/assets/04.jpg'],
 
 
 
116
  ],
117
  inputs=[image_input, 0.2])
118
 
 
33
  os.system(f'rm -rf {OUT_FOLDER}/*')
34
  multi_person = False if (num_people == "Single person") else True
35
  vis_img, num_bbox, mmdet_box = inferer.infer(image_input, in_threshold, 0, multi_person, not(render_mesh))
36
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  return vis_img, "bbox num: {}, bbox meta: {}".format(num_bbox, mmdet_box)
38
 
39
  TITLE = '''<h1 align="center">PostoMETRO: Pose Token Enhanced Mesh Transformer for Robust 3D Human Mesh Recovery</h1>'''
 
80
  ['/home/user/app/assets/02.jpg'],
81
  ['/home/user/app/assets/03.jpg'],
82
  ['/home/user/app/assets/04.jpg'],
83
+ ['/home/user/app/assets/05.jpg'],
84
+ ['/home/user/app/assets/06.jpg'],
85
+ ['/home/user/app/assets/07.jpg'],
86
  ],
87
  inputs=[image_input, 0.2])
88
 
assets/02.jpg CHANGED
assets/04.jpg CHANGED
assets/05.jpg ADDED
assets/06.jpg ADDED
assets/07.jpg ADDED
common/utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (169 Bytes). View file
 
common/utils/__pycache__/inference_utils.cpython-39.pyc ADDED
Binary file (4.33 kB). View file
 
common/utils/__pycache__/preprocessing.cpython-39.pyc ADDED
Binary file (14.3 kB). View file
 
common/utils/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (5.52 kB). View file
 
common/utils/__pycache__/vis.cpython-39.pyc ADDED
Binary file (7.55 kB). View file
 
common/utils/vis.py CHANGED
@@ -5,7 +5,7 @@ from mpl_toolkits.mplot3d import Axes3D
5
  import matplotlib.pyplot as plt
6
  import matplotlib as mpl
7
  import os
8
- os.environ["PYOPENGL_PLATFORM"] = "egl"
9
  import pyrender
10
  import trimesh
11
  from config import cfg
@@ -138,6 +138,20 @@ def perspective_projection(vertices, cam_param):
138
  vertices[:, 1] = vertices[:, 1] * fy / vertices[:, 2] + cy
139
  return vertices
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
143
  if mesh_as_vertices:
@@ -150,28 +164,32 @@ def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
150
  rot = trimesh.transformations.rotation_matrix(
151
  np.radians(180), [1, 0, 0])
152
  mesh.apply_transform(rot)
153
- material = pyrender.MetallicRoughnessMaterial(metallicFactor=0.0, alphaMode='OPAQUE', baseColorFactor=(1.0, 1.0, 0.9, 1.0))
154
- mesh = pyrender.Mesh.from_trimesh(mesh, material=material, smooth=False)
155
- scene = pyrender.Scene(ambient_light=(0.3, 0.3, 0.3))
 
 
 
 
 
 
 
 
 
 
 
156
  scene.add(mesh, 'mesh')
157
 
158
- focal, princpt = cam_param['focal'], cam_param['princpt']
159
- camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
160
- scene.add(camera)
 
 
 
161
 
162
  # renderer
163
  renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
164
 
165
- # light
166
- light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.8)
167
- light_pose = np.eye(4)
168
- light_pose[:3, 3] = np.array([0, -1, 1])
169
- scene.add(light, pose=light_pose)
170
- light_pose[:3, 3] = np.array([0, 1, 1])
171
- scene.add(light, pose=light_pose)
172
- light_pose[:3, 3] = np.array([1, 1, 2])
173
- scene.add(light, pose=light_pose)
174
-
175
  # render
176
  rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
177
  rgb = rgb[:,:,:3].astype(np.float32)
 
5
  import matplotlib.pyplot as plt
6
  import matplotlib as mpl
7
  import os
8
+ os.environ["PYOPENGL_PLATFORM"] = "osmesa"
9
  import pyrender
10
  import trimesh
11
  from config import cfg
 
138
  vertices[:, 1] = vertices[:, 1] * fy / vertices[:, 2] + cy
139
  return vertices
140
 
141
+ class WeakPerspectiveCamera(pyrender.Camera):
142
+ def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None):
143
+ super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name)
144
+ self.scale = scale
145
+ self.translation = translation
146
+
147
+ def get_projection_matrix(self, width=None, height=None):
148
+ P = np.eye(4)
149
+ P[0, 0] = self.scale[0]
150
+ P[1, 1] = self.scale[1]
151
+ P[0, 3] = self.translation[0] * self.scale[0]
152
+ P[1, 3] = -self.translation[1] * self.scale[1]
153
+ P[2, 2] = -1
154
+ return P
155
 
156
  def render_mesh(img, mesh, face, cam_param, mesh_as_vertices=False):
157
  if mesh_as_vertices:
 
164
  rot = trimesh.transformations.rotation_matrix(
165
  np.radians(180), [1, 0, 0])
166
  mesh.apply_transform(rot)
167
+ color=[0.7, 0.7, 0.6]
168
+ material = pyrender.MetallicRoughnessMaterial(
169
+ metallicFactor=0.2,
170
+ roughnessFactor=1.0,
171
+ alphaMode='OPAQUE',
172
+ baseColorFactor=(color[0], color[1], color[2], 1.0)
173
+ )
174
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
175
+ scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05))
176
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
177
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0])
178
+ scene.add(light, pose=light_pose)
179
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0])
180
+ scene.add(light, pose=light_pose)
181
  scene.add(mesh, 'mesh')
182
 
183
+ # focal, princpt = cam_param['focal'], cam_param['princpt']
184
+ # camera = pyrender.IntrinsicsCamera(fx=focal[0], fy=focal[1], cx=princpt[0], cy=princpt[1])
185
+ sx, sy, tx, ty = cam_param
186
+ camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0)
187
+ camera_pose = np.eye(4)
188
+ scene.add(camera, pose=camera_pose)
189
 
190
  # renderer
191
  renderer = pyrender.OffscreenRenderer(viewport_width=img.shape[1], viewport_height=img.shape[0], point_size=1.0)
192
 
 
 
 
 
 
 
 
 
 
 
193
  # render
194
  rgb, depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
195
  rgb = rgb[:,:,:3].astype(np.float32)
main/__pycache__/config.cpython-39.pyc ADDED
Binary file (2.31 kB). View file
 
main/__pycache__/postometro.cpython-39.pyc ADDED
Binary file (9.31 kB). View file
 
main/config/config_postometro.py CHANGED
@@ -3,109 +3,12 @@ import os.path as osp
3
 
4
  # will be update in exp
5
  num_gpus = -1
6
- exp_name = 'output/exp1/pre_analysis'
7
-
8
- # quick access
9
- save_epoch = 1
10
- lr = 1e-5
11
- end_epoch = 10
12
- train_batch_size = 16
13
-
14
- syncbn = True
15
- bbox_ratio = 1.2
16
-
17
- # continue
18
- continue_train = False
19
- start_over = True
20
-
21
- # dataset setting
22
- agora_fix_betas = True
23
- agora_fix_global_orient_transl = True
24
- agora_valid_root_pose = True
25
-
26
- # all
27
- dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
28
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
29
- trainset_3d = ['MSCOCO','AGORA', 'UBody']
30
- trainset_2d = ['PW3D', 'MPII', 'Human36M']
31
- trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
32
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
33
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
34
- 'Behave', 'UP3D', 'ARCTIC',
35
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
36
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
37
- testset = 'EHF'
38
-
39
- use_cache = True
40
- # downsample
41
- BEDLAM_train_sample_interval = 5
42
- EgoBody_Kinect_train_sample_interval = 10
43
- train_sample_interval = 10 # UBody
44
- MPI_INF_3DHP_train_sample_interval = 5
45
- InstaVariety_train_sample_interval = 10
46
- RenBody_HiRes_train_sample_interval = 5
47
- ARCTIC_train_sample_interval = 10
48
- # RenBody_train_sample_interval = 10
49
- FIT3D_train_sample_interval = 10
50
- Talkshow_train_sample_interval = 10
51
-
52
- # strategy
53
- data_strategy = 'balance' # 'balance' need to define total_data_len
54
- total_data_len = 4500000
55
-
56
- # model
57
- smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
58
- smplx_pose_weight = 10.0
59
-
60
- smplx_kps_3d_weight = 100.0
61
- smplx_kps_2d_weight = 1.0
62
- net_kps_2d_weight = 1.0
63
-
64
- agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
65
-
66
- model_type = 'smpler_x_h'
67
- encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py'
68
- encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth'
69
- feat_dim = 1280
70
-
71
- ## =====FIXED ARGS============================================================
72
- ## model setting
73
- upscale = 4
74
- hand_pos_joint_num = 20
75
- face_pos_joint_num = 72
76
- num_task_token = 24
77
- num_noise_sample = 0
78
-
79
- ## UBody setting
80
- train_sample_interval = 10
81
- test_sample_interval = 100
82
- make_same_len = False
83
 
84
  ## input, output size
85
  input_img_shape = (256, 256)
86
  input_body_shape = (256, 256)
87
- output_hm_shape = (16, 16, 12)
88
- input_hand_shape = (256, 256)
89
- output_hand_hm_shape = (16, 16, 16)
90
- output_face_hm_shape = (8, 8, 8)
91
- input_face_shape = (192, 192)
92
- focal = (5000, 5000) # virtual focal lengths
93
- princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
94
- body_3d_size = 2
95
- hand_3d_size = 0.3
96
- face_3d_size = 0.3
97
- camera_3d_size = 2.5
98
-
99
- ## training config
100
- print_iters = 100
101
- lr_mult = 1
102
 
103
- ## testing config
104
- test_batch_size = 32
105
-
106
- ## others
107
- num_thread = 2
108
- vis = False
109
 
110
- ## directory
111
- output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
 
3
 
4
  # will be update in exp
5
  num_gpus = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  ## input, output size
8
  input_img_shape = (256, 256)
9
  input_body_shape = (256, 256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ renderer_input_body_shape = (256, 256)
12
+ focal = (5000, 5000) # virtual focal lengths
13
+ princpt = (renderer_input_body_shape[1] / 2, renderer_input_body_shape[0] / 2) # virtual principal point position
 
 
 
14
 
 
 
main/config/config_smpler_x_b32.py DELETED
@@ -1,112 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- # will be update in exp
5
- num_gpus = -1
6
- exp_name = 'output/exp1/pre_analysis'
7
-
8
- # quick access
9
- save_epoch = 1
10
- lr = 1e-5
11
- end_epoch = 10
12
- train_batch_size = 32
13
-
14
- syncbn = True
15
- bbox_ratio = 1.2
16
-
17
- # continue
18
- continue_train = False
19
- start_over = True
20
-
21
- # dataset setting
22
- agora_fix_betas = True
23
- agora_fix_global_orient_transl = True
24
- agora_valid_root_pose = True
25
-
26
- # all
27
- dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
28
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
29
- trainset_3d = ['MSCOCO','AGORA', 'UBody']
30
- trainset_2d = ['PW3D', 'MPII', 'Human36M']
31
- trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
32
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
33
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
34
- 'Behave', 'UP3D', 'ARCTIC',
35
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
36
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
37
- testset = 'EHF'
38
-
39
- use_cache = True
40
- # downsample
41
- BEDLAM_train_sample_interval = 5
42
- EgoBody_Kinect_train_sample_interval = 10
43
- train_sample_interval = 10 # UBody
44
- MPI_INF_3DHP_train_sample_interval = 5
45
- InstaVariety_train_sample_interval = 10
46
- RenBody_HiRes_train_sample_interval = 5
47
- ARCTIC_train_sample_interval = 10
48
- # RenBody_train_sample_interval = 10
49
- FIT3D_train_sample_interval = 10
50
- Talkshow_train_sample_interval = 10
51
-
52
- # strategy
53
- data_strategy = 'balance' # 'balance' need to define total_data_len
54
- total_data_len = 4500000
55
-
56
- # model
57
- smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
58
- smplx_pose_weight = 10.0
59
-
60
- smplx_kps_3d_weight = 100.0
61
- smplx_kps_2d_weight = 1.0
62
- net_kps_2d_weight = 1.0
63
-
64
- agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
65
-
66
- model_type = 'smpler_x_b'
67
- encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_base.py'
68
- encoder_pretrained_model_path = 'pretrained_models/vitpose_base.pth'
69
- feat_dim = 768
70
-
71
-
72
- ## =====FIXED ARGS============================================================
73
- ## model setting
74
- upscale = 4
75
- hand_pos_joint_num = 20
76
- face_pos_joint_num = 72
77
- num_task_token = 24
78
- num_noise_sample = 0
79
-
80
- ## UBody setting
81
- train_sample_interval = 10
82
- test_sample_interval = 100
83
- make_same_len = False
84
-
85
- ## input, output size
86
- input_img_shape = (512, 384)
87
- input_body_shape = (256, 192)
88
- output_hm_shape = (16, 16, 12)
89
- input_hand_shape = (256, 256)
90
- output_hand_hm_shape = (16, 16, 16)
91
- output_face_hm_shape = (8, 8, 8)
92
- input_face_shape = (192, 192)
93
- focal = (5000, 5000) # virtual focal lengths
94
- princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
95
- body_3d_size = 2
96
- hand_3d_size = 0.3
97
- face_3d_size = 0.3
98
- camera_3d_size = 2.5
99
-
100
- ## training config
101
- print_iters = 100
102
- lr_mult = 1
103
-
104
- ## testing config
105
- test_batch_size = 32
106
-
107
- ## others
108
- num_thread = 2
109
- vis = False
110
-
111
- ## directory
112
- output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/config/config_smpler_x_h32.py DELETED
@@ -1,111 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- # will be update in exp
5
- num_gpus = -1
6
- exp_name = 'output/exp1/pre_analysis'
7
-
8
- # quick access
9
- save_epoch = 1
10
- lr = 1e-5
11
- end_epoch = 10
12
- train_batch_size = 16
13
-
14
- syncbn = True
15
- bbox_ratio = 1.2
16
-
17
- # continue
18
- continue_train = False
19
- start_over = True
20
-
21
- # dataset setting
22
- agora_fix_betas = True
23
- agora_fix_global_orient_transl = True
24
- agora_valid_root_pose = True
25
-
26
- # all
27
- dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
28
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
29
- trainset_3d = ['MSCOCO','AGORA', 'UBody']
30
- trainset_2d = ['PW3D', 'MPII', 'Human36M']
31
- trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
32
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
33
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
34
- 'Behave', 'UP3D', 'ARCTIC',
35
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
36
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
37
- testset = 'EHF'
38
-
39
- use_cache = True
40
- # downsample
41
- BEDLAM_train_sample_interval = 5
42
- EgoBody_Kinect_train_sample_interval = 10
43
- train_sample_interval = 10 # UBody
44
- MPI_INF_3DHP_train_sample_interval = 5
45
- InstaVariety_train_sample_interval = 10
46
- RenBody_HiRes_train_sample_interval = 5
47
- ARCTIC_train_sample_interval = 10
48
- # RenBody_train_sample_interval = 10
49
- FIT3D_train_sample_interval = 10
50
- Talkshow_train_sample_interval = 10
51
-
52
- # strategy
53
- data_strategy = 'balance' # 'balance' need to define total_data_len
54
- total_data_len = 4500000
55
-
56
- # model
57
- smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
58
- smplx_pose_weight = 10.0
59
-
60
- smplx_kps_3d_weight = 100.0
61
- smplx_kps_2d_weight = 1.0
62
- net_kps_2d_weight = 1.0
63
-
64
- agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
65
-
66
- model_type = 'smpler_x_h'
67
- encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_huge.py'
68
- encoder_pretrained_model_path = 'pretrained_models/vitpose_huge.pth'
69
- feat_dim = 1280
70
-
71
- ## =====FIXED ARGS============================================================
72
- ## model setting
73
- upscale = 4
74
- hand_pos_joint_num = 20
75
- face_pos_joint_num = 72
76
- num_task_token = 24
77
- num_noise_sample = 0
78
-
79
- ## UBody setting
80
- train_sample_interval = 10
81
- test_sample_interval = 100
82
- make_same_len = False
83
-
84
- ## input, output size
85
- input_img_shape = (512, 384)
86
- input_body_shape = (256, 192)
87
- output_hm_shape = (16, 16, 12)
88
- input_hand_shape = (256, 256)
89
- output_hand_hm_shape = (16, 16, 16)
90
- output_face_hm_shape = (8, 8, 8)
91
- input_face_shape = (192, 192)
92
- focal = (5000, 5000) # virtual focal lengths
93
- princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
94
- body_3d_size = 2
95
- hand_3d_size = 0.3
96
- face_3d_size = 0.3
97
- camera_3d_size = 2.5
98
-
99
- ## training config
100
- print_iters = 100
101
- lr_mult = 1
102
-
103
- ## testing config
104
- test_batch_size = 32
105
-
106
- ## others
107
- num_thread = 2
108
- vis = False
109
-
110
- ## directory
111
- output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/config/config_smpler_x_l32.py DELETED
@@ -1,112 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- # will be update in exp
5
- num_gpus = -1
6
- exp_name = 'output/exp1/pre_analysis'
7
-
8
- # quick access
9
- save_epoch = 1
10
- lr = 1e-5
11
- end_epoch = 10
12
- train_batch_size = 32
13
-
14
- syncbn = True
15
- bbox_ratio = 1.2
16
-
17
- # continue
18
- continue_train = False
19
- start_over = True
20
-
21
- # dataset setting
22
- agora_fix_betas = True
23
- agora_fix_global_orient_transl = True
24
- agora_valid_root_pose = True
25
-
26
- # all
27
- dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
28
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
29
- trainset_3d = ['MSCOCO','AGORA', 'UBody']
30
- trainset_2d = ['PW3D', 'MPII', 'Human36M']
31
- trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
32
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
33
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
34
- 'Behave', 'UP3D', 'ARCTIC',
35
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
36
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
37
- testset = 'EHF'
38
-
39
- use_cache = True
40
- # downsample
41
- BEDLAM_train_sample_interval = 5
42
- EgoBody_Kinect_train_sample_interval = 10
43
- train_sample_interval = 10 # UBody
44
- MPI_INF_3DHP_train_sample_interval = 5
45
- InstaVariety_train_sample_interval = 10
46
- RenBody_HiRes_train_sample_interval = 5
47
- ARCTIC_train_sample_interval = 10
48
- # RenBody_train_sample_interval = 10
49
- FIT3D_train_sample_interval = 10
50
- Talkshow_train_sample_interval = 10
51
-
52
- # strategy
53
- data_strategy = 'balance' # 'balance' need to define total_data_len
54
- total_data_len = 4500000
55
-
56
- # model
57
- smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
58
- smplx_pose_weight = 10.0
59
-
60
- smplx_kps_3d_weight = 100.0
61
- smplx_kps_2d_weight = 1.0
62
- net_kps_2d_weight = 1.0
63
-
64
- agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
65
-
66
- model_type = 'smpler_x_l'
67
- encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_large.py'
68
- encoder_pretrained_model_path = 'pretrained_models/vitpose_large.pth'
69
- feat_dim = 1024
70
-
71
-
72
- ## =====FIXED ARGS============================================================
73
- ## model setting
74
- upscale = 4
75
- hand_pos_joint_num = 20
76
- face_pos_joint_num = 72
77
- num_task_token = 24
78
- num_noise_sample = 0
79
-
80
- ## UBody setting
81
- train_sample_interval = 10
82
- test_sample_interval = 100
83
- make_same_len = False
84
-
85
- ## input, output size
86
- input_img_shape = (512, 384)
87
- input_body_shape = (256, 192)
88
- output_hm_shape = (16, 16, 12)
89
- input_hand_shape = (256, 256)
90
- output_hand_hm_shape = (16, 16, 16)
91
- output_face_hm_shape = (8, 8, 8)
92
- input_face_shape = (192, 192)
93
- focal = (5000, 5000) # virtual focal lengths
94
- princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
95
- body_3d_size = 2
96
- hand_3d_size = 0.3
97
- face_3d_size = 0.3
98
- camera_3d_size = 2.5
99
-
100
- ## training config
101
- print_iters = 100
102
- lr_mult = 1
103
-
104
- ## testing config
105
- test_batch_size = 32
106
-
107
- ## others
108
- num_thread = 2
109
- vis = False
110
-
111
- ## directory
112
- output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/config/config_smpler_x_s32.py DELETED
@@ -1,111 +0,0 @@
1
- import os
2
- import os.path as osp
3
-
4
- # will be update in exp
5
- num_gpus = -1
6
- exp_name = 'output/exp1/pre_analysis'
7
-
8
- # quick access
9
- save_epoch = 1
10
- lr = 1e-5
11
- end_epoch = 10
12
- train_batch_size = 32
13
-
14
- syncbn = True
15
- bbox_ratio = 1.2
16
-
17
- # continue
18
- continue_train = False
19
- start_over = True
20
-
21
- # dataset setting
22
- agora_fix_betas = True
23
- agora_fix_global_orient_transl = True
24
- agora_valid_root_pose = True
25
-
26
- # all data
27
- dataset_list = ['Human36M', 'MSCOCO', 'MPII', 'AGORA', 'EHF', 'SynBody', 'GTA_Human2', \
28
- 'EgoBody_Egocentric', 'EgoBody_Kinect', 'UBody', 'PW3D', 'MuCo', 'PROX']
29
- trainset_3d = ['MSCOCO','AGORA', 'UBody']
30
- trainset_2d = ['PW3D', 'MPII', 'Human36M']
31
- trainset_humandata = ['BEDLAM', 'SPEC', 'GTA_Human2','SynBody', 'PoseTrack',
32
- 'EgoBody_Egocentric', 'PROX', 'CrowdPose',
33
- 'EgoBody_Kinect', 'MPI_INF_3DHP', 'RICH', 'MuCo', 'InstaVariety',
34
- 'Behave', 'UP3D', 'ARCTIC',
35
- 'OCHuman', 'CHI3D', 'RenBody_HiRes', 'MTP', 'HumanSC3D', 'RenBody',
36
- 'FIT3D', 'Talkshow' , 'SSP3D', 'LSPET']
37
- testset = 'EHF'
38
-
39
- use_cache = True
40
- # downsample
41
- BEDLAM_train_sample_interval = 5
42
- EgoBody_Kinect_train_sample_interval = 10
43
- train_sample_interval = 10 # UBody
44
- MPI_INF_3DHP_train_sample_interval = 5
45
- InstaVariety_train_sample_interval = 10
46
- RenBody_HiRes_train_sample_interval = 5
47
- ARCTIC_train_sample_interval = 10
48
- # RenBody_train_sample_interval = 10
49
- FIT3D_train_sample_interval = 10
50
- Talkshow_train_sample_interval = 10
51
-
52
- # strategy
53
- data_strategy = 'balance' # 'balance' need to define total_data_len
54
- total_data_len = 4500000
55
-
56
- # model
57
- smplx_loss_weight = 1.0 #2 for agora_model for smplx shape
58
- smplx_pose_weight = 10.0
59
-
60
- smplx_kps_3d_weight = 100.0
61
- smplx_kps_2d_weight = 1.0
62
- net_kps_2d_weight = 1.0
63
-
64
- agora_benchmark = 'agora_model' # 'agora_model', 'test_only'
65
-
66
- model_type = 'smpler_x_s'
67
- encoder_config_file = 'main/transformer_utils/configs/smpler_x/encoder/body_encoder_small.py'
68
- encoder_pretrained_model_path = 'pretrained_models/vitpose_small.pth'
69
- feat_dim = 384
70
-
71
- ## =====FIXED ARGS============================================================
72
- ## model setting
73
- upscale = 4
74
- hand_pos_joint_num = 20
75
- face_pos_joint_num = 72
76
- num_task_token = 24
77
- num_noise_sample = 0
78
-
79
- ## UBody setting
80
- train_sample_interval = 10
81
- test_sample_interval = 100
82
- make_same_len = False
83
-
84
- ## input, output size
85
- input_img_shape = (512, 384)
86
- input_body_shape = (256, 192)
87
- output_hm_shape = (16, 16, 12)
88
- input_hand_shape = (256, 256)
89
- output_hand_hm_shape = (16, 16, 16)
90
- output_face_hm_shape = (8, 8, 8)
91
- input_face_shape = (192, 192)
92
- focal = (5000, 5000) # virtual focal lengths
93
- princpt = (input_body_shape[1] / 2, input_body_shape[0] / 2) # virtual principal point position
94
- body_3d_size = 2
95
- hand_3d_size = 0.3
96
- face_3d_size = 0.3
97
- camera_3d_size = 2.5
98
-
99
- ## training config
100
- print_iters = 100
101
- lr_mult = 1
102
-
103
- ## testing config
104
- test_batch_size = 32
105
-
106
- ## others
107
- num_thread = 2
108
- vis = False
109
-
110
- ## directory
111
- output_dir, model_dir, vis_dir, log_dir, result_dir, code_dir = None, None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference.py CHANGED
@@ -11,11 +11,12 @@ sys.path.insert(0, osp.join(CUR_DIR, '..', 'main'))
11
  sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
12
  from config import cfg
13
  import cv2
14
- from tqdm import tqdm
15
- import json
16
- from typing import Literal, Union
17
  from mmdet.apis import init_detector, inference_detector
18
  from utils.inference_utils import process_mmdet_results, non_max_suppression
 
 
 
 
19
 
20
  class Inferer:
21
 
@@ -29,16 +30,18 @@ class Inferer:
29
  # ckpt_path = osp.join(CUR_DIR, '../pretrained_models', f'{pretrained_model}.pth.tar')
30
  ckpt_path = None # for config
31
  cfg.get_config_fromfile(config_path)
 
32
  cfg.update_config(num_gpus, ckpt_path, output_folder, self.device)
33
  self.cfg = cfg
34
  cudnn.benchmark = True
35
 
36
- # # load model
37
- # from base import Demoer
38
- # demoer = Demoer()
39
- # demoer._make_model()
40
- # demoer.model.eval()
41
- # self.demoer = demoer
 
42
 
43
  # load faster-rcnn as human detector
44
  checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/mmdet/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
@@ -46,17 +49,20 @@ class Inferer:
46
  model = init_detector(config_file, checkpoint_file, device=self.device) # or device='cuda:0'
47
  self.model = model
48
 
49
- def infer(self, original_img, iou_thr, frame, multi_person=False, mesh_as_vertices=False):
50
  from utils.preprocessing import process_bbox, generate_patch_image
51
- # from utils.vis import render_mesh, save_obj
52
  # from utils.human_models import smpl_x
53
- mesh_paths = []
54
- smplx_paths = []
55
  # prepare input image
56
- transform = transforms.ToTensor()
 
57
  vis_img = original_img.copy()
58
  original_img_height, original_img_width = original_img.shape[:2]
59
 
 
 
 
60
  ## mmdet inference
61
  mmdet_results = inference_detector(self.model, original_img)
62
  mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
@@ -99,51 +105,69 @@ class Inferer:
99
  top_left = (int(bbox[0]), int(bbox[1]))
100
  bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
101
  cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2)
102
-
103
-
104
  # human model inference
105
- # img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
106
- # img = transform(img.astype(np.float32))/255
107
- # img = img.to(cfg.device)[None,:,:,:]
108
- # inputs = {'img': img}
109
- # targets = {}
110
- # meta_info = {}
111
-
112
- # # mesh recovery
113
- # with torch.no_grad():
114
- # out = self.demoer.model(inputs, targets, meta_info, 'test')
115
- # mesh = out['smplx_mesh_cam'].detach().cpu().numpy()[0]
116
-
117
- # ## save mesh
118
- # save_path_mesh = os.path.join(self.output_folder, 'mesh')
119
- # os.makedirs(save_path_mesh, exist_ok= True)
120
- # obj_path = os.path.join(save_path_mesh, f'{frame:05}_{bbox_id}.obj')
121
- # save_obj(mesh, smpl_x.face, obj_path)
122
- # mesh_paths.append(obj_path)
123
- # ## save single person param
124
- # smplx_pred = {}
125
- # smplx_pred['global_orient'] = out['smplx_root_pose'].reshape(-1,3).cpu().numpy()
126
- # smplx_pred['body_pose'] = out['smplx_body_pose'].reshape(-1,3).cpu().numpy()
127
- # smplx_pred['left_hand_pose'] = out['smplx_lhand_pose'].reshape(-1,3).cpu().numpy()
128
- # smplx_pred['right_hand_pose'] = out['smplx_rhand_pose'].reshape(-1,3).cpu().numpy()
129
- # smplx_pred['jaw_pose'] = out['smplx_jaw_pose'].reshape(-1,3).cpu().numpy()
130
- # smplx_pred['leye_pose'] = np.zeros((1, 3))
131
- # smplx_pred['reye_pose'] = np.zeros((1, 3))
132
- # smplx_pred['betas'] = out['smplx_shape'].reshape(-1,10).cpu().numpy()
133
- # smplx_pred['expression'] = out['smplx_expr'].reshape(-1,10).cpu().numpy()
134
- # smplx_pred['transl'] = out['cam_trans'].reshape(-1,3).cpu().numpy()
135
- # save_path_smplx = os.path.join(self.output_folder, 'smplx')
136
- # os.makedirs(save_path_smplx, exist_ok= True)
137
-
138
- # npz_path = os.path.join(save_path_smplx, f'{frame:05}_{bbox_id}.npz')
139
- # np.savez(npz_path, **smplx_pred)
140
- # smplx_paths.append(npz_path)
141
-
142
- # ## render single person mesh
143
- # focal = [self.cfg.focal[0] / self.cfg.input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.input_body_shape[0] * bbox[3]]
144
- # princpt = [self.cfg.princpt[0] / self.cfg.input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.input_body_shape[0] * bbox[3] + bbox[1]]
145
- # vis_img = render_mesh(vis_img, mesh, smpl_x.face, {'focal': focal, 'princpt': princpt},
146
  # mesh_as_vertices=mesh_as_vertices)
147
- # vis_img = vis_img.astype('uint8')
148
- return vis_img, num_bbox, ok_bboxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
 
 
 
 
11
  sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
12
  from config import cfg
13
  import cv2
 
 
 
14
  from mmdet.apis import init_detector, inference_detector
15
  from utils.inference_utils import process_mmdet_results, non_max_suppression
16
+ from postometro_utils.smpl import SMPL
17
+ import data.config as smpl_cfg
18
+ from postometro import get_model
19
+ from postometro_utils.renderer_pyrender import PyRender_Renderer
20
 
21
  class Inferer:
22
 
 
30
  # ckpt_path = osp.join(CUR_DIR, '../pretrained_models', f'{pretrained_model}.pth.tar')
31
  ckpt_path = None # for config
32
  cfg.get_config_fromfile(config_path)
33
+ # uodate config
34
  cfg.update_config(num_gpus, ckpt_path, output_folder, self.device)
35
  self.cfg = cfg
36
  cudnn.benchmark = True
37
 
38
+ # load SMPL
39
+ self.smpl = SMPL().to(self.device)
40
+ self.faces = self.smpl.faces.cpu().numpy()
41
+
42
+ # load model
43
+ hmr_model_checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/postometro/resnet_state_dict.bin')
44
+ self.hmr_model = get_model(backbone_str='resnet50',device=self.device, checkpoint_file = hmr_model_checkpoint_file)
45
 
46
  # load faster-rcnn as human detector
47
  checkpoint_file = osp.join(CUR_DIR, '../pretrained_models/mmdet/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth')
 
49
  model = init_detector(config_file, checkpoint_file, device=self.device) # or device='cuda:0'
50
  self.model = model
51
 
52
+ def infer(self, original_img, iou_thr, multi_person=False, mesh_as_vertices=False):
53
  from utils.preprocessing import process_bbox, generate_patch_image
54
+ from utils.vis import render_mesh
55
  # from utils.human_models import smpl_x
56
+
 
57
  # prepare input image
58
+ transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
59
+ std=[0.229, 0.224, 0.225])
60
  vis_img = original_img.copy()
61
  original_img_height, original_img_width = original_img.shape[:2]
62
 
63
+ # load renderer
64
+ # self.renderer = PyRender_Renderer(resolution=(original_img_width, original_img_height), faces=self.faces)
65
+
66
  ## mmdet inference
67
  mmdet_results = inference_detector(self.model, original_img)
68
  mmdet_box = process_mmdet_results(mmdet_results, cat_id=0, multi_person=True)
 
105
  top_left = (int(bbox[0]), int(bbox[1]))
106
  bottom_right = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
107
  cv2.rectangle(vis_img, top_left, bottom_right, (0, 0, 255), 2)
108
+
 
109
  # human model inference
110
+ img, img2bb_trans, bb2img_trans = generate_patch_image(original_img, bbox, 1.0, 0.0, False, self.cfg.input_img_shape)
111
+ vis_patched_images = img.copy()
112
+ # here we pre-process images
113
+ img = img.transpose((2,0,1)) # h,w,c -> c,h,w
114
+ img = torch.from_numpy(img).float() / 255.0
115
+ # Store image before normalization to use it in visualization
116
+ img = transform(img)
117
+ img = img.to(cfg.device)[None,:,:,:]
118
+
119
+ self.renderer = PyRender_Renderer(resolution=(bbox[2], bbox[3]), faces=self.faces)
120
+
121
+ # mesh recovery
122
+ with torch.no_grad():
123
+ out = self.hmr_model(img)
124
+ pred_cam, pred_3d_vertices_fine = out['pred_cam'], out['pred_3d_vertices_fine']
125
+ pred_3d_joints_from_smpl = self.smpl.get_h36m_joints(pred_3d_vertices_fine) # batch_size X 17 X 3
126
+ pred_3d_joints_from_smpl_pelvis = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_NAME.index('Pelvis'),:]
127
+ pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,smpl_cfg.H36M_J17_TO_J14,:] # batch_size X 14 X 3
128
+ # normalize predicted vertices
129
+ pred_3d_vertices_fine = pred_3d_vertices_fine - pred_3d_joints_from_smpl_pelvis[:, None, :] # batch_size X 6890 X 3
130
+ pred_3d_vertices_fine = pred_3d_vertices_fine.detach().cpu().numpy()[0] # 6890 X 3
131
+ pred_cam = pred_cam.detach().cpu().numpy()[0]
132
+ bbox_cx, bbox_cy = bbox[0] + bbox[2] / 2, bbox[1] + bbox[3] / 2
133
+ img_cx, img_cy = original_img_width / 2, original_img_height / 2
134
+ cx_delta, cy_delta = bbox_cx / img_cx - 1, bbox_cy / img_cy - 1
135
+
136
+ # render single person mesh
137
+ # focal = [self.cfg.focal[0] / self.cfg.renderer_input_body_shape[1] * bbox[2], self.cfg.focal[1] / self.cfg.renderer_input_body_shape[0] * bbox[3]]
138
+ # princpt = [self.cfg.princpt[0] / self.cfg.renderer_input_body_shape[1] * bbox[2] + bbox[0], self.cfg.princpt[1] / self.cfg.renderer_input_body_shape[0] * bbox[3] + bbox[1]]
139
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, {'focal': focal, 'princpt': princpt},
 
 
 
 
 
 
 
 
 
 
 
140
  # mesh_as_vertices=mesh_as_vertices)
141
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), pred_cam[1], pred_cam[2]], mesh_as_vertices=mesh_as_vertices)
142
+ # import ipdb
143
+ # ipdb.set_trace()
144
+ vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]),
145
+ pred_cam[1] + cx_delta / (pred_cam[0] / (original_img_width / bbox[2])),
146
+ pred_cam[2] + cy_delta / (pred_cam[0] / (original_img_height / bbox[3]))],
147
+ mesh_as_vertices=mesh_as_vertices)
148
+ # vis_img = render_mesh(vis_img, pred_3d_vertices_fine, self.faces, [pred_cam[0] / (original_img_width / bbox[2]), pred_cam[0] / (original_img_height / bbox[3]), 0, 0], mesh_as_vertices=mesh_as_vertices)
149
+
150
+ # bbox_meta = {'bbox': bbox, 'img_hw': [original_img_height, original_img_width]}
151
+ # vis_img = self.renderer(pred_3d_vertices_fine, bbox_meta, vis_img, pred_cam)
152
+ vis_img = vis_img.astype('uint8')
153
+ return vis_img, len(ok_bboxes), ok_bboxes
154
+
155
+
156
+ if __name__ == '__main__':
157
+ from PIL import Image
158
+ inferer = Inferer('postometro', 0, './out_folder') # gpu
159
+ image_path = f'../assets/07.jpg'
160
+ image = Image.open(image_path)
161
+ # Convert the PIL image to a NumPy array
162
+ image_np = np.array(image)
163
+ vis_img, _ , _ = inferer.infer(image_np, 0.2, multi_person=True, mesh_as_vertices=False)
164
+ save_path = f'./saved_vis_07.jpg'
165
+
166
+ # Ensure the image is in the correct format (PIL expects uint8)
167
+ if vis_img.dtype != np.uint8:
168
+ vis_img = vis_img.astype('uint8')
169
 
170
+ # Convert the Numpy array (if RGB) to a PIL image and save
171
+ image = Image.fromarray(vis_img)
172
+ image.save(save_path)
173
+
main/pct_utils/__pycache__/modules.cpython-39.pyc ADDED
Binary file (3.4 kB). View file
 
main/pct_utils/__pycache__/pct.cpython-39.pyc ADDED
Binary file (1.89 kB). View file
 
main/pct_utils/__pycache__/pct_backbone.cpython-39.pyc ADDED
Binary file (40.6 kB). View file
 
main/pct_utils/__pycache__/pct_head.cpython-39.pyc ADDED
Binary file (6.93 kB). View file
 
main/pct_utils/__pycache__/pct_tokenizer.cpython-39.pyc ADDED
Binary file (9.14 kB). View file
 
main/pct_utils/modules.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch)
3
+ # Borrow from ResNet
4
+ # Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class FCBlock(nn.Module):
12
+ def __init__(self, dim, out_dim):
13
+ super().__init__()
14
+
15
+ self.ff = nn.Sequential(
16
+ nn.Linear(dim, out_dim),
17
+ nn.LayerNorm(out_dim),
18
+ nn.ReLU(inplace=True),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.ff(x)
23
+
24
+
25
+ class MLPBlock(nn.Module):
26
+ def __init__(self, dim, inter_dim, dropout_ratio):
27
+ super().__init__()
28
+
29
+ self.ff = nn.Sequential(
30
+ nn.Linear(dim, inter_dim),
31
+ nn.GELU(),
32
+ nn.Dropout(dropout_ratio),
33
+ nn.Linear(inter_dim, dim),
34
+ nn.Dropout(dropout_ratio)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.ff(x)
39
+
40
+
41
+ class MixerLayer(nn.Module):
42
+ def __init__(self,
43
+ hidden_dim,
44
+ hidden_inter_dim,
45
+ token_dim,
46
+ token_inter_dim,
47
+ dropout_ratio):
48
+ super().__init__()
49
+
50
+ self.layernorm1 = nn.LayerNorm(hidden_dim)
51
+ self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio)
52
+ self.layernorm2 = nn.LayerNorm(hidden_dim)
53
+ self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio)
54
+
55
+ def forward(self, x):
56
+ y = self.layernorm1(x)
57
+ y = y.transpose(2, 1)
58
+ y = self.MLP_token(y)
59
+ y = y.transpose(2, 1)
60
+ z = self.layernorm2(x + y)
61
+ z = self.MLP_channel(z)
62
+ out = x + y + z
63
+ return out
64
+
65
+
66
+ class BasicBlock(nn.Module):
67
+ expansion = 1
68
+
69
+ def __init__(self, inplanes, planes, stride=1,
70
+ downsample=None, dilation=1):
71
+ super(BasicBlock, self).__init__()
72
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
73
+ padding=dilation, bias=False, dilation=dilation)
74
+ self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
77
+ padding=dilation, bias=False, dilation=dilation)
78
+ self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+
82
+
83
+ def forward(self, x):
84
+ residual = x
85
+
86
+ out = self.conv1(x)
87
+ out = self.bn1(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv2(out)
91
+ out = self.bn2(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+ def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
102
+ layers = []
103
+ for i in range(len(feat_dims)-1):
104
+ layers.append(
105
+ nn.Conv2d(
106
+ in_channels=feat_dims[i],
107
+ out_channels=feat_dims[i+1],
108
+ kernel_size=kernel,
109
+ stride=stride,
110
+ padding=padding
111
+ ))
112
+ # Do not use BN and ReLU for final estimation
113
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
114
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
115
+ layers.append(nn.ReLU(inplace=True))
116
+
117
+ return nn.Sequential(*layers)
main/pct_utils/pct.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pct_utils.pct_head import PCT_Head
4
+
5
+ class PCT(nn.Module):
6
+ def __init__(self,
7
+ args,
8
+ backbone,
9
+ stage_pct,
10
+ in_channels,
11
+ image_size,
12
+ num_joints,
13
+ pretrained=None,
14
+ tokenizer_pretrained=None):
15
+ super().__init__()
16
+ self.stage_pct = stage_pct
17
+ assert self.stage_pct in ["tokenizer", "classifier"]
18
+ self.guide_ratio = args.tokenizer_guide_ratio
19
+ self.image_guide = self.guide_ratio > 0.0
20
+ self.num_joints = num_joints
21
+
22
+ self.backbone = backbone
23
+ if self.image_guide:
24
+ self.extra_backbone = backbone
25
+
26
+ self.keypoint_head = PCT_Head(args,stage_pct,in_channels,image_size,num_joints)
27
+
28
+ if (pretrained is not None) or (tokenizer_pretrained is not None):
29
+ self.init_weights(pretrained, tokenizer_pretrained)
30
+
31
+ def init_weights(self, pretrained, tokenizer):
32
+ """Weight initialization for model."""
33
+ if self.stage_pct == "classifier":
34
+ self.backbone.init_weights(pretrained)
35
+ if self.image_guide:
36
+ self.extra_backbone.init_weights(pretrained)
37
+ self.keypoint_head.init_weights()
38
+ self.keypoint_head.tokenizer.init_weights(tokenizer)
39
+
40
+ def forward(self,img, joints, train = True):
41
+ if train:
42
+ output = None if self.stage_pct == "tokenizer" else self.backbone(img)
43
+ extra_output = self.extra_backbone(img) if self.image_guide else None
44
+
45
+ p_logits, p_joints, g_logits, e_latent_loss = \
46
+ self.keypoint_head(output, extra_output, joints, train=True)
47
+ return {
48
+ 'cls_logits': p_logits,
49
+ 'pred_pose': p_joints,
50
+ 'encoding_indices': g_logits,
51
+ 'e_latent_loss': e_latent_loss
52
+ }
53
+ else:
54
+ results = {}
55
+
56
+ batch_size, _, img_height, img_width = img.shape
57
+
58
+ output = None if self.stage_pct == "tokenizer" \
59
+ else self.backbone(img)
60
+ extra_output = self.extra_backbone(img) \
61
+ if self.image_guide and self.stage_pct == "tokenizer" else None
62
+
63
+ p_joints, encoding_scores, out_part_token_feat = \
64
+ self.keypoint_head(output, extra_output, joints, train=False)
65
+ return {
66
+ 'pred_pose': p_joints,
67
+ 'encoding_scores': encoding_scores,
68
+ 'part_token_feat': out_part_token_feat
69
+ }
main/pct_utils/pct_backbone.py ADDED
@@ -0,0 +1,1475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu, Yutong Lin, Yixuan Wei
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from functools import partial
13
+ import torch.utils.checkpoint as checkpoint
14
+ from torch.nn.utils import weight_norm
15
+ from torch import Tensor, Size
16
+ from typing import Union, List
17
+ import numpy as np
18
+ import logging
19
+
20
+ # Copyright (c) Open-MMLab. All rights reserved.
21
+ # Copy from mmcv source code.
22
+ import io
23
+ import os
24
+ import os.path as osp
25
+ import pkgutil
26
+ import time
27
+ import warnings
28
+ import numpy as np
29
+ from scipy import interpolate
30
+
31
+ import torch
32
+ import torchvision
33
+ import torch.distributed as dist
34
+ from torch.utils import model_zoo
35
+ from torch.nn import functional as F
36
+
37
+
38
+ def _load_checkpoint(filename, map_location=None):
39
+ if not osp.isfile(filename):
40
+ raise IOError(f'{filename} is not a checkpoint file')
41
+ checkpoint = torch.load(filename, map_location=map_location)
42
+ return checkpoint
43
+
44
+
45
+ def load_checkpoint_swin(model,
46
+ filename,
47
+ map_location='cpu',
48
+ strict=False,
49
+ rpe_interpolation='outer_mask',
50
+ logger=None):
51
+ """Load checkpoint from a file or URI.
52
+ Args:
53
+ model (Module): Module to load checkpoint.
54
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
55
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
56
+ details.
57
+ map_location (str): Same as :func:`torch.load`.
58
+ strict (bool): Whether to allow different params for the model and
59
+ checkpoint.
60
+ logger (:mod:`logging.Logger` or None): The logger for error message.
61
+ Returns:
62
+ dict or OrderedDict: The loaded checkpoint.
63
+ """
64
+ checkpoint = _load_checkpoint(filename, map_location)
65
+ # OrderedDict is a subclass of dict
66
+ if not isinstance(checkpoint, dict):
67
+ raise RuntimeError(
68
+ f'No state_dict found in checkpoint file {filename}')
69
+ # get state_dict from checkpoint
70
+ if 'state_dict' in checkpoint:
71
+ state_dict = checkpoint['state_dict']
72
+ elif 'model' in checkpoint:
73
+ state_dict = checkpoint['model']
74
+ elif 'module' in checkpoint:
75
+ state_dict = checkpoint['module']
76
+ else:
77
+ state_dict = checkpoint
78
+ # strip prefix of state_dict
79
+ if list(state_dict.keys())[0].startswith('module.'):
80
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
81
+
82
+ if list(state_dict.keys())[0].startswith('backbone.'):
83
+ state_dict = {k[9:]: v for k, v in state_dict.items()}
84
+
85
+ # for MoBY, load model of online branch
86
+ if sorted(list(state_dict.keys()))[2].startswith('encoder'):
87
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
88
+
89
+ # directly load here
90
+
91
+ model.load_state_dict(state_dict, strict=True)
92
+
93
+ return checkpoint
94
+
95
+
96
+ _shape_t = Union[int, List[int], Size]
97
+
98
+ from itertools import repeat
99
+ import collections.abc
100
+
101
+ def _ntuple(n):
102
+ def parse(x):
103
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
104
+ return tuple(x)
105
+ return tuple(repeat(x, n))
106
+ return parse
107
+
108
+ to_2tuple = _ntuple(2)
109
+
110
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
111
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
112
+
113
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
114
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
115
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
116
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
117
+ 'survival rate' as the argument.
118
+
119
+ """
120
+ if drop_prob == 0. or not training:
121
+ return x
122
+ keep_prob = 1 - drop_prob
123
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
124
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
125
+ if keep_prob > 0.0 and scale_by_keep:
126
+ random_tensor.div_(keep_prob)
127
+ return x * random_tensor
128
+
129
+
130
+ class DropPath(nn.Module):
131
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
132
+ """
133
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
134
+ super(DropPath, self).__init__()
135
+ self.drop_prob = drop_prob
136
+ self.scale_by_keep = scale_by_keep
137
+
138
+ def forward(self, x):
139
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
140
+
141
+ def extra_repr(self):
142
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
143
+
144
+ def _trunc_normal_(tensor, mean, std, a, b):
145
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
146
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
147
+ def norm_cdf(x):
148
+ # Computes standard normal cumulative distribution function
149
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
150
+
151
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
152
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
153
+ "The distribution of values may be incorrect.",
154
+ stacklevel=2)
155
+
156
+ # Values are generated by using a truncated uniform distribution and
157
+ # then using the inverse CDF for the normal distribution.
158
+ # Get upper and lower cdf values
159
+ l = norm_cdf((a - mean) / std)
160
+ u = norm_cdf((b - mean) / std)
161
+
162
+ # Uniformly fill tensor with values from [l, u], then translate to
163
+ # [2l-1, 2u-1].
164
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
165
+
166
+ # Use inverse cdf transform for normal distribution to get truncated
167
+ # standard normal
168
+ tensor.erfinv_()
169
+
170
+ # Transform to proper mean, std
171
+ tensor.mul_(std * math.sqrt(2.))
172
+ tensor.add_(mean)
173
+
174
+ # Clamp to ensure it's in the proper range
175
+ tensor.clamp_(min=a, max=b)
176
+ return tensor
177
+
178
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
179
+ r"""Fills the input Tensor with values drawn from a truncated
180
+ normal distribution. The values are effectively drawn from the
181
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
182
+ with values outside :math:`[a, b]` redrawn until they are within
183
+ the bounds. The method used for generating the random values works
184
+ best when :math:`a \leq \text{mean} \leq b`.
185
+
186
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
187
+ applied while sampling the normal with mean/std applied, therefore a, b args
188
+ should be adjusted to match the range of mean, std args.
189
+
190
+ Args:
191
+ tensor: an n-dimensional `torch.Tensor`
192
+ mean: the mean of the normal distribution
193
+ std: the standard deviation of the normal distribution
194
+ a: the minimum cutoff value
195
+ b: the maximum cutoff value
196
+ Examples:
197
+ >>> w = torch.empty(3, 5)
198
+ >>> nn.init.trunc_normal_(w)
199
+ """
200
+ with torch.no_grad():
201
+ return _trunc_normal_(tensor, mean, std, a, b)
202
+
203
+
204
+ class LayerNorm2D(nn.Module):
205
+ def __init__(self, normalized_shape, norm_layer=None):
206
+ super().__init__()
207
+ self.ln = norm_layer(normalized_shape) if norm_layer is not None else nn.Identity()
208
+
209
+ def forward(self, x):
210
+ """
211
+ x: N C H W
212
+ """
213
+ x = x.permute(0, 2, 3, 1)
214
+ x = self.ln(x)
215
+ x = x.permute(0, 3, 1, 2)
216
+ return x
217
+
218
+
219
+ class LayerNormFP32(nn.LayerNorm):
220
+ def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
221
+ super(LayerNormFP32, self).__init__(normalized_shape, eps, elementwise_affine)
222
+
223
+ def forward(self, input: Tensor) -> Tensor:
224
+ return F.layer_norm(
225
+ input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input)
226
+
227
+
228
+ class LinearFP32(nn.Linear):
229
+ def __init__(self, in_features, out_features, bias=True):
230
+ super(LinearFP32, self).__init__(in_features, out_features, bias)
231
+
232
+ def forward(self, input: Tensor) -> Tensor:
233
+ return F.linear(input.float(), self.weight.float(),
234
+ self.bias.float() if self.bias is not None else None)
235
+
236
+
237
+ class Mlp(nn.Module):
238
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
239
+ norm_layer=None, mlpfp32=False):
240
+ super().__init__()
241
+ out_features = out_features or in_features
242
+ hidden_features = hidden_features or in_features
243
+ self.mlpfp32 = mlpfp32
244
+
245
+ self.fc1 = nn.Linear(in_features, hidden_features)
246
+ self.act = act_layer()
247
+ self.fc2 = nn.Linear(hidden_features, out_features)
248
+ self.drop = nn.Dropout(drop)
249
+ if norm_layer is not None:
250
+ self.norm = norm_layer(hidden_features)
251
+ else:
252
+ self.norm = None
253
+
254
+ def forward(self, x, H, W):
255
+ x = self.fc1(x)
256
+ if self.norm:
257
+ x = self.norm(x)
258
+ x = self.act(x)
259
+ x = self.drop(x)
260
+ if self.mlpfp32:
261
+ x = self.fc2.float()(x.type(torch.float32))
262
+ x = self.drop.float()(x)
263
+ # print(f"======>[MLP FP32]")
264
+ else:
265
+ x = self.fc2(x)
266
+ x = self.drop(x)
267
+ return x
268
+
269
+
270
+ class ConvMlp(nn.Module):
271
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,
272
+ norm_layer=None, mlpfp32=False, proj_ln=False):
273
+ super().__init__()
274
+ self.mlp = Mlp(in_features=in_features, hidden_features=hidden_features, out_features=out_features,
275
+ act_layer=act_layer, drop=drop, norm_layer=norm_layer, mlpfp32=mlpfp32)
276
+ self.conv_proj = nn.Conv2d(in_features,
277
+ in_features,
278
+ kernel_size=3,
279
+ padding=1,
280
+ stride=1,
281
+ bias=False,
282
+ groups=in_features)
283
+ self.proj_ln = LayerNorm2D(in_features, LayerNormFP32) if proj_ln else None
284
+
285
+ def forward(self, x, H, W):
286
+ B, L, C = x.shape
287
+ assert L == H * W
288
+ x = x.view(B, H, W, C).permute(0, 3, 1, 2) # B C H W
289
+ x = self.conv_proj(x)
290
+ if self.proj_ln:
291
+ x = self.proj_ln(x)
292
+ x = x.permute(0, 2, 3, 1) # B H W C
293
+ x = x.reshape(B, L, C)
294
+ x = self.mlp(x, H, W)
295
+ return x
296
+
297
+
298
+ def window_partition(x, window_size):
299
+ """
300
+ Args:
301
+ x: (B, H, W, C)
302
+ window_size (int): window size
303
+
304
+ Returns:
305
+ windows: (num_windows*B, window_size, window_size, C)
306
+ """
307
+ B, H, W, C = x.shape
308
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
309
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
310
+ return windows
311
+
312
+
313
+ def window_reverse(windows, window_size, H, W):
314
+ """
315
+ Args:
316
+ windows: (num_windows*B, window_size, window_size, C)
317
+ window_size (int): Window size
318
+ H (int): Height of image
319
+ W (int): Width of image
320
+
321
+ Returns:
322
+ x: (B, H, W, C)
323
+ """
324
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
325
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
326
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
327
+ return x
328
+
329
+
330
+ class WindowAttention(nn.Module):
331
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
332
+ It supports both of shifted and non-shifted window.
333
+
334
+ Args:
335
+ dim (int): Number of input channels.
336
+ window_size (tuple[int]): The height and width of the window.
337
+ num_heads (int): Number of attention heads.
338
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
339
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
340
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
341
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
342
+ """
343
+
344
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
345
+ relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
346
+ rpe_output_type='normal', attn_type='normal', mlpfp32=False, pretrain_window_size=-1):
347
+
348
+ super().__init__()
349
+ self.dim = dim
350
+ self.window_size = window_size # Wh, Ww
351
+ self.num_heads = num_heads
352
+ self.mlpfp32 = mlpfp32
353
+ self.attn_type = attn_type
354
+ self.rpe_output_type = rpe_output_type
355
+ self.relative_coords_table_type = relative_coords_table_type
356
+
357
+ if self.attn_type == 'cosine_mh':
358
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
359
+ elif self.attn_type == 'normal':
360
+ head_dim = dim // num_heads
361
+ self.scale = qk_scale or head_dim ** -0.5
362
+ else:
363
+ raise NotImplementedError()
364
+ if self.relative_coords_table_type != "none":
365
+ # mlp to generate table of relative position bias
366
+ self.rpe_mlp = nn.Sequential(nn.Linear(2, rpe_hidden_dim, bias=True),
367
+ nn.ReLU(inplace=True),
368
+ LinearFP32(rpe_hidden_dim, num_heads, bias=False))
369
+
370
+ # get relative_coords_table
371
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
372
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
373
+ relative_coords_table = torch.stack(
374
+ torch.meshgrid([relative_coords_h,
375
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
376
+ if relative_coords_table_type == 'linear':
377
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
378
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
379
+ elif relative_coords_table_type == 'linear_bylayer':
380
+ relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1)
381
+ relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1)
382
+ elif relative_coords_table_type == 'norm8_log':
383
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
384
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
385
+ relative_coords_table *= 8 # normalize to -8, 8
386
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
387
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
388
+ elif relative_coords_table_type == 'norm8_log_192to640':
389
+ if self.window_size[0] == 40:
390
+ relative_coords_table[:, :, :, 0] /= (11)
391
+ relative_coords_table[:, :, :, 1] /= (11)
392
+ elif self.window_size[0] == 20:
393
+ relative_coords_table[:, :, :, 0] /= (5)
394
+ relative_coords_table[:, :, :, 1] /= (5)
395
+ else:
396
+ raise NotImplementedError
397
+ relative_coords_table *= 8 # normalize to -8, 8
398
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
399
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
400
+ # check
401
+ elif relative_coords_table_type == 'norm8_log_256to640':
402
+ if self.window_size[0] == 40:
403
+ relative_coords_table[:, :, :, 0] /= (15)
404
+ relative_coords_table[:, :, :, 1] /= (15)
405
+ elif self.window_size[0] == 20:
406
+ relative_coords_table[:, :, :, 0] /= (7)
407
+ relative_coords_table[:, :, :, 1] /= (7)
408
+ else:
409
+ raise NotImplementedError
410
+ relative_coords_table *= 8 # normalize to -8, 8
411
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
412
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
413
+ elif relative_coords_table_type == 'norm8_log_bylayer':
414
+ relative_coords_table[:, :, :, 0] /= (pretrain_window_size - 1)
415
+ relative_coords_table[:, :, :, 1] /= (pretrain_window_size - 1)
416
+ relative_coords_table *= 8 # normalize to -8, 8
417
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
418
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8) # log8
419
+ else:
420
+ raise NotImplementedError
421
+ self.register_buffer("relative_coords_table", relative_coords_table)
422
+ else:
423
+ self.relative_position_bias_table = nn.Parameter(
424
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
425
+ trunc_normal_(self.relative_position_bias_table, std=.02)
426
+
427
+ # get pair-wise relative position index for each token inside the window
428
+ coords_h = torch.arange(self.window_size[0])
429
+ coords_w = torch.arange(self.window_size[1])
430
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
431
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
432
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
433
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
434
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
435
+ relative_coords[:, :, 1] += self.window_size[1] - 1
436
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
437
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
438
+ self.register_buffer("relative_position_index", relative_position_index)
439
+
440
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
441
+ if qkv_bias:
442
+ self.q_bias = nn.Parameter(torch.zeros(dim))
443
+ self.v_bias = nn.Parameter(torch.zeros(dim))
444
+ else:
445
+ self.q_bias = None
446
+ self.v_bias = None
447
+
448
+ self.attn_drop = nn.Dropout(attn_drop)
449
+ self.proj = nn.Linear(dim, dim)
450
+ self.proj_drop = nn.Dropout(proj_drop)
451
+
452
+ self.softmax = nn.Softmax(dim=-1)
453
+
454
+ def forward(self, x, mask=None):
455
+ """
456
+ Args:
457
+ x: input features with shape of (num_windows*B, N, C)
458
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
459
+ """
460
+ B_, N, C = x.shape
461
+
462
+ qkv_bias = None
463
+ if self.q_bias is not None:
464
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
465
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
466
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
467
+ # qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
468
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
469
+
470
+ if self.attn_type == 'cosine_mh':
471
+ q = F.normalize(q.float(), dim=-1)
472
+ k = F.normalize(k.float(), dim=-1)
473
+ logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01, device=self.logit_scale.device))).exp()
474
+ attn = (q @ k.transpose(-2, -1)) * logit_scale.float()
475
+ elif self.attn_type == 'normal':
476
+ q = q * self.scale
477
+ attn = (q.float() @ k.float().transpose(-2, -1))
478
+ else:
479
+ raise NotImplementedError()
480
+
481
+ if self.relative_coords_table_type != "none":
482
+ # relative_position_bias_table: 2*Wh-1 * 2*Ww-1, nH
483
+ relative_position_bias_table = self.rpe_mlp(self.relative_coords_table).view(-1, self.num_heads)
484
+ else:
485
+ relative_position_bias_table = self.relative_position_bias_table
486
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
487
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
488
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
489
+ if self.rpe_output_type == 'normal':
490
+ pass
491
+ elif self.rpe_output_type == 'sigmoid':
492
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
493
+ else:
494
+ raise NotImplementedError
495
+
496
+ attn = attn + relative_position_bias.unsqueeze(0)
497
+
498
+ if mask is not None:
499
+ nW = mask.shape[0]
500
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
501
+ attn = attn.view(-1, self.num_heads, N, N)
502
+
503
+ attn = self.softmax(attn)
504
+ attn = attn.type_as(x)
505
+ attn = self.attn_drop(attn)
506
+
507
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
508
+ if self.mlpfp32:
509
+ x = self.proj.float()(x.type(torch.float32))
510
+ x = self.proj_drop.float()(x)
511
+ # print(f"======>[ATTN FP32]")
512
+ else:
513
+ x = self.proj(x)
514
+ x = self.proj_drop(x)
515
+ return x
516
+
517
+ def extra_repr(self) -> str:
518
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
519
+
520
+ def flops(self, N):
521
+ # calculate flops for 1 window with token length of N
522
+ flops = 0
523
+ # qkv = self.qkv(x)
524
+ flops += N * self.dim * 3 * self.dim
525
+ # attn = (q @ k.transpose(-2, -1))
526
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
527
+ # x = (attn @ v)
528
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
529
+ # x = self.proj(x)
530
+ flops += N * self.dim * self.dim
531
+ return flops
532
+
533
+
534
+ class SwinTransformerBlockPost(nn.Module):
535
+ """ Swin Transformer Block.
536
+
537
+ Args:
538
+ dim (int): Number of input channels.
539
+ num_heads (int): Number of attention heads.
540
+ window_size (int): Window size.
541
+ shift_size (int): Shift size for SW-MSA.
542
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
543
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
544
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
545
+ drop (float, optional): Dropout rate. Default: 0.0
546
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
547
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
548
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
549
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
550
+ """
551
+
552
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
553
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
554
+ use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
555
+ relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
556
+ rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False,
557
+ pretrain_window_size=-1):
558
+ super().__init__()
559
+ self.dim = dim
560
+ self.num_heads = num_heads
561
+ self.window_size = window_size
562
+ self.shift_size = shift_size
563
+ self.mlp_ratio = mlp_ratio
564
+ self.use_mlp_norm = use_mlp_norm
565
+ self.endnorm = endnorm
566
+ self.mlpfp32 = mlpfp32
567
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
568
+
569
+ self.norm1 = norm_layer(dim)
570
+ self.attn = WindowAttention(
571
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
572
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
573
+ relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type,
574
+ rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32,
575
+ pretrain_window_size=pretrain_window_size)
576
+
577
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
578
+ self.norm2 = norm_layer(dim)
579
+ mlp_hidden_dim = int(dim * mlp_ratio)
580
+
581
+ if mlp_type == 'normal':
582
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
583
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
584
+ elif mlp_type == 'conv':
585
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
586
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
587
+ elif mlp_type == 'conv_ln':
588
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
589
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True)
590
+
591
+ if self.endnorm:
592
+ self.enorm = norm_layer(dim)
593
+ else:
594
+ self.enorm = None
595
+
596
+ self.H = None
597
+ self.W = None
598
+
599
+ def forward(self, x, mask_matrix):
600
+ H, W = self.H, self.W
601
+ B, L, C = x.shape
602
+ assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}"
603
+
604
+ shortcut = x
605
+
606
+ x = x.view(B, H, W, C)
607
+
608
+ # pad feature maps to multiples of window size
609
+ pad_l = pad_t = 0
610
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
611
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
612
+ if pad_r > 0 or pad_b > 0:
613
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
614
+ _, Hp, Wp, _ = x.shape
615
+
616
+ # cyclic shift
617
+ if self.shift_size > 0:
618
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
619
+ attn_mask = mask_matrix
620
+ else:
621
+ shifted_x = x
622
+ attn_mask = None
623
+
624
+ # partition windows
625
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
626
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
627
+
628
+ # W-MSA/SW-MSA
629
+ orig_type = x.dtype # attn may force to fp32
630
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
631
+
632
+ # merge windows
633
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
634
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
635
+
636
+ # reverse cyclic shift
637
+ if self.shift_size > 0:
638
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
639
+ else:
640
+ x = shifted_x
641
+
642
+ if pad_r > 0 or pad_b > 0:
643
+ x = x[:, :H, :W, :].contiguous()
644
+
645
+ x = x.view(B, H * W, C)
646
+
647
+ # FFN
648
+ if self.mlpfp32:
649
+ x = self.norm1.float()(x)
650
+ x = x.type(orig_type)
651
+ else:
652
+ x = self.norm1(x)
653
+ x = shortcut + self.drop_path(x)
654
+ shortcut = x
655
+
656
+ orig_type = x.dtype
657
+ x = self.mlp(x, H, W)
658
+ if self.mlpfp32:
659
+ x = self.norm2.float()(x)
660
+ x = x.type(orig_type)
661
+ else:
662
+ x = self.norm2(x)
663
+ x = shortcut + self.drop_path(x)
664
+
665
+ if self.endnorm:
666
+ x = self.enorm(x)
667
+
668
+ return x
669
+
670
+
671
+ class SwinTransformerBlockPre(nn.Module):
672
+ """ Swin Transformer Block.
673
+
674
+ Args:
675
+ dim (int): Number of input channels.
676
+ num_heads (int): Number of attention heads.
677
+ window_size (int): Window size.
678
+ shift_size (int): Shift size for SW-MSA.
679
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
680
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
681
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
682
+ drop (float, optional): Dropout rate. Default: 0.0
683
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
684
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
685
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
686
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
687
+ """
688
+
689
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
690
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
691
+ use_mlp_norm=False, endnorm=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
692
+ init_values=None, relative_coords_table_type='norm8_log', rpe_hidden_dim=512,
693
+ rpe_output_type='normal', attn_type='normal', mlp_type='normal', mlpfp32=False,
694
+ pretrain_window_size=-1):
695
+ super().__init__()
696
+ self.dim = dim
697
+ self.num_heads = num_heads
698
+ self.window_size = window_size
699
+ self.shift_size = shift_size
700
+ self.mlp_ratio = mlp_ratio
701
+ self.use_mlp_norm = use_mlp_norm
702
+ self.endnorm = endnorm
703
+ self.mlpfp32 = mlpfp32
704
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
705
+
706
+ self.norm1 = norm_layer(dim)
707
+ self.attn = WindowAttention(
708
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
709
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
710
+ relative_coords_table_type=relative_coords_table_type, rpe_output_type=rpe_output_type,
711
+ rpe_hidden_dim=rpe_hidden_dim, attn_type=attn_type, mlpfp32=mlpfp32,
712
+ pretrain_window_size=pretrain_window_size)
713
+
714
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
715
+ self.norm2 = norm_layer(dim)
716
+ mlp_hidden_dim = int(dim * mlp_ratio)
717
+
718
+ if mlp_type == 'normal':
719
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
720
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
721
+ elif mlp_type == 'conv':
722
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
723
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32)
724
+ elif mlp_type == 'conv_ln':
725
+ self.mlp = ConvMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
726
+ norm_layer=norm_layer if self.use_mlp_norm else None, mlpfp32=mlpfp32, proj_ln=True)
727
+
728
+ if init_values is not None and init_values >= 0:
729
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
730
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
731
+ else:
732
+ self.gamma_1, self.gamma_2 = 1.0, 1.0
733
+
734
+ if self.endnorm:
735
+ self.enorm = norm_layer(dim)
736
+ else:
737
+ self.enorm = None
738
+
739
+ self.H = None
740
+ self.W = None
741
+
742
+ def forward(self, x, mask_matrix):
743
+ H, W = self.H, self.W
744
+ B, L, C = x.shape
745
+ assert L == H * W, f"input feature has wrong size, with L = {L}, H = {H}, W = {W}"
746
+
747
+ shortcut = x
748
+ x = self.norm1(x)
749
+ x = x.view(B, H, W, C)
750
+
751
+ # pad feature maps to multiples of window size
752
+ pad_l = pad_t = 0
753
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
754
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
755
+ if pad_r > 0 or pad_b > 0:
756
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
757
+ _, Hp, Wp, _ = x.shape
758
+
759
+ # cyclic shift
760
+ if self.shift_size > 0:
761
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
762
+ attn_mask = mask_matrix
763
+ else:
764
+ shifted_x = x
765
+ attn_mask = None
766
+
767
+ # partition windows
768
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
769
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
770
+
771
+ # W-MSA/SW-MSA
772
+ orig_type = x.dtype # attn may force to fp32
773
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
774
+
775
+ # merge windows
776
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
777
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
778
+
779
+ # reverse cyclic shift
780
+ if self.shift_size > 0:
781
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
782
+ else:
783
+ x = shifted_x
784
+
785
+ if pad_r > 0 or pad_b > 0:
786
+ x = x[:, :H, :W, :].contiguous()
787
+
788
+ x = x.view(B, H * W, C)
789
+
790
+ # FFN
791
+ if self.mlpfp32:
792
+ x = self.gamma_1 * x
793
+ x = x.type(orig_type)
794
+ else:
795
+ x = self.gamma_1 * x
796
+ x = shortcut + self.drop_path(x)
797
+ shortcut = x
798
+
799
+ orig_type = x.dtype
800
+ x = self.norm2(x)
801
+ if self.mlpfp32:
802
+ x = self.gamma_2 * self.mlp(x, H, W)
803
+ x = x.type(orig_type)
804
+ else:
805
+ x = self.gamma_2 * self.mlp(x, H, W)
806
+ x = shortcut + self.drop_path(x)
807
+
808
+ if self.endnorm:
809
+ x = self.enorm(x)
810
+
811
+ return x
812
+
813
+
814
+ class PatchMerging(nn.Module):
815
+ """ Patch Merging Layer
816
+
817
+ Args:
818
+ dim (int): Number of input channels.
819
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
820
+ """
821
+
822
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
823
+ super().__init__()
824
+ self.dim = dim
825
+ self.postnorm = postnorm
826
+
827
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
828
+ self.norm = norm_layer(2 * dim) if postnorm else norm_layer(4 * dim)
829
+
830
+ def forward(self, x, H, W):
831
+ """ Forward function.
832
+
833
+ Args:
834
+ x: Input feature, tensor size (B, H*W, C).
835
+ H, W: Spatial resolution of the input feature.
836
+ """
837
+ B, L, C = x.shape
838
+ assert L == H * W, "input feature has wrong size"
839
+
840
+ x = x.view(B, H, W, C)
841
+
842
+ # padding
843
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
844
+ if pad_input:
845
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
846
+
847
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
848
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
849
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
850
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
851
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
852
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
853
+
854
+ if self.postnorm:
855
+ x = self.reduction(x)
856
+ x = self.norm(x)
857
+ else:
858
+ x = self.norm(x)
859
+ x = self.reduction(x)
860
+
861
+ return x
862
+
863
+
864
+ class PatchReduction1C(nn.Module):
865
+ r""" Patch Reduction Layer.
866
+
867
+ Args:
868
+ input_resolution (tuple[int]): Resolution of input feature.
869
+ dim (int): Number of input channels.
870
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
871
+ """
872
+
873
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
874
+ super().__init__()
875
+ self.dim = dim
876
+ self.postnorm = postnorm
877
+
878
+ self.reduction = nn.Linear(dim, dim, bias=False)
879
+ self.norm = norm_layer(dim)
880
+
881
+ def forward(self, x, H, W):
882
+ """
883
+ x: B, H*W, C
884
+ """
885
+ if self.postnorm:
886
+ x = self.reduction(x)
887
+ x = self.norm(x)
888
+ else:
889
+ x = self.norm(x)
890
+ x = self.reduction(x)
891
+
892
+ return x
893
+
894
+
895
+ class ConvPatchMerging(nn.Module):
896
+ r""" Patch Merging Layer.
897
+
898
+ Args:
899
+ input_resolution (tuple[int]): Resolution of input feature.
900
+ dim (int): Number of input channels.
901
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
902
+ """
903
+
904
+ def __init__(self, dim, norm_layer=nn.LayerNorm, postnorm=True):
905
+ super().__init__()
906
+ self.dim = dim
907
+ self.postnorm = postnorm
908
+
909
+ self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=3, stride=2, padding=1)
910
+ self.norm = norm_layer(2 * dim) if postnorm else norm_layer(dim)
911
+
912
+ def forward(self, x, H, W):
913
+ B, L, C = x.shape
914
+ assert L == H * W, "input feature has wrong size"
915
+
916
+ x = x.view(B, H, W, C)
917
+
918
+ # padding
919
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
920
+ if pad_input:
921
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
922
+
923
+ if self.postnorm:
924
+ x = x.permute(0, 3, 1, 2) # B C H W
925
+ x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C
926
+ x = self.norm(x)
927
+ else:
928
+ x = self.norm(x)
929
+ x = x.permute(0, 3, 1, 2) # B C H W
930
+ x = self.reduction(x).flatten(2).transpose(1, 2) # B H//2*W//2 2*C
931
+
932
+ return x
933
+
934
+
935
+ class BasicLayer(nn.Module):
936
+ """ A basic Swin Transformer layer for one stage.
937
+
938
+ Args:
939
+ dim (int): Number of feature channels
940
+ depth (int): Depths of this stage.
941
+ num_heads (int): Number of attention head.
942
+ window_size (int): Local window size. Default: 7.
943
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
944
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
945
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
946
+ drop (float, optional): Dropout rate. Default: 0.0
947
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
948
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
949
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
950
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
951
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
952
+ use_shift (bool): Whether to use shifted window. Default: True.
953
+ """
954
+
955
+ def __init__(self,
956
+ dim,
957
+ depth,
958
+ num_heads,
959
+ window_size=7,
960
+ mlp_ratio=4.,
961
+ qkv_bias=True,
962
+ qk_scale=None,
963
+ drop=0.,
964
+ attn_drop=0.,
965
+ drop_path=0.,
966
+ norm_layer=nn.LayerNorm,
967
+ downsample=None,
968
+ use_checkpoint=False,
969
+ checkpoint_blocks=255,
970
+ init_values=None,
971
+ endnorm_interval=-1,
972
+ use_mlp_norm=False,
973
+ use_shift=True,
974
+ relative_coords_table_type='norm8_log',
975
+ rpe_hidden_dim=512,
976
+ rpe_output_type='normal',
977
+ attn_type='normal',
978
+ mlp_type='normal',
979
+ mlpfp32_blocks=[-1],
980
+ postnorm=True,
981
+ pretrain_window_size=-1):
982
+ super().__init__()
983
+ self.window_size = window_size
984
+ self.shift_size = window_size // 2
985
+ self.depth = depth
986
+ self.use_checkpoint = use_checkpoint
987
+ self.checkpoint_blocks = checkpoint_blocks
988
+ self.init_values = init_values if init_values is not None else 0.0
989
+ self.endnorm_interval = endnorm_interval
990
+ self.mlpfp32_blocks = mlpfp32_blocks
991
+ self.postnorm = postnorm
992
+
993
+ # build blocks
994
+ if self.postnorm:
995
+ self.blocks = nn.ModuleList([
996
+ SwinTransformerBlockPost(
997
+ dim=dim,
998
+ num_heads=num_heads,
999
+ window_size=window_size,
1000
+ shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2,
1001
+ mlp_ratio=mlp_ratio,
1002
+ qkv_bias=qkv_bias,
1003
+ qk_scale=qk_scale,
1004
+ drop=drop,
1005
+ attn_drop=attn_drop,
1006
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
1007
+ norm_layer=norm_layer,
1008
+ use_mlp_norm=use_mlp_norm,
1009
+ endnorm=True if ((i + 1) % endnorm_interval == 0) and (
1010
+ endnorm_interval > 0) else False,
1011
+ relative_coords_table_type=relative_coords_table_type,
1012
+ rpe_hidden_dim=rpe_hidden_dim,
1013
+ rpe_output_type=rpe_output_type,
1014
+ attn_type=attn_type,
1015
+ mlp_type=mlp_type,
1016
+ mlpfp32=True if i in mlpfp32_blocks else False,
1017
+ pretrain_window_size=pretrain_window_size)
1018
+ for i in range(depth)])
1019
+ else:
1020
+ self.blocks = nn.ModuleList([
1021
+ SwinTransformerBlockPre(
1022
+ dim=dim,
1023
+ num_heads=num_heads,
1024
+ window_size=window_size,
1025
+ shift_size=0 if (i % 2 == 0) or (not use_shift) else window_size // 2,
1026
+ mlp_ratio=mlp_ratio,
1027
+ qkv_bias=qkv_bias,
1028
+ qk_scale=qk_scale,
1029
+ drop=drop,
1030
+ attn_drop=attn_drop,
1031
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
1032
+ norm_layer=norm_layer,
1033
+ init_values=init_values,
1034
+ use_mlp_norm=use_mlp_norm,
1035
+ endnorm=True if ((i + 1) % endnorm_interval == 0) and (
1036
+ endnorm_interval > 0) else False,
1037
+ relative_coords_table_type=relative_coords_table_type,
1038
+ rpe_hidden_dim=rpe_hidden_dim,
1039
+ rpe_output_type=rpe_output_type,
1040
+ attn_type=attn_type,
1041
+ mlp_type=mlp_type,
1042
+ mlpfp32=True if i in mlpfp32_blocks else False,
1043
+ pretrain_window_size=pretrain_window_size)
1044
+ for i in range(depth)])
1045
+
1046
+ # patch merging layer
1047
+ if downsample is not None:
1048
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer, postnorm=postnorm)
1049
+ else:
1050
+ self.downsample = None
1051
+
1052
+ def forward(self, x, H, W):
1053
+ """ Forward function.
1054
+
1055
+ Args:
1056
+ x: Input feature, tensor size (B, H*W, C).
1057
+ H, W: Spatial resolution of the input feature.
1058
+ """
1059
+
1060
+ # calculate attention mask for SW-MSA
1061
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
1062
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
1063
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
1064
+ h_slices = (slice(0, -self.window_size),
1065
+ slice(-self.window_size, -self.shift_size),
1066
+ slice(-self.shift_size, None))
1067
+ w_slices = (slice(0, -self.window_size),
1068
+ slice(-self.window_size, -self.shift_size),
1069
+ slice(-self.shift_size, None))
1070
+ cnt = 0
1071
+ for h in h_slices:
1072
+ for w in w_slices:
1073
+ img_mask[:, h, w, :] = cnt
1074
+ cnt += 1
1075
+
1076
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
1077
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1078
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1079
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
1080
+ for idx, blk in enumerate(self.blocks):
1081
+ blk.H, blk.W = H, W
1082
+ if self.use_checkpoint:
1083
+ x = checkpoint.checkpoint(blk, x, attn_mask)
1084
+ else:
1085
+ x = blk(x, attn_mask)
1086
+
1087
+ if self.downsample is not None:
1088
+ x_down = self.downsample(x, H, W)
1089
+ if isinstance(self.downsample, PatchReduction1C):
1090
+ return x, H, W, x_down, H, W
1091
+ else:
1092
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
1093
+ return x, H, W, x_down, Wh, Ww
1094
+ else:
1095
+ return x, H, W, x, H, W
1096
+
1097
+ def _init_block_norm_weights(self):
1098
+ for blk in self.blocks:
1099
+ nn.init.constant_(blk.norm1.bias, 0)
1100
+ nn.init.constant_(blk.norm1.weight, self.init_values)
1101
+ nn.init.constant_(blk.norm2.bias, 0)
1102
+ nn.init.constant_(blk.norm2.weight, self.init_values)
1103
+
1104
+
1105
+ class PatchEmbed(nn.Module):
1106
+ """ Image to Patch Embedding
1107
+
1108
+ Args:
1109
+ patch_size (int): Patch token size. Default: 4.
1110
+ in_chans (int): Number of input image channels. Default: 3.
1111
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1112
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
1113
+ """
1114
+
1115
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
1116
+ super().__init__()
1117
+ patch_size = to_2tuple(patch_size)
1118
+ self.patch_size = patch_size
1119
+
1120
+ self.in_chans = in_chans
1121
+ self.embed_dim = embed_dim
1122
+
1123
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
1124
+ if norm_layer is not None:
1125
+ self.norm = norm_layer(embed_dim)
1126
+ else:
1127
+ self.norm = None
1128
+
1129
+ def forward(self, x):
1130
+ """Forward function."""
1131
+ # padding
1132
+ _, _, H, W = x.size()
1133
+ if W % self.patch_size[1] != 0:
1134
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
1135
+ if H % self.patch_size[0] != 0:
1136
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
1137
+
1138
+ x = self.proj(x) # B C Wh Ww
1139
+ if self.norm is not None:
1140
+ Wh, Ww = x.size(2), x.size(3)
1141
+ x = x.flatten(2).transpose(1, 2)
1142
+ x = self.norm(x)
1143
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
1144
+
1145
+ return x
1146
+
1147
+
1148
+ class ResNetDLNPatchEmbed(nn.Module):
1149
+ def __init__(self, in_chans=3, embed_dim=96, norm_layer=None):
1150
+ super().__init__()
1151
+ patch_size = to_2tuple(4)
1152
+ self.patch_size = patch_size
1153
+
1154
+ self.in_chans = in_chans
1155
+ self.embed_dim = embed_dim
1156
+
1157
+ self.conv1 = nn.Sequential(nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False),
1158
+ LayerNorm2D(64, norm_layer),
1159
+ nn.GELU(),
1160
+ nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False),
1161
+ LayerNorm2D(64, norm_layer),
1162
+ nn.GELU(),
1163
+ nn.Conv2d(64, embed_dim, 3, stride=1, padding=1, bias=False))
1164
+ self.norm = LayerNorm2D(embed_dim, norm_layer if norm_layer is not None else LayerNormFP32) # use ln always
1165
+ self.act = nn.GELU()
1166
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
1167
+
1168
+ def forward(self, x):
1169
+ _, _, H, W = x.size()
1170
+ if W % self.patch_size[1] != 0:
1171
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
1172
+ if H % self.patch_size[0] != 0:
1173
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
1174
+
1175
+ x = self.conv1(x)
1176
+ x = self.norm(x)
1177
+ x = self.act(x)
1178
+ x = self.maxpool(x)
1179
+ # x = x.flatten(2).transpose(1, 2)
1180
+ return x
1181
+
1182
+
1183
+ class SwinV2TransformerRPE2FC(nn.Module):
1184
+ """ Swin Transformer backbone.
1185
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
1186
+ https://arxiv.org/pdf/2103.14030
1187
+
1188
+ Args:
1189
+ pretrain_img_size (int): Input image size for training the pretrained model,
1190
+ used in absolute postion embedding. Default 224.
1191
+ patch_size (int | tuple(int)): Patch size. Default: 4.
1192
+ in_chans (int): Number of input image channels. Default: 3.
1193
+ embed_dim (int): Number of linear projection output channels. Default: 96.
1194
+ depths (tuple[int]): Depths of each Swin Transformer stage.
1195
+ num_heads (tuple[int]): Number of attention head of each stage.
1196
+ window_size (int): Window size. Default: 7.
1197
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
1198
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
1199
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
1200
+ drop_rate (float): Dropout rate.
1201
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
1202
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
1203
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
1204
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
1205
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
1206
+ out_indices (Sequence[int]): Output from which stages.
1207
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
1208
+ -1 means not freezing any parameters.
1209
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1210
+ use_shift (bool): Whether to use shifted window. Default: True.
1211
+ """
1212
+
1213
+ def __init__(self,
1214
+ pretrain_img_size=224,
1215
+ patch_size=4,
1216
+ in_chans=3,
1217
+ embed_dim=96,
1218
+ depths=[2, 2, 6, 2],
1219
+ num_heads=[3, 6, 12, 24],
1220
+ window_size=7,
1221
+ mlp_ratio=4.,
1222
+ qkv_bias=True,
1223
+ qk_scale=None,
1224
+ drop_rate=0.,
1225
+ attn_drop_rate=0.,
1226
+ drop_path_rate=0.1,
1227
+ norm_layer=partial(LayerNormFP32, eps=1e-6),
1228
+ ape=False,
1229
+ patch_norm=True,
1230
+ use_checkpoint=False,
1231
+ init_values=1e-5,
1232
+ endnorm_interval=-1,
1233
+ use_mlp_norm_layers=[],
1234
+ relative_coords_table_type='norm8_log',
1235
+ rpe_hidden_dim=512,
1236
+ attn_type='cosine_mh',
1237
+ rpe_output_type='sigmoid',
1238
+ rpe_wd=False,
1239
+ postnorm=True,
1240
+ mlp_type='normal',
1241
+ patch_embed_type='normal',
1242
+ patch_merge_type='normal',
1243
+ strid16=False,
1244
+ checkpoint_blocks=[255, 255, 255, 255],
1245
+ mlpfp32_layer_blocks=[[-1], [-1], [-1], [-1]],
1246
+ out_indices=(0, 1, 2, 3),
1247
+ frozen_stages=-1,
1248
+ use_shift=True,
1249
+ rpe_interpolation='geo',
1250
+ pretrain_window_size=[-1, -1, -1, -1],
1251
+ **kwargs):
1252
+ super().__init__()
1253
+
1254
+ self.pretrain_img_size = pretrain_img_size
1255
+ self.depths = depths
1256
+ self.num_layers = len(depths)
1257
+ self.embed_dim = embed_dim
1258
+ self.ape = ape
1259
+ self.patch_norm = patch_norm
1260
+ self.out_indices = out_indices
1261
+ self.frozen_stages = frozen_stages
1262
+ self.rpe_interpolation = rpe_interpolation
1263
+ self.mlp_ratio = mlp_ratio
1264
+ self.endnorm_interval = endnorm_interval
1265
+ self.use_mlp_norm_layers = use_mlp_norm_layers
1266
+ self.relative_coords_table_type = relative_coords_table_type
1267
+ self.rpe_hidden_dim = rpe_hidden_dim
1268
+ self.rpe_output_type = rpe_output_type
1269
+ self.rpe_wd = rpe_wd
1270
+ self.attn_type = attn_type
1271
+ self.postnorm = postnorm
1272
+ self.mlp_type = mlp_type
1273
+ self.strid16 = strid16
1274
+
1275
+ if isinstance(window_size, list):
1276
+ pass
1277
+ elif isinstance(window_size, int):
1278
+ window_size = [window_size] * self.num_layers
1279
+ else:
1280
+ raise TypeError("We only support list or int for window size")
1281
+
1282
+ if isinstance(use_shift, list):
1283
+ pass
1284
+ elif isinstance(use_shift, bool):
1285
+ use_shift = [use_shift] * self.num_layers
1286
+ else:
1287
+ raise TypeError("We only support list or bool for use_shift")
1288
+
1289
+ if isinstance(use_checkpoint, list):
1290
+ pass
1291
+ elif isinstance(use_checkpoint, bool):
1292
+ use_checkpoint = [use_checkpoint] * self.num_layers
1293
+ else:
1294
+ raise TypeError("We only support list or bool for use_checkpoint")
1295
+
1296
+ # split image into non-overlapping patches
1297
+ if patch_embed_type == 'normal':
1298
+ self.patch_embed = PatchEmbed(
1299
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
1300
+ norm_layer=norm_layer if self.patch_norm else None)
1301
+ elif patch_embed_type == 'resnetdln':
1302
+ assert patch_size == 4, "check"
1303
+ self.patch_embed = ResNetDLNPatchEmbed(
1304
+ in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer)
1305
+ elif patch_embed_type == 'resnetdnf':
1306
+ assert patch_size == 4, "check"
1307
+ self.patch_embed = ResNetDLNPatchEmbed(
1308
+ in_chans=in_chans, embed_dim=embed_dim, norm_layer=None)
1309
+ else:
1310
+ raise NotImplementedError()
1311
+ # absolute position embedding
1312
+ if self.ape:
1313
+ pretrain_img_size = to_2tuple(pretrain_img_size)
1314
+ patch_size = to_2tuple(patch_size)
1315
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
1316
+
1317
+ self.absolute_pos_embed = nn.Parameter(
1318
+ torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
1319
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1320
+
1321
+ self.pos_drop = nn.Dropout(p=drop_rate)
1322
+
1323
+ # stochastic depth
1324
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1325
+
1326
+ if patch_merge_type == 'normal':
1327
+ downsample_layer = PatchMerging
1328
+ elif patch_merge_type == 'conv':
1329
+ downsample_layer = ConvPatchMerging
1330
+ else:
1331
+ raise NotImplementedError()
1332
+ # build layers
1333
+ self.layers = nn.ModuleList()
1334
+ num_features = []
1335
+ for i_layer in range(self.num_layers):
1336
+ cur_dim = int(embed_dim * 2 ** (i_layer - 1)) \
1337
+ if (i_layer == self.num_layers - 1 and strid16) else \
1338
+ int(embed_dim * 2 ** i_layer)
1339
+ num_features.append(cur_dim)
1340
+ if i_layer < self.num_layers - 2:
1341
+ cur_downsample_layer = downsample_layer
1342
+ elif i_layer == self.num_layers - 2:
1343
+ if strid16:
1344
+ cur_downsample_layer = PatchReduction1C
1345
+ else:
1346
+ cur_downsample_layer = downsample_layer
1347
+ else:
1348
+ cur_downsample_layer = None
1349
+ layer = BasicLayer(
1350
+ dim=cur_dim,
1351
+ depth=depths[i_layer],
1352
+ num_heads=num_heads[i_layer],
1353
+ window_size=window_size[i_layer],
1354
+ mlp_ratio=mlp_ratio,
1355
+ qkv_bias=qkv_bias,
1356
+ qk_scale=qk_scale,
1357
+ drop=drop_rate,
1358
+ attn_drop=attn_drop_rate,
1359
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
1360
+ norm_layer=norm_layer,
1361
+ downsample=cur_downsample_layer,
1362
+ use_checkpoint=use_checkpoint[i_layer],
1363
+ checkpoint_blocks=checkpoint_blocks[i_layer],
1364
+ init_values=init_values,
1365
+ endnorm_interval=endnorm_interval,
1366
+ use_mlp_norm=True if i_layer in use_mlp_norm_layers else False,
1367
+ use_shift=use_shift[i_layer],
1368
+ relative_coords_table_type=self.relative_coords_table_type,
1369
+ rpe_hidden_dim=self.rpe_hidden_dim,
1370
+ rpe_output_type=self.rpe_output_type,
1371
+ attn_type=self.attn_type,
1372
+ mlp_type=self.mlp_type,
1373
+ mlpfp32_blocks=mlpfp32_layer_blocks[i_layer],
1374
+ postnorm=self.postnorm,
1375
+ pretrain_window_size=pretrain_window_size[i_layer]
1376
+ )
1377
+ self.layers.append(layer)
1378
+
1379
+ self.num_features = num_features
1380
+
1381
+ # add a norm layer for each output
1382
+ for i_layer in out_indices:
1383
+ layer = norm_layer(num_features[i_layer])
1384
+ layer_name = f'norm{i_layer}'
1385
+ self.add_module(layer_name, layer)
1386
+
1387
+ self._freeze_stages()
1388
+
1389
+ def _freeze_stages(self):
1390
+ if self.frozen_stages >= 0:
1391
+ self.patch_embed.eval()
1392
+ for param in self.patch_embed.parameters():
1393
+ param.requires_grad = False
1394
+
1395
+ if self.frozen_stages >= 1 and self.ape:
1396
+ self.absolute_pos_embed.requires_grad = False
1397
+
1398
+ if self.frozen_stages >= 2:
1399
+ self.pos_drop.eval()
1400
+ for i in range(0, self.frozen_stages - 1):
1401
+ m = self.layers[i]
1402
+ m.eval()
1403
+ for param in m.parameters():
1404
+ param.requires_grad = False
1405
+
1406
+ def init_weights(self, pretrained=None):
1407
+ """Initialize the weights in backbone.
1408
+
1409
+ Args:
1410
+ pretrained (str, optional): Path to pre-trained weights.
1411
+ Defaults to None.
1412
+ """
1413
+ self.norm3.eval()
1414
+ for param in self.norm3.parameters():
1415
+ param.requires_grad = False
1416
+
1417
+ def _init_weights(m):
1418
+ if isinstance(m, nn.Linear):
1419
+ trunc_normal_(m.weight, std=.02)
1420
+ if isinstance(m, nn.Linear) and m.bias is not None:
1421
+ nn.init.constant_(m.bias, 0)
1422
+ elif isinstance(m, nn.LayerNorm):
1423
+ nn.init.constant_(m.bias, 0)
1424
+ nn.init.constant_(m.weight, 1.0)
1425
+ elif isinstance(m, nn.Conv2d):
1426
+ trunc_normal_(m.weight, std=.02)
1427
+ if m.bias is not None:
1428
+ nn.init.constant_(m.bias, 0)
1429
+
1430
+ self.apply(_init_weights)
1431
+ for bly in self.layers:
1432
+ bly._init_block_norm_weights()
1433
+
1434
+ if isinstance(pretrained, str):
1435
+ logger = None
1436
+ load_checkpoint_swin(self, pretrained, strict=False, map_location='cpu',
1437
+ logger=logger, rpe_interpolation=self.rpe_interpolation)
1438
+ elif pretrained is None:
1439
+ pass
1440
+ else:
1441
+ raise TypeError('pretrained must be a str or None')
1442
+
1443
+ def forward(self, x):
1444
+ """Forward function."""
1445
+ x = self.patch_embed(x)
1446
+
1447
+ Wh, Ww = x.size(2), x.size(3)
1448
+ if self.ape:
1449
+ # interpolate the position embedding to the corresponding size
1450
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
1451
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
1452
+ else:
1453
+ x = x.flatten(2).transpose(1, 2)
1454
+
1455
+ x = self.pos_drop(x)
1456
+
1457
+ outs = []
1458
+ for i in range(self.num_layers):
1459
+ layer = self.layers[i]
1460
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
1461
+
1462
+ if i in self.out_indices:
1463
+ norm_layer = getattr(self, f'norm{i}')
1464
+ x_out = norm_layer.float()(x_out.float())
1465
+
1466
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
1467
+
1468
+ outs.append(out)
1469
+
1470
+ return outs
1471
+
1472
+ def train(self, mode=True):
1473
+ """Convert the model into training mode while keep layers freezed."""
1474
+ super(SwinV2TransformerRPE2FC, self).train(mode)
1475
+ self._freeze_stages()
main/pct_utils/pct_head.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from pct_utils.pct_tokenizer import PCT_Tokenizer
5
+ from pct_utils.modules import MixerLayer, FCBlock, BasicBlock
6
+
7
+ def constant_init(module, val, bias=0):
8
+ if hasattr(module, 'weight') and module.weight is not None:
9
+ nn.init.constant_(module.weight, val)
10
+ if hasattr(module, 'bias') and module.bias is not None:
11
+ nn.init.constant_(module.bias, bias)
12
+
13
+ def normal_init(module, mean=0, std=1, bias=0):
14
+ if hasattr(module, 'weight') and module.weight is not None:
15
+ nn.init.normal_(module.weight, mean, std)
16
+ if hasattr(module, 'bias') and module.bias is not None:
17
+ nn.init.constant_(module.bias, bias)
18
+
19
+ class PCT_Head(nn.Module):
20
+ """ Head of Pose Compositional Tokens.
21
+ paper ref: Zigang Geng et al. "Human Pose as
22
+ Compositional Tokens"
23
+
24
+ The pipelines of two stage during training and inference:
25
+
26
+ Tokenizer Stage & Train:
27
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints
28
+ Loss: (Joints, Recovered Joints)
29
+ Tokenizer Stage & Test:
30
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Decoder -> Recovered Joints
31
+
32
+ Classifer Stage & Train:
33
+ Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints
34
+ Joints -> (Img Guide) -> Encoder -> Codebook -> Groundtruth Class
35
+ Loss: (Predict Class, Groundtruth Class), (Joints, Recovered Joints)
36
+ Classifer Stage & Test:
37
+ Img -> Classifier -> Predict Class -> Codebook -> Decoder -> Recovered Joints
38
+
39
+ Args:
40
+ stage_pct (str): Training stage (Tokenizer or Classifier).
41
+ in_channels (int): Feature Dim of the backbone feature.
42
+ image_size (tuple): Input image size.
43
+ num_joints (int): Number of annotated joints in the dataset.
44
+ cls_head (dict): Config for PCT classification head. Default: None.
45
+ tokenizer (dict): Config for PCT tokenizer. Default: None.
46
+ loss_keypoint (dict): Config for loss for training classifier. Default: None.
47
+ """
48
+
49
+ def __init__(self,
50
+ args,
51
+ stage_pct,
52
+ in_channels,
53
+ image_size,
54
+ num_joints,
55
+ cls_head=None,
56
+ tokenizer=None,
57
+ loss_keypoint=None,):
58
+ super().__init__()
59
+
60
+ self.image_size = image_size
61
+ self.stage_pct = stage_pct
62
+
63
+ self.guide_ratio = args.tokenizer_guide_ratio
64
+ self.img_guide = self.guide_ratio > 0.0
65
+
66
+ self.conv_channels = args.cls_head_conv_channels
67
+ self.hidden_dim = args.cls_head_hidden_dim
68
+
69
+ self.num_blocks = args.cls_head_num_blocks
70
+ self.hidden_inter_dim = args.cls_head_hidden_inter_dim
71
+ self.token_inter_dim = args.cls_head_token_inter_dim
72
+ self.dropout = args.cls_head_dropout
73
+
74
+ self.token_num = args.tokenizer_codebook_token_num
75
+ self.token_class_num = args.tokenizer_codebook_token_class_num
76
+
77
+ if stage_pct == "classifier":
78
+ self.conv_trans = self._make_transition_for_head(
79
+ in_channels, self.conv_channels)
80
+ self.conv_head = self._make_cls_head(args)
81
+
82
+ input_size = (image_size[0]//32)*(image_size[1]//32)
83
+ self.mixer_trans = FCBlock(
84
+ self.conv_channels * input_size,
85
+ self.token_num * self.hidden_dim)
86
+
87
+ self.mixer_head = nn.ModuleList(
88
+ [MixerLayer(self.hidden_dim, self.hidden_inter_dim,
89
+ self.token_num, self.token_inter_dim,
90
+ self.dropout) for _ in range(self.num_blocks)])
91
+ self.mixer_norm_layer = FCBlock(
92
+ self.hidden_dim, self.hidden_dim)
93
+
94
+ self.cls_pred_layer = nn.Linear(
95
+ self.hidden_dim, self.token_class_num)
96
+
97
+ self.tokenizer = PCT_Tokenizer(
98
+ args = args, stage_pct=stage_pct, num_joints=num_joints,
99
+ guide_ratio=self.guide_ratio, guide_channels=in_channels)
100
+
101
+ def forward(self, x, extra_x, joints=None, train=True):
102
+ """Forward function."""
103
+
104
+ if self.stage_pct == "classifier":
105
+ batch_size = x[-1].shape[0]
106
+ cls_feat = self.conv_head[0](self.conv_trans(x[-1]))
107
+
108
+ cls_feat = cls_feat.flatten(2).transpose(2,1).flatten(1)
109
+ cls_feat = self.mixer_trans(cls_feat)
110
+ cls_feat = cls_feat.reshape(batch_size, self.token_num, -1)
111
+
112
+ for mixer_layer in self.mixer_head:
113
+ cls_feat = mixer_layer(cls_feat)
114
+ cls_feat = self.mixer_norm_layer(cls_feat)
115
+
116
+ cls_logits = self.cls_pred_layer(cls_feat)
117
+
118
+ encoding_scores = cls_logits.topk(1, dim=2)[0]
119
+ cls_logits = cls_logits.flatten(0,1)
120
+ cls_logits_softmax = cls_logits.clone().softmax(1)
121
+ else:
122
+ encoding_scores = None
123
+ cls_logits = None
124
+ cls_logits_softmax = None
125
+
126
+ if not self.img_guide or \
127
+ (self.stage_pct == "classifier" and not train):
128
+ joints_feat = None
129
+ else:
130
+ joints_feat = self.extract_joints_feat(extra_x[-1], joints)
131
+
132
+ output_joints, cls_label, e_latent_loss, out_part_token_feat = \
133
+ self.tokenizer(joints, joints_feat, cls_logits_softmax, train=train)
134
+
135
+ if train:
136
+ return cls_logits, output_joints, cls_label, e_latent_loss
137
+ else:
138
+ return output_joints, encoding_scores, out_part_token_feat
139
+
140
+ def _make_transition_for_head(self, inplanes, outplanes):
141
+ transition_layer = [
142
+ nn.Conv2d(inplanes, outplanes, 1, 1, 0, bias=False),
143
+ nn.BatchNorm2d(outplanes),
144
+ nn.ReLU(True)
145
+ ]
146
+ return nn.Sequential(*transition_layer)
147
+
148
+ def _make_cls_head(self, args):
149
+ feature_convs = []
150
+ feature_conv = self._make_layer(
151
+ BasicBlock,
152
+ args.cls_head_conv_channels,
153
+ args.cls_head_conv_channels,
154
+ args.cls_head_conv_num_blocks,
155
+ dilation=args.cls_head_dilation
156
+ )
157
+ feature_convs.append(feature_conv)
158
+
159
+ return nn.ModuleList(feature_convs)
160
+
161
+ def _make_layer(
162
+ self, block, inplanes, planes, blocks, stride=1, dilation=1):
163
+ downsample = None
164
+ if stride != 1 or inplanes != planes * block.expansion:
165
+ downsample = nn.Sequential(
166
+ nn.Conv2d(inplanes, planes * block.expansion,
167
+ kernel_size=1, stride=stride, bias=False),
168
+ nn.BatchNorm2d(planes * block.expansion, momentum=0.1),
169
+ )
170
+
171
+ layers = []
172
+ layers.append(block(inplanes, planes,
173
+ stride, downsample, dilation=dilation))
174
+ inplanes = planes * block.expansion
175
+ for _ in range(1, blocks):
176
+ layers.append(block(inplanes, planes, dilation=dilation))
177
+
178
+ return nn.Sequential(*layers)
179
+
180
+ def extract_joints_feat(self, feature_map, joint_coords):
181
+ assert self.image_size[1] == self.image_size[0], \
182
+ 'If you want to use a rectangle input, ' \
183
+ 'please carefully check the length and width below.'
184
+ batch_size, _, _, height = feature_map.shape
185
+ stride = self.image_size[0] / feature_map.shape[-1]
186
+ joint_x = (joint_coords[:,:,0] / stride + 0.5).int()
187
+ joint_y = (joint_coords[:,:,1] / stride + 0.5).int()
188
+ joint_x = joint_x.clamp(0, feature_map.shape[-1] - 1)
189
+ joint_y = joint_y.clamp(0, feature_map.shape[-2] - 1)
190
+ joint_indices = (joint_y * height + joint_x).long()
191
+
192
+ flattened_feature_map = feature_map.clone().flatten(2)
193
+ joint_features = flattened_feature_map[
194
+ torch.arange(batch_size).unsqueeze(1), :, joint_indices]
195
+
196
+ return joint_features
197
+
198
+ def init_weights(self):
199
+ if self.stage_pct == "classifier":
200
+ self.tokenizer.eval()
201
+ for name, params in self.tokenizer.named_parameters():
202
+ params.requires_grad = False
203
+
204
+ for m in self.modules():
205
+ if isinstance(m, nn.Conv2d):
206
+ normal_init(m, std=0.001, bias=0)
207
+ elif isinstance(m, nn.BatchNorm2d):
208
+ constant_init(m, 1)
main/pct_utils/pct_tokenizer.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Pose Compositional Tokens
3
+ # Written by Zigang Geng (zigang@mail.ustc.edu.cn)
4
+ # --------------------------------------------------------
5
+
6
+ import os
7
+ import math
8
+ import warnings
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.distributed as dist
13
+
14
+ from pct_utils.modules import MixerLayer
15
+
16
+ def _trunc_normal_(tensor, mean, std, a, b):
17
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
18
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
19
+ def norm_cdf(x):
20
+ # Computes standard normal cumulative distribution function
21
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
22
+
23
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
24
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
25
+ "The distribution of values may be incorrect.",
26
+ stacklevel=2)
27
+
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a - mean) / std)
32
+ u = norm_cdf((b - mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
51
+ r"""Fills the input Tensor with values drawn from a truncated
52
+ normal distribution. The values are effectively drawn from the
53
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
54
+ with values outside :math:`[a, b]` redrawn until they are within
55
+ the bounds. The method used for generating the random values works
56
+ best when :math:`a \leq \text{mean} \leq b`.
57
+
58
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
59
+ applied while sampling the normal with mean/std applied, therefore a, b args
60
+ should be adjusted to match the range of mean, std args.
61
+
62
+ Args:
63
+ tensor: an n-dimensional `torch.Tensor`
64
+ mean: the mean of the normal distribution
65
+ std: the standard deviation of the normal distribution
66
+ a: the minimum cutoff value
67
+ b: the maximum cutoff value
68
+ Examples:
69
+ >>> w = torch.empty(3, 5)
70
+ >>> nn.init.trunc_normal_(w)
71
+ """
72
+ with torch.no_grad():
73
+ return _trunc_normal_(tensor, mean, std, a, b)
74
+
75
+ class PCT_Tokenizer(nn.Module):
76
+ """ Tokenizer of Pose Compositional Tokens.
77
+ paper ref: Zigang Geng et al. "Human Pose as
78
+ Compositional Tokens"
79
+
80
+ Args:
81
+ stage_pct (str): Training stage (Tokenizer or Classifier).
82
+ tokenizer (list): Config about the tokenizer.
83
+ num_joints (int): Number of annotated joints in the dataset.
84
+ guide_ratio (float): The ratio of image guidance.
85
+ guide_channels (int): Feature Dim of the image guidance.
86
+ """
87
+
88
+ def __init__(self,
89
+ args,
90
+ stage_pct,
91
+ num_joints=14,
92
+ theta_dim=2,
93
+ guide_ratio=0,
94
+ guide_channels=0):
95
+ super().__init__()
96
+
97
+ self.stage_pct = stage_pct
98
+ self.guide_ratio = guide_ratio
99
+ self.num_joints = num_joints
100
+ self.theta_dim = theta_dim
101
+
102
+ self.drop_rate = args.tokenizer_encoder_drop_rate
103
+ self.enc_num_blocks = args.tokenizer_encoder_num_blocks
104
+ self.enc_hidden_dim = args.tokenizer_encoder_hidden_dim
105
+ self.enc_token_inter_dim = args.tokenizer_encoder_token_inter_dim
106
+ self.enc_hidden_inter_dim = args.tokenizer_encoder_hidden_inter_dim
107
+ self.enc_dropout = args.tokenizer_encoder_dropout
108
+
109
+ self.dec_num_blocks = args.tokenizer_decoder_num_blocks
110
+ self.dec_hidden_dim = args.tokenizer_decoder_hidden_dim
111
+ self.dec_token_inter_dim = args.tokenizer_decoder_token_inter_dim
112
+ self.dec_hidden_inter_dim = args.tokenizer_decoder_hidden_inter_dim
113
+ self.dec_dropout = args.tokenizer_decoder_dropout
114
+
115
+ self.token_num = args.tokenizer_codebook_token_num
116
+ self.token_class_num = args.tokenizer_codebook_token_class_num
117
+ self.token_dim = args.tokenizer_codebook_token_dim
118
+ self.decay = args.tokenizer_codebook_ema_decay
119
+
120
+ self.invisible_token = nn.Parameter(
121
+ torch.zeros(1, 1, self.enc_hidden_dim))
122
+ trunc_normal_(self.invisible_token, mean=0., std=0.02, a=-0.02, b=0.02)
123
+
124
+ if self.guide_ratio > 0:
125
+ self.start_img_embed = nn.Linear(
126
+ guide_channels, int(self.enc_hidden_dim*self.guide_ratio))
127
+ self.start_embed = nn.Linear(
128
+ 2, int(self.enc_hidden_dim*(1-self.guide_ratio)))
129
+
130
+ self.encoder = nn.ModuleList(
131
+ [MixerLayer(self.enc_hidden_dim, self.enc_hidden_inter_dim,
132
+ self.num_joints, self.enc_token_inter_dim,
133
+ self.enc_dropout) for _ in range(self.enc_num_blocks)])
134
+ self.encoder_layer_norm = nn.LayerNorm(self.enc_hidden_dim)
135
+
136
+ self.token_mlp = nn.Linear(
137
+ self.num_joints, self.token_num)
138
+ self.feature_embed = nn.Linear(
139
+ self.enc_hidden_dim, self.token_dim)
140
+
141
+ self.register_buffer('codebook',
142
+ torch.empty(self.token_class_num, self.token_dim))
143
+ self.codebook.data.normal_()
144
+ self.register_buffer('ema_cluster_size',
145
+ torch.zeros(self.token_class_num))
146
+ self.register_buffer('ema_w',
147
+ torch.empty(self.token_class_num, self.token_dim))
148
+ self.ema_w.data.normal_()
149
+
150
+ self.decoder_token_mlp = nn.Linear(
151
+ self.token_num, self.num_joints)
152
+ self.decoder_start = nn.Linear(
153
+ self.token_dim, self.dec_hidden_dim)
154
+
155
+ self.decoder = nn.ModuleList(
156
+ [MixerLayer(self.dec_hidden_dim, self.dec_hidden_inter_dim,
157
+ self.num_joints, self.dec_token_inter_dim,
158
+ self.dec_dropout) for _ in range(self.dec_num_blocks)])
159
+ self.decoder_layer_norm = nn.LayerNorm(self.dec_hidden_dim)
160
+
161
+ self.recover_embed = nn.Linear(self.dec_hidden_dim, 2)
162
+
163
+ def forward(self, joints, joints_feature, cls_logits, train=True):
164
+ """Forward function. """
165
+
166
+ if train or self.stage_pct == "tokenizer":
167
+ # Encoder of Tokenizer, Get the PCT groundtruth class labels.
168
+ bs, num_joints, _ = joints.shape
169
+ device = joints.device
170
+ joints_coord, joints_visible, bs \
171
+ = joints[:,:,:-1], joints[:,:,-1].bool(), joints.shape[0]
172
+
173
+ encode_feat = self.start_embed(joints_coord)
174
+ if self.guide_ratio > 0:
175
+ encode_img_feat = self.start_img_embed(joints_feature)
176
+ encode_feat = torch.cat((encode_feat, encode_img_feat), dim=2)
177
+
178
+ if train and self.stage_pct == "tokenizer":
179
+ rand_mask_ind = torch.rand(
180
+ joints_visible.shape, device=joints.device) > self.drop_rate
181
+ joints_visible = torch.logical_and(rand_mask_ind, joints_visible)
182
+
183
+ mask_tokens = self.invisible_token.expand(bs, joints.shape[1], -1)
184
+ w = joints_visible.unsqueeze(-1).type_as(mask_tokens)
185
+ encode_feat = encode_feat * w + mask_tokens * (1 - w)
186
+
187
+ for num_layer in self.encoder:
188
+ encode_feat = num_layer(encode_feat)
189
+ encode_feat = self.encoder_layer_norm(encode_feat)
190
+
191
+ encode_feat = encode_feat.transpose(2, 1)
192
+ encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
193
+ encode_feat = self.feature_embed(encode_feat).flatten(0,1)
194
+
195
+ distances = torch.sum(encode_feat**2, dim=1, keepdim=True) \
196
+ + torch.sum(self.codebook**2, dim=1) \
197
+ - 2 * torch.matmul(encode_feat, self.codebook.t())
198
+
199
+ encoding_indices = torch.argmin(distances, dim=1)
200
+ encodings = torch.zeros(
201
+ encoding_indices.shape[0], self.token_class_num, device=joints.device)
202
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
203
+ else:
204
+ # here it suppose cls_logits shape [bs * token_num * token_cls_num]
205
+ # predict prob of each token 0,1,2...M-1 belongs to entries 0,1,2...V-1
206
+ # see paper
207
+ bs = cls_logits.shape[0] // self.token_num
208
+ encoding_indices = None
209
+
210
+ if self.stage_pct == "classifier":
211
+ part_token_feat = torch.matmul(cls_logits, self.codebook)
212
+ else:
213
+ part_token_feat = torch.matmul(encodings, self.codebook)
214
+
215
+ if train and self.stage_pct == "tokenizer":
216
+ # Updating Codebook using EMA
217
+ dw = torch.matmul(encodings.t(), encode_feat.detach())
218
+ # sync
219
+ n_encodings, n_dw = encodings.numel(), dw.numel()
220
+ encodings_shape, dw_shape = encodings.shape, dw.shape
221
+ combined = torch.cat((encodings.flatten(), dw.flatten()))
222
+ dist.all_reduce(combined) # math sum
223
+ sync_encodings, sync_dw = torch.split(combined, [n_encodings, n_dw])
224
+ sync_encodings, sync_dw = \
225
+ sync_encodings.view(encodings_shape), sync_dw.view(dw_shape)
226
+
227
+ self.ema_cluster_size = self.ema_cluster_size * self.decay + \
228
+ (1 - self.decay) * torch.sum(sync_encodings, 0)
229
+
230
+ n = torch.sum(self.ema_cluster_size.data)
231
+ self.ema_cluster_size = (
232
+ (self.ema_cluster_size + 1e-5)
233
+ / (n + self.token_class_num * 1e-5) * n)
234
+
235
+ self.ema_w = self.ema_w * self.decay + (1 - self.decay) * sync_dw
236
+ self.codebook = self.ema_w / self.ema_cluster_size.unsqueeze(1)
237
+ e_latent_loss = F.mse_loss(part_token_feat.detach(), encode_feat)
238
+ part_token_feat = encode_feat + (part_token_feat - encode_feat).detach()
239
+ else:
240
+ e_latent_loss = None
241
+
242
+ # Decoder of Tokenizer, Recover the joints.
243
+ part_token_feat = part_token_feat.view(bs, -1, self.token_dim)
244
+
245
+ # Store part token
246
+ out_part_token_feat = part_token_feat.clone().detach()
247
+
248
+ part_token_feat = part_token_feat.transpose(2,1)
249
+ part_token_feat = self.decoder_token_mlp(part_token_feat).transpose(2,1)
250
+ decode_feat = self.decoder_start(part_token_feat)
251
+
252
+ for num_layer in self.decoder:
253
+ decode_feat = num_layer(decode_feat)
254
+ decode_feat = self.decoder_layer_norm(decode_feat)
255
+
256
+ recoverd_joints = self.recover_embed(decode_feat)
257
+
258
+ return recoverd_joints, encoding_indices, e_latent_loss, out_part_token_feat
259
+
260
+ def init_weights(self, pretrained=""):
261
+ """Initialize model weights."""
262
+
263
+ parameters_names = set()
264
+ for name, _ in self.named_parameters():
265
+ parameters_names.add(name)
266
+
267
+ buffers_names = set()
268
+ for name, _ in self.named_buffers():
269
+ buffers_names.add(name)
270
+
271
+ if os.path.isfile(pretrained):
272
+ assert (self.stage_pct == "classifier"), \
273
+ "Training tokenizer does not need to load model"
274
+ pretrained_state_dict = torch.load(pretrained,
275
+ map_location=lambda storage, loc: storage)
276
+
277
+ need_init_state_dict = {}
278
+
279
+ if 'state_dict' in pretrained_state_dict:
280
+ key = 'state_dict'
281
+ else:
282
+ key = 'model'
283
+ for name, m in pretrained_state_dict[key].items():
284
+ if 'keypoint_head.tokenizer.' in name:
285
+ name = name.replace('keypoint_head.tokenizer.', '')
286
+ if name in parameters_names or name in buffers_names:
287
+ need_init_state_dict[name] = m
288
+ self.load_state_dict(need_init_state_dict, strict=True)
289
+ else:
290
+ if self.stage_pct == "classifier":
291
+ print('If you are training a classifier, '\
292
+ 'must check that the well-trained tokenizer '\
293
+ 'is located in the correct path.')
294
+
295
+
296
+ def save_checkpoint(model, optimizer, epoch, loss, filepath):
297
+ checkpoint = {
298
+ 'epoch': epoch,
299
+ 'model_state_dict': model.state_dict(),
300
+ 'optimizer_state_dict': optimizer.state_dict(),
301
+ 'loss': loss
302
+ }
303
+ torch.save(checkpoint, filepath)
304
+ print(f"Checkpoint saved at {filepath}")
305
+
306
+ def load_checkpoint(model, optimizer, filepath):
307
+ checkpoint = torch.load(filepath)
308
+ model.load_state_dict(checkpoint['model_state_dict'])
309
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
310
+ epoch = checkpoint['epoch']
311
+ loss = checkpoint['loss']
312
+
313
+ print(f"Checkpoint loaded from {filepath}. Resuming training from epoch {epoch} with loss {loss}")
314
+
315
+ return epoch, loss
main/postometro.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------------------------
2
+ # FastMETRO Official Code
3
+ # Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
4
+ # Licensed under the MIT license.
5
+ # ----------------------------------------------------------------------------------------------
6
+
7
+ # ----------------------------------------------------------------------------------------------
8
+ # PostoMETRO Official Code
9
+ # Copyright (c) MIRACLE Lab. All Rights Reserved
10
+ # Licensed under the MIT license.
11
+ # ----------------------------------------------------------------------------------------------
12
+
13
+ from __future__ import absolute_import, division, print_function
14
+ import torch
15
+ import numpy as np
16
+ import argparse
17
+ import os
18
+ import os.path as osp
19
+ from torch import nn
20
+ from postometro_utils.smpl import Mesh
21
+ from postometro_utils.transformer import build_transformer
22
+ from postometro_utils.positional_encoding import build_position_encoding
23
+ from postometro_utils.modules import FCBlock, MixerLayer
24
+ from pct_utils.pct import PCT
25
+ from pct_utils.pct_backbone import SwinV2TransformerRPE2FC
26
+ from postometro_utils.pose_resnet import get_pose_net as get_pose_resnet
27
+ from postometro_utils.pose_resnet_config import config as resnet_config
28
+ from postometro_utils.pose_hrnet import get_pose_hrnet
29
+ from postometro_utils.pose_hrnet_config import _C as hrnet_config
30
+ from postometro_utils.pose_hrnet_config import update_config as hrnet_update_config
31
+
32
+ CUR_DIR = osp.dirname(os.path.abspath(__file__))
33
+
34
+ class PostoMETRO(nn.Module):
35
+ """PostoMETRO for 3D human pose and mesh reconstruction from a single RGB image"""
36
+ def __init__(self, args, backbone, mesh_sampler, pct = None, num_joints=14, num_vertices=431):
37
+ """
38
+ Parameters:
39
+ - args: Arguments
40
+ - backbone: CNN Backbone used to extract image features from the given image
41
+ - mesh_sampler: Mesh Sampler used in the coarse-to-fine mesh upsampling
42
+ - num_joints: The number of joint tokens used in the transformer decoder
43
+ - num_vertices: The number of vertex tokens used in the transformer decoder
44
+ """
45
+ super().__init__()
46
+ self.args = args
47
+ self.backbone = backbone
48
+ self.mesh_sampler = mesh_sampler
49
+ self.num_joints = num_joints
50
+ self.num_vertices = num_vertices
51
+
52
+ # the number of transformer layers, set to default
53
+ num_enc_layers = 3
54
+ num_dec_layers = 3
55
+
56
+ # configurations for the first transformer
57
+ self.transformer_config_1 = {"model_dim": args.model_dim_1, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
58
+ "feedforward_dim": args.feedforward_dim_1, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
59
+ "pos_type": args.pos_type}
60
+ # configurations for the second transformer
61
+ self.transformer_config_2 = {"model_dim": args.model_dim_2, "dropout": args.transformer_dropout, "nhead": args.transformer_nhead,
62
+ "feedforward_dim": args.feedforward_dim_2, "num_enc_layers": num_enc_layers, "num_dec_layers": num_dec_layers,
63
+ "pos_type": args.pos_type}
64
+ # build transformers
65
+ self.transformer_1 = build_transformer(self.transformer_config_1)
66
+ self.transformer_2 = build_transformer(self.transformer_config_2)
67
+
68
+ # dimensionality reduction
69
+ self.dim_reduce_enc_cam = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
70
+ self.dim_reduce_enc_img = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
71
+ self.dim_reduce_dec = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
72
+
73
+ # token embeddings
74
+ self.cam_token_embed = nn.Embedding(1, self.transformer_config_1["model_dim"])
75
+ self.joint_token_embed = nn.Embedding(self.num_joints, self.transformer_config_1["model_dim"])
76
+ self.vertex_token_embed = nn.Embedding(self.num_vertices, self.transformer_config_1["model_dim"])
77
+ # positional encodings
78
+ self.position_encoding_1 = build_position_encoding(pos_type=self.transformer_config_1['pos_type'], hidden_dim=self.transformer_config_1['model_dim'])
79
+ self.position_encoding_2 = build_position_encoding(pos_type=self.transformer_config_2['pos_type'], hidden_dim=self.transformer_config_2['model_dim'])
80
+ # estimators
81
+ self.xyz_regressor = nn.Linear(self.transformer_config_2["model_dim"], 3)
82
+ self.cam_predictor = nn.Linear(self.transformer_config_2["model_dim"], 3)
83
+
84
+ # 1x1 Convolution
85
+ self.conv_1x1 = nn.Conv2d(args.conv_1x1_dim, self.transformer_config_1["model_dim"], kernel_size=1)
86
+
87
+ # attention mask
88
+ zeros_1 = torch.tensor(np.zeros((num_vertices, num_joints)).astype(bool))
89
+ zeros_2 = torch.tensor(np.zeros((num_joints, (num_joints + num_vertices))).astype(bool))
90
+ adjacency_indices = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_indices.pt'))
91
+ adjacency_matrix_value = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_values.pt'))
92
+ adjacency_matrix_size = torch.load(osp.join(CUR_DIR, 'data/smpl_431_adjmat_size.pt'))
93
+ adjacency_matrix = torch.sparse_coo_tensor(adjacency_indices, adjacency_matrix_value, size=adjacency_matrix_size).to_dense()
94
+ temp_mask_1 = (adjacency_matrix == 0)
95
+ temp_mask_2 = torch.cat([zeros_1, temp_mask_1], dim=1)
96
+ self.attention_mask = torch.cat([zeros_2, temp_mask_2], dim=0)
97
+
98
+ # learnable upsampling layer is used (from coarse mesh to intermediate mesh); for visually pleasing mesh result
99
+ ### pre-computed upsampling matrix is used (from intermediate mesh to fine mesh); to reduce optimization difficulty
100
+ self.coarse2intermediate_upsample = nn.Linear(431, 1723)
101
+
102
+ # using extra token
103
+ self.pct = None
104
+ if pct is not None:
105
+ self.pct = pct
106
+ # +1 to align with uncertainty score
107
+ self.token_mixer = FCBlock(args.tokenizer_codebook_token_dim + 1, self.transformer_config_1["model_dim"])
108
+ self.start_embed = nn.Linear(512, args.enc_hidden_dim)
109
+ self.encoder = nn.ModuleList(
110
+ [MixerLayer(args.enc_hidden_dim, args.enc_hidden_inter_dim,
111
+ args.num_joints, args.token_inter_dim,
112
+ args.enc_dropout) for _ in range(args.enc_num_blocks)])
113
+ self.encoder_layer_norm = nn.LayerNorm(args.enc_hidden_dim)
114
+ self.token_mlp = nn.Linear(args.num_joints, args.token_num)
115
+ self.dim_reduce_enc_pct = nn.Linear(self.transformer_config_1["model_dim"], self.transformer_config_2["model_dim"])
116
+
117
+
118
+ def forward(self, images):
119
+ device = images.device
120
+ batch_size = images.size(0)
121
+
122
+ # preparation
123
+ cam_token = self.cam_token_embed.weight.unsqueeze(1).repeat(1, batch_size, 1) # 1 X batch_size X 512
124
+ jv_tokens = torch.cat([self.joint_token_embed.weight, self.vertex_token_embed.weight], dim=0).unsqueeze(1).repeat(1, batch_size, 1) # (num_joints + num_vertices) X batch_size X 512
125
+ attention_mask = self.attention_mask.to(device) # (num_joints + num_vertices) X (num_joints + num_vertices)
126
+
127
+ pct_token = None
128
+ if self.pct is not None:
129
+ pct_out = self.pct(images, None, train=False)
130
+ pct_pose = pct_out['part_token_feat'].clone()
131
+
132
+ encode_feat = self.start_embed(pct_pose) # 2, 17, 512
133
+ for num_layer in self.encoder:
134
+ encode_feat = num_layer(encode_feat)
135
+ encode_feat = self.encoder_layer_norm(encode_feat)
136
+ encode_feat = encode_feat.transpose(2, 1)
137
+ encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
138
+ pct_token_out = encode_feat.permute(1,0,2)
139
+
140
+ pct_score = pct_out['encoding_scores']
141
+ pct_score = pct_score.permute(1,0,2)
142
+ pct_token = torch.cat([pct_token_out, pct_score], dim = -1)
143
+ pct_token = self.token_mixer(pct_token) # [b, 34, 512]
144
+
145
+ # extract image features through a CNN backbone
146
+ _img_features = self.backbone(images) # batch_size X 2048 X 7 X 7
147
+ _, _, h, w = _img_features.shape
148
+ img_features = self.conv_1x1(_img_features).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
149
+
150
+ # positional encodings
151
+ pos_enc_1 = self.position_encoding_1(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 512
152
+ pos_enc_2 = self.position_encoding_2(batch_size, h, w, device).flatten(2).permute(2, 0, 1) # 49 X batch_size X 128
153
+
154
+ # first transformer encoder-decoder
155
+ cam_features_1, enc_img_features_1, jv_features_1, pct_features_1 = self.transformer_1(img_features, cam_token, jv_tokens, pos_enc_1, pct_token = pct_token, attention_mask=attention_mask)
156
+
157
+ # progressive dimensionality reduction
158
+ reduced_cam_features_1 = self.dim_reduce_enc_cam(cam_features_1) # 1 X batch_size X 128
159
+ reduced_enc_img_features_1 = self.dim_reduce_enc_img(enc_img_features_1) # 49 X batch_size X 128
160
+ reduced_jv_features_1 = self.dim_reduce_dec(jv_features_1) # (num_joints + num_vertices) X batch_size X 128
161
+ reduced_pct_features_1 = None
162
+ if pct_features_1 is not None:
163
+ reduced_pct_features_1 = self.dim_reduce_enc_pct(pct_features_1)
164
+
165
+ # second transformer encoder-decoder
166
+ cam_features_2, _, jv_features_2,_ = self.transformer_2(reduced_enc_img_features_1, reduced_cam_features_1, reduced_jv_features_1, pos_enc_2, pct_token = reduced_pct_features_1, attention_mask=attention_mask)
167
+
168
+ # estimators
169
+ pred_cam = self.cam_predictor(cam_features_2).view(batch_size, 3) # batch_size X 3
170
+
171
+ pred_3d_coordinates = self.xyz_regressor(jv_features_2.transpose(0, 1)) # batch_size X (num_joints + num_vertices) X 3
172
+ pred_3d_joints = pred_3d_coordinates[:,:self.num_joints,:] # batch_size X num_joints X 3
173
+ pred_3d_vertices_coarse = pred_3d_coordinates[:,self.num_joints:,:] # batch_size X num_vertices(coarse) X 3
174
+
175
+ # coarse-to-intermediate mesh upsampling
176
+ pred_3d_vertices_intermediate = self.coarse2intermediate_upsample(pred_3d_vertices_coarse.transpose(1,2)).transpose(1,2) # batch_size X num_vertices(intermediate) X 3
177
+ # intermediate-to-fine mesh upsampling
178
+ pred_3d_vertices_fine = self.mesh_sampler.upsample(pred_3d_vertices_intermediate, n1=1, n2=0) # batch_size X num_vertices(fine) X 3
179
+
180
+ out = {}
181
+ out['pred_cam'] = pred_cam
182
+ out['pct_pose'] = pct_out['pred_pose'] if self.pct is not None else torch.zeros((batch_size, 34, 2)).cuda(device)
183
+ out['pred_3d_joints'] = pred_3d_joints
184
+ out['pred_3d_vertices_coarse'] = pred_3d_vertices_coarse
185
+ out['pred_3d_vertices_intermediate'] = pred_3d_vertices_intermediate
186
+ out['pred_3d_vertices_fine'] = pred_3d_vertices_fine
187
+
188
+ return out
189
+
190
+
191
+ defaults_args = argparse.Namespace(
192
+ pos_type = 'sine',
193
+ transformer_dropout = 0.1,
194
+ transformer_nhead = 8,
195
+ conv_1x1_dim = 2048,
196
+ tokenizer_codebook_token_dim = 512,
197
+ model_dim_1 = 512,
198
+ feedforward_dim_1 = 2048,
199
+ model_dim_2 = 128,
200
+ feedforward_dim_2 = 512,
201
+ enc_hidden_dim = 512,
202
+ enc_hidden_inter_dim = 512,
203
+ token_inter_dim = 64,
204
+ enc_dropout = 0.0,
205
+ enc_num_blocks = 4,
206
+ num_joints = 34,
207
+ token_num = 34
208
+ )
209
+
210
+ default_pct_args = argparse.Namespace(
211
+ pct_backbone_channel = 1536,
212
+ tokenizer_guide_ratio=0.5,
213
+ cls_head_conv_channels=256,
214
+ cls_head_hidden_dim=64,
215
+ cls_head_num_blocks=4,
216
+ cls_head_hidden_inter_dim=256,
217
+ cls_head_token_inter_dim=64,
218
+ cls_head_dropout=0.0,
219
+ cls_head_conv_num_blocks=2,
220
+ cls_head_dilation=1,
221
+ # tokenzier
222
+ tokenizer_encoder_drop_rate=0.2,
223
+ tokenizer_encoder_num_blocks=4,
224
+ tokenizer_encoder_hidden_dim=512,
225
+ tokenizer_encoder_token_inter_dim=64,
226
+ tokenizer_encoder_hidden_inter_dim=512,
227
+ tokenizer_encoder_dropout=0.0,
228
+ tokenizer_decoder_num_blocks=1,
229
+ tokenizer_decoder_hidden_dim=32,
230
+ tokenizer_decoder_token_inter_dim=64,
231
+ tokenizer_decoder_hidden_inter_dim=64,
232
+ tokenizer_decoder_dropout=0.0,
233
+ tokenizer_codebook_token_num=34,
234
+ tokenizer_codebook_token_dim=512,
235
+ tokenizer_codebook_token_class_num=2048,
236
+ tokenizer_codebook_ema_decay=0.9,
237
+ )
238
+
239
+ backbone_config=dict(
240
+ embed_dim=192,
241
+ depths=[2, 2, 18, 2],
242
+ num_heads=[6, 12, 24, 48],
243
+ window_size=[16, 16, 16, 8],
244
+ pretrain_window_size=[12, 12, 12, 6],
245
+ ape=False,
246
+ drop_path_rate=0.5,
247
+ patch_norm=True,
248
+ use_checkpoint=True,
249
+ rpe_interpolation='geo',
250
+ use_shift=[True, True, False, False],
251
+ relative_coords_table_type='norm8_log_bylayer',
252
+ attn_type='cosine_mh',
253
+ rpe_output_type='sigmoid',
254
+ postnorm=True,
255
+ mlp_type='normal',
256
+ out_indices=(3,),
257
+ patch_embed_type='normal',
258
+ patch_merge_type='normal',
259
+ strid16=False,
260
+ frozen_stages=5,
261
+ )
262
+
263
+ def get_model(backbone_str = 'resnet50', device = torch.device('cpu'), checkpoint_file = None):
264
+ if backbone_str == 'hrnet-w48':
265
+ defaults_args.conv_1x1_dim = 384
266
+ # update hrnet config by yaml
267
+ hrnet_yaml = osp.join(CUR_DIR,'postometro_utils', 'pose_w48_256x192_adam_lr1e-3.yaml')
268
+ hrnet_update_config(hrnet_config, hrnet_yaml)
269
+ backbone = get_pose_hrnet(hrnet_config, None)
270
+ else:
271
+ backbone = get_pose_resnet(resnet_config, is_train=False)
272
+ mesh_upsampler = Mesh(device=device)
273
+ pct_swin_backbone = SwinV2TransformerRPE2FC(**backbone_config)
274
+ # initialize pct head
275
+ pct = PCT(default_pct_args, pct_swin_backbone, 'classifier', default_pct_args.pct_backbone_channel, (256, 256), 17, None, None).to(device)
276
+ model = PostoMETRO(defaults_args, backbone, mesh_upsampler, pct=pct).to(device)
277
+ print("[INFO] model loaded, params: {}, {}".format(backbone_str, device))
278
+ if checkpoint_file:
279
+ cpu_device = torch.device('cpu')
280
+ state_dict = torch.load(checkpoint_file, map_location=cpu_device)
281
+ model.load_state_dict(state_dict, strict=True)
282
+ del state_dict
283
+ print("[INFO] checkpoint loaded, params: {}, {}".format(backbone_str, device))
284
+ return model
285
+
286
+ if __name__ == '__main__':
287
+ test_model = get_model(device=torch.device('cuda'))
288
+ images = torch.randn(1,3,256,256).to(torch.device('cuda'))
289
+ test_out = test_model(images)
290
+ print("[TEST] resnet50, cuda : pass")
291
+
292
+ test_model = get_model()
293
+ images = torch.randn(1,3,256,256).to()
294
+ test_out = test_model(images)
295
+ print("[TEST] resnet50, cpu : pass")
296
+
297
+ test_model = get_model(backbone_str='hrnet-w48', device=torch.device('cuda'))
298
+ images = torch.randn(1,3,256,256).to(torch.device('cuda'))
299
+ test_out = test_model(images)
300
+ print("[TEST] hrnet-w48, cuda : pass")
301
+
302
+ test_model = get_model(backbone_str='hrnet-w48')
303
+ images = torch.randn(1,3,256,256).to()
304
+ test_out = test_model(images)
305
+ print("[TEST] hrnet-w48, cpu : pass")
main/postometro_utils/__pycache__/geometric_layers.cpython-39.pyc ADDED
Binary file (14.2 kB). View file
 
main/postometro_utils/__pycache__/modules.cpython-39.pyc ADDED
Binary file (3.41 kB). View file
 
main/postometro_utils/__pycache__/pose_hrnet.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
main/postometro_utils/__pycache__/pose_hrnet_config.cpython-39.pyc ADDED
Binary file (2.63 kB). View file
 
main/postometro_utils/__pycache__/pose_resnet.cpython-39.pyc ADDED
Binary file (7.36 kB). View file
 
main/postometro_utils/__pycache__/pose_resnet_config.cpython-39.pyc ADDED
Binary file (5.02 kB). View file
 
main/postometro_utils/__pycache__/positional_encoding.cpython-39.pyc ADDED
Binary file (2.2 kB). View file
 
main/postometro_utils/__pycache__/renderer_pyrender.cpython-39.pyc ADDED
Binary file (6.62 kB). View file
 
main/postometro_utils/__pycache__/smpl.cpython-39.pyc ADDED
Binary file (10.2 kB). View file
 
main/postometro_utils/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (8.17 kB). View file
 
main/postometro_utils/geometric_layers.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------------------------
2
+ # METRO (https://github.com/microsoft/MeshTransformer)
3
+ # Copyright (c) Microsoft Corporation. All Rights Reserved [see https://github.com/microsoft/MeshTransformer/blob/main/LICENSE for details]
4
+ # Licensed under the MIT license.
5
+ # ----------------------------------------------------------------------------------------------
6
+ """
7
+ Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
8
+
9
+ Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
10
+ """
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ def rodrigues(theta):
16
+ """Convert axis-angle representation to rotation matrix.
17
+ Args:
18
+ theta: size = [B, 3]
19
+ Returns:
20
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
21
+ """
22
+ l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
23
+ angle = torch.unsqueeze(l1norm, -1)
24
+ normalized = torch.div(theta, angle)
25
+ angle = angle * 0.5
26
+ v_cos = torch.cos(angle)
27
+ v_sin = torch.sin(angle)
28
+ quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
29
+ return quat2mat(quat)
30
+
31
+ def quat2mat(quat):
32
+ """Convert quaternion coefficients to rotation matrix.
33
+ Args:
34
+ quat: size = [B, 4] 4 <===>(w, x, y, z)
35
+ Returns:
36
+ Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
37
+ """
38
+ norm_quat = quat
39
+ norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
40
+ w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
41
+
42
+ B = quat.size(0)
43
+
44
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
45
+ wx, wy, wz = w*x, w*y, w*z
46
+ xy, xz, yz = x*y, x*z, y*z
47
+
48
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
49
+ 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
50
+ 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
51
+ return rotMat
52
+
53
+ def orthographic_projection(X, camera):
54
+ """Perform orthographic projection of 3D points X using the camera parameters
55
+ Args:
56
+ X: size = [B, N, 3]
57
+ camera: size = [B, 3]
58
+ Returns:
59
+ Projected 2D points -- size = [B, N, 2]
60
+ """
61
+ camera = camera.view(-1, 1, 3)
62
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
63
+ shape = X_trans.shape
64
+ X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
65
+ return X_2d
66
+
67
+ def orthographic_projection_reshape(X, camera):
68
+ """Perform orthographic projection of 3D points X using the camera parameters
69
+ Args:
70
+ X: size = [B, N, 3]
71
+ camera: size = [B, 3]
72
+ Returns:
73
+ Projected 2D points -- size = [B, N, 2]
74
+ """
75
+ camera = camera.reshape(-1, 1, 3)
76
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
77
+ shape = X_trans.shape
78
+ X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape)
79
+ return X_2d
80
+
81
+ def orthographic_projection_reshape(X, camera):
82
+ """Perform orthographic projection of 3D points X using the camera parameters
83
+ Args:
84
+ X: size = [B, N, 3]
85
+ camera: size = [B, 3]
86
+ Returns:
87
+ Projected 2D points -- size = [B, N, 2]
88
+ """
89
+ camera = camera.reshape(-1, 1, 3)
90
+ X_trans = X[:, :, :2] + camera[:, :, 1:]
91
+ shape = X_trans.shape
92
+ X_2d = (camera[:, :, 0] * X_trans.reshape(shape[0], -1)).reshape(shape)
93
+ return X_2d
94
+
95
+
96
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
97
+ """
98
+ Return a tensor where each element has the absolute value taken from the,
99
+ corresponding element of a, with sign taken from the corresponding
100
+ element of b. This is like the standard copysign floating-point operation,
101
+ but is not careful about negative 0 and NaN.
102
+
103
+ Args:
104
+ a: source tensor.
105
+ b: tensor whose signs will be used, of the same shape as a.
106
+
107
+ Returns:
108
+ Tensor of the same shape as a with the signs of b.
109
+ """
110
+ signs_differ = (a < 0) != (b < 0)
111
+ return torch.where(signs_differ, -a, a)
112
+
113
+
114
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
115
+ """
116
+ Returns torch.sqrt(torch.max(0, x))
117
+ but with a zero subgradient where x is 0.
118
+ """
119
+ ret = torch.zeros_like(x)
120
+ positive_mask = x > 0
121
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
122
+ return ret
123
+
124
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
127
+ using Gram--Schmidt orthogonalization per Section B of [1].
128
+ Args:
129
+ d6: 6D rotation representation, of size (*, 6)
130
+
131
+ Returns:
132
+ batch of rotation matrices of size (*, 3, 3)
133
+
134
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
135
+ On the Continuity of Rotation Representations in Neural Networks.
136
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
137
+ Retrieved from http://arxiv.org/abs/1812.07035
138
+ """
139
+
140
+ a1, a2 = d6[..., :3], d6[..., 3:]
141
+ b1 = F.normalize(a1, dim=-1)
142
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
143
+ b2 = F.normalize(b2, dim=-1)
144
+ b3 = torch.cross(b1, b2, dim=-1)
145
+ return torch.stack((b1, b2, b3), dim=-2)
146
+
147
+
148
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
149
+ """
150
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
151
+ by dropping the last row. Note that 6D representation is not unique.
152
+ Args:
153
+ matrix: batch of rotation matrices of size (*, 3, 3)
154
+
155
+ Returns:
156
+ 6D rotation representation, of size (*, 6)
157
+
158
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
159
+ On the Continuity of Rotation Representations in Neural Networks.
160
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
161
+ Retrieved from http://arxiv.org/abs/1812.07035
162
+ """
163
+ batch_dim = matrix.size()[:-2]
164
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
165
+
166
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
167
+ """
168
+ Convert rotations given as axis/angle to quaternions.
169
+
170
+ Args:
171
+ axis_angle: Rotations given as a vector in axis angle form,
172
+ as a tensor of shape (..., 3), where the magnitude is
173
+ the angle turned anticlockwise in radians around the
174
+ vector's direction.
175
+
176
+ Returns:
177
+ quaternions with real part first, as tensor of shape (..., 4).
178
+ """
179
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
180
+ half_angles = angles * 0.5
181
+ eps = 1e-6
182
+ small_angles = angles.abs() < eps
183
+ sin_half_angles_over_angles = torch.empty_like(angles)
184
+ sin_half_angles_over_angles[~small_angles] = (
185
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
186
+ )
187
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
188
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
189
+ sin_half_angles_over_angles[small_angles] = (
190
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
191
+ )
192
+ quaternions = torch.cat(
193
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
194
+ )
195
+ return quaternions
196
+
197
+
198
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
199
+ """
200
+ Convert rotations given as quaternions to axis/angle.
201
+
202
+ Args:
203
+ quaternions: quaternions with real part first,
204
+ as tensor of shape (..., 4).
205
+
206
+ Returns:
207
+ Rotations given as a vector in axis angle form, as a tensor
208
+ of shape (..., 3), where the magnitude is the angle
209
+ turned anticlockwise in radians around the vector's
210
+ direction.
211
+ """
212
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
213
+ half_angles = torch.atan2(norms, quaternions[..., :1])
214
+ angles = 2 * half_angles
215
+ eps = 1e-6
216
+ small_angles = angles.abs() < eps
217
+ sin_half_angles_over_angles = torch.empty_like(angles)
218
+ sin_half_angles_over_angles[~small_angles] = (
219
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
220
+ )
221
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
222
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
223
+ sin_half_angles_over_angles[small_angles] = (
224
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
225
+ )
226
+ return quaternions[..., 1:] / sin_half_angles_over_angles
227
+
228
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
229
+ """
230
+ Convert rotations given as quaternions to rotation matrices.
231
+
232
+ Args:
233
+ quaternions: quaternions with real part first,
234
+ as tensor of shape (..., 4).
235
+
236
+ Returns:
237
+ Rotation matrices as tensor of shape (..., 3, 3).
238
+ """
239
+ r, i, j, k = torch.unbind(quaternions, -1)
240
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
241
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
242
+
243
+ o = torch.stack(
244
+ (
245
+ 1 - two_s * (j * j + k * k),
246
+ two_s * (i * j - k * r),
247
+ two_s * (i * k + j * r),
248
+ two_s * (i * j + k * r),
249
+ 1 - two_s * (i * i + k * k),
250
+ two_s * (j * k - i * r),
251
+ two_s * (i * k - j * r),
252
+ two_s * (j * k + i * r),
253
+ 1 - two_s * (i * i + j * j),
254
+ ),
255
+ -1,
256
+ )
257
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
258
+
259
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
260
+ """
261
+ Convert rotations given as rotation matrices to quaternions.
262
+
263
+ Args:
264
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
265
+
266
+ Returns:
267
+ quaternions with real part first, as tensor of shape (..., 4).
268
+ """
269
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
270
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
271
+
272
+ batch_dim = matrix.shape[:-2]
273
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
274
+ matrix.reshape(batch_dim + (9,)), dim=-1
275
+ )
276
+
277
+ q_abs = _sqrt_positive_part(
278
+ torch.stack(
279
+ [
280
+ 1.0 + m00 + m11 + m22,
281
+ 1.0 + m00 - m11 - m22,
282
+ 1.0 - m00 + m11 - m22,
283
+ 1.0 - m00 - m11 + m22,
284
+ ],
285
+ dim=-1,
286
+ )
287
+ )
288
+
289
+ # we produce the desired quaternion multiplied by each of r, i, j, k
290
+ quat_by_rijk = torch.stack(
291
+ [
292
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
293
+ # `int`.
294
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
295
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
296
+ # `int`.
297
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
298
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
299
+ # `int`.
300
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
301
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
302
+ # `int`.
303
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
304
+ ],
305
+ dim=-2,
306
+ )
307
+
308
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
309
+ # the candidate won't be picked.
310
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
311
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
312
+
313
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
314
+ # forall i; we pick the best-conditioned one (with the largest denominator)
315
+
316
+ return quat_candidates[
317
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
318
+ ].reshape(batch_dim + (4,))
319
+
320
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
321
+ """
322
+ Convert rotations given as axis/angle to rotation matrices.
323
+
324
+ Args:
325
+ axis_angle: Rotations given as a vector in axis angle form,
326
+ as a tensor of shape (..., 3), where the magnitude is
327
+ the angle turned anticlockwise in radians around the
328
+ vector's direction.
329
+
330
+ Returns:
331
+ Rotation matrices as tensor of shape (..., 3, 3).
332
+ """
333
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
334
+
335
+
336
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
337
+ """
338
+ Convert rotations given as rotation matrices to axis/angle.
339
+
340
+ Args:
341
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
342
+
343
+ Returns:
344
+ Rotations given as a vector in axis angle form, as a tensor
345
+ of shape (..., 3), where the magnitude is the angle
346
+ turned anticlockwise in radians around the vector's
347
+ direction.
348
+ """
349
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
350
+
351
+ def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor:
352
+ """
353
+ Convert rotations given as axis/angle to rotation matrices.
354
+
355
+ Args:
356
+ axis_angle: Rotations given as a vector in axis angle form,
357
+ as a tensor of shape (..., 3), where the magnitude is
358
+ the angle turned anticlockwise in radians around the
359
+ vector's direction.
360
+
361
+ Returns:
362
+ 6D rotation representation, of size (*, 6)
363
+ """
364
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
365
+
366
+ def rotation_6d_to_axis_angle(d6):
367
+ """
368
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
369
+ using Gram--Schmidt orthogonalization per Section B of [1].
370
+ Args:
371
+ d6: 6D rotation representation, of size (*, 6)
372
+
373
+ Returns:
374
+ axis_angle: Rotations given as a vector in axis angle form,
375
+ as a tensor of shape (..., 3), where the magnitude is
376
+ the angle turned anticlockwise in radians around the
377
+ vector's direction.
378
+
379
+
380
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
381
+ On the Continuity of Rotation Representations in Neural Networks.
382
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
383
+ Retrieved from http://arxiv.org/abs/1812.07035
384
+ """
385
+
386
+ return matrix_to_axis_angle(rotation_6d_to_matrix(d6))
387
+
388
+
389
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
390
+ """
391
+ Return a tensor where each element has the absolute value taken from the,
392
+ corresponding element of a, with sign taken from the corresponding
393
+ element of b. This is like the standard copysign floating-point operation,
394
+ but is not careful about negative 0 and NaN.
395
+
396
+ Args:
397
+ a: source tensor.
398
+ b: tensor whose signs will be used, of the same shape as a.
399
+
400
+ Returns:
401
+ Tensor of the same shape as a with the signs of b.
402
+ """
403
+ signs_differ = (a < 0) != (b < 0)
404
+ return torch.where(signs_differ, -a, a)
405
+
406
+
407
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
408
+ """
409
+ Returns torch.sqrt(torch.max(0, x))
410
+ but with a zero subgradient where x is 0.
411
+ """
412
+ ret = torch.zeros_like(x)
413
+ positive_mask = x > 0
414
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
415
+ return ret
416
+
417
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
418
+ """
419
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
420
+ using Gram--Schmidt orthogonalization per Section B of [1].
421
+ Args:
422
+ d6: 6D rotation representation, of size (*, 6)
423
+
424
+ Returns:
425
+ batch of rotation matrices of size (*, 3, 3)
426
+
427
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
428
+ On the Continuity of Rotation Representations in Neural Networks.
429
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
430
+ Retrieved from http://arxiv.org/abs/1812.07035
431
+ """
432
+
433
+ a1, a2 = d6[..., :3], d6[..., 3:]
434
+ b1 = F.normalize(a1, dim=-1)
435
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
436
+ b2 = F.normalize(b2, dim=-1)
437
+ b3 = torch.cross(b1, b2, dim=-1)
438
+ return torch.stack((b1, b2, b3), dim=-2)
439
+
440
+
441
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
442
+ """
443
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
444
+ by dropping the last row. Note that 6D representation is not unique.
445
+ Args:
446
+ matrix: batch of rotation matrices of size (*, 3, 3)
447
+
448
+ Returns:
449
+ 6D rotation representation, of size (*, 6)
450
+
451
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
452
+ On the Continuity of Rotation Representations in Neural Networks.
453
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
454
+ Retrieved from http://arxiv.org/abs/1812.07035
455
+ """
456
+ batch_dim = matrix.size()[:-2]
457
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
458
+
459
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
460
+ """
461
+ Convert rotations given as axis/angle to quaternions.
462
+
463
+ Args:
464
+ axis_angle: Rotations given as a vector in axis angle form,
465
+ as a tensor of shape (..., 3), where the magnitude is
466
+ the angle turned anticlockwise in radians around the
467
+ vector's direction.
468
+
469
+ Returns:
470
+ quaternions with real part first, as tensor of shape (..., 4).
471
+ """
472
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
473
+ half_angles = angles * 0.5
474
+ eps = 1e-6
475
+ small_angles = angles.abs() < eps
476
+ sin_half_angles_over_angles = torch.empty_like(angles)
477
+ sin_half_angles_over_angles[~small_angles] = (
478
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
479
+ )
480
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
481
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
482
+ sin_half_angles_over_angles[small_angles] = (
483
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
484
+ )
485
+ quaternions = torch.cat(
486
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
487
+ )
488
+ return quaternions
489
+
490
+
491
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ Convert rotations given as quaternions to axis/angle.
494
+
495
+ Args:
496
+ quaternions: quaternions with real part first,
497
+ as tensor of shape (..., 4).
498
+
499
+ Returns:
500
+ Rotations given as a vector in axis angle form, as a tensor
501
+ of shape (..., 3), where the magnitude is the angle
502
+ turned anticlockwise in radians around the vector's
503
+ direction.
504
+ """
505
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
506
+ half_angles = torch.atan2(norms, quaternions[..., :1])
507
+ angles = 2 * half_angles
508
+ eps = 1e-6
509
+ small_angles = angles.abs() < eps
510
+ sin_half_angles_over_angles = torch.empty_like(angles)
511
+ sin_half_angles_over_angles[~small_angles] = (
512
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
513
+ )
514
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
515
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
516
+ sin_half_angles_over_angles[small_angles] = (
517
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
518
+ )
519
+ return quaternions[..., 1:] / sin_half_angles_over_angles
520
+
521
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
522
+ """
523
+ Convert rotations given as quaternions to rotation matrices.
524
+
525
+ Args:
526
+ quaternions: quaternions with real part first,
527
+ as tensor of shape (..., 4).
528
+
529
+ Returns:
530
+ Rotation matrices as tensor of shape (..., 3, 3).
531
+ """
532
+ r, i, j, k = torch.unbind(quaternions, -1)
533
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
534
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
535
+
536
+ o = torch.stack(
537
+ (
538
+ 1 - two_s * (j * j + k * k),
539
+ two_s * (i * j - k * r),
540
+ two_s * (i * k + j * r),
541
+ two_s * (i * j + k * r),
542
+ 1 - two_s * (i * i + k * k),
543
+ two_s * (j * k - i * r),
544
+ two_s * (i * k - j * r),
545
+ two_s * (j * k + i * r),
546
+ 1 - two_s * (i * i + j * j),
547
+ ),
548
+ -1,
549
+ )
550
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
551
+
552
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
553
+ """
554
+ Convert rotations given as rotation matrices to quaternions.
555
+
556
+ Args:
557
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
558
+
559
+ Returns:
560
+ quaternions with real part first, as tensor of shape (..., 4).
561
+ """
562
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
563
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
564
+
565
+ batch_dim = matrix.shape[:-2]
566
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
567
+ matrix.reshape(batch_dim + (9,)), dim=-1
568
+ )
569
+
570
+ q_abs = _sqrt_positive_part(
571
+ torch.stack(
572
+ [
573
+ 1.0 + m00 + m11 + m22,
574
+ 1.0 + m00 - m11 - m22,
575
+ 1.0 - m00 + m11 - m22,
576
+ 1.0 - m00 - m11 + m22,
577
+ ],
578
+ dim=-1,
579
+ )
580
+ )
581
+
582
+ # we produce the desired quaternion multiplied by each of r, i, j, k
583
+ quat_by_rijk = torch.stack(
584
+ [
585
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
586
+ # `int`.
587
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
588
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
589
+ # `int`.
590
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
591
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
592
+ # `int`.
593
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
594
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
595
+ # `int`.
596
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
597
+ ],
598
+ dim=-2,
599
+ )
600
+
601
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
602
+ # the candidate won't be picked.
603
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
604
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
605
+
606
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
607
+ # forall i; we pick the best-conditioned one (with the largest denominator)
608
+
609
+ return quat_candidates[
610
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
611
+ ].reshape(batch_dim + (4,))
612
+
613
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
614
+ """
615
+ Convert rotations given as axis/angle to rotation matrices.
616
+
617
+ Args:
618
+ axis_angle: Rotations given as a vector in axis angle form,
619
+ as a tensor of shape (..., 3), where the magnitude is
620
+ the angle turned anticlockwise in radians around the
621
+ vector's direction.
622
+
623
+ Returns:
624
+ Rotation matrices as tensor of shape (..., 3, 3).
625
+ """
626
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
627
+
628
+
629
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
630
+ """
631
+ Convert rotations given as rotation matrices to axis/angle.
632
+
633
+ Args:
634
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
635
+
636
+ Returns:
637
+ Rotations given as a vector in axis angle form, as a tensor
638
+ of shape (..., 3), where the magnitude is the angle
639
+ turned anticlockwise in radians around the vector's
640
+ direction.
641
+ """
642
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
643
+
644
+ def axis_angle_to_rotation_6d(axis_angle: torch.Tensor) -> torch.Tensor:
645
+ """
646
+ Convert rotations given as axis/angle to rotation matrices.
647
+
648
+ Args:
649
+ axis_angle: Rotations given as a vector in axis angle form,
650
+ as a tensor of shape (..., 3), where the magnitude is
651
+ the angle turned anticlockwise in radians around the
652
+ vector's direction.
653
+
654
+ Returns:
655
+ 6D rotation representation, of size (*, 6)
656
+ """
657
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
658
+
659
+ def rotation_6d_to_axis_angle(d6):
660
+ """
661
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
662
+ using Gram--Schmidt orthogonalization per Section B of [1].
663
+ Args:
664
+ d6: 6D rotation representation, of size (*, 6)
665
+
666
+ Returns:
667
+ axis_angle: Rotations given as a vector in axis angle form,
668
+ as a tensor of shape (..., 3), where the magnitude is
669
+ the angle turned anticlockwise in radians around the
670
+ vector's direction.
671
+
672
+
673
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
674
+ On the Continuity of Rotation Representations in Neural Networks.
675
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
676
+ Retrieved from http://arxiv.org/abs/1812.07035
677
+ """
678
+
679
+ return matrix_to_axis_angle(rotation_6d_to_matrix(d6))
main/postometro_utils/modules.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Borrow from unofficial MLPMixer (https://github.com/920232796/MlpMixer-pytorch)
3
+ # Borrow from ResNet
4
+ # Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
5
+ # --------------------------------------------------------
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class FCBlock(nn.Module):
12
+ def __init__(self, dim, out_dim):
13
+ super().__init__()
14
+
15
+ self.ff = nn.Sequential(
16
+ nn.Linear(dim, out_dim),
17
+ nn.LayerNorm(out_dim),
18
+ nn.ReLU(inplace=True),
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.ff(x)
23
+
24
+
25
+ class MLPBlock(nn.Module):
26
+ def __init__(self, dim, inter_dim, dropout_ratio):
27
+ super().__init__()
28
+
29
+ self.ff = nn.Sequential(
30
+ nn.Linear(dim, inter_dim),
31
+ nn.GELU(),
32
+ nn.Dropout(dropout_ratio),
33
+ nn.Linear(inter_dim, dim),
34
+ nn.Dropout(dropout_ratio)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.ff(x)
39
+
40
+
41
+ class MixerLayer(nn.Module):
42
+ def __init__(self,
43
+ hidden_dim,
44
+ hidden_inter_dim,
45
+ token_dim,
46
+ token_inter_dim,
47
+ dropout_ratio):
48
+ super().__init__()
49
+
50
+ self.layernorm1 = nn.LayerNorm(hidden_dim)
51
+ self.MLP_token = MLPBlock(token_dim, token_inter_dim, dropout_ratio)
52
+ self.layernorm2 = nn.LayerNorm(hidden_dim)
53
+ self.MLP_channel = MLPBlock(hidden_dim, hidden_inter_dim, dropout_ratio)
54
+
55
+ def forward(self, x):
56
+ y = self.layernorm1(x)
57
+ y = y.transpose(2, 1)
58
+ y = self.MLP_token(y)
59
+ y = y.transpose(2, 1)
60
+ z = self.layernorm2(x + y)
61
+ z = self.MLP_channel(z)
62
+ out = x + y + z
63
+ return out
64
+
65
+
66
+ class BasicBlock(nn.Module):
67
+ expansion = 1
68
+
69
+ def __init__(self, inplanes, planes, stride=1,
70
+ downsample=None, dilation=1):
71
+ super(BasicBlock, self).__init__()
72
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
73
+ padding=dilation, bias=False, dilation=dilation)
74
+ self.bn1 = nn.BatchNorm2d(planes, momentum=0.1)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
77
+ padding=dilation, bias=False, dilation=dilation)
78
+ self.bn2 = nn.BatchNorm2d(planes, momentum=0.1)
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+
82
+
83
+ def forward(self, x):
84
+ residual = x
85
+
86
+ out = self.conv1(x)
87
+ out = self.bn1(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv2(out)
91
+ out = self.bn2(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+ def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True):
102
+ layers = []
103
+ for i in range(len(feat_dims)-1):
104
+ layers.append(
105
+ nn.Conv2d(
106
+ in_channels=feat_dims[i],
107
+ out_channels=feat_dims[i+1],
108
+ kernel_size=kernel,
109
+ stride=stride,
110
+ padding=padding
111
+ ))
112
+ # Do not use BN and ReLU for final estimation
113
+ if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final):
114
+ layers.append(nn.BatchNorm2d(feat_dims[i+1]))
115
+ layers.append(nn.ReLU(inplace=True))
116
+
117
+ return nn.Sequential(*layers)
main/postometro_utils/pose_hrnet.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import os
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+
18
+ BN_MOMENTUM = 0.1
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def conv3x3(in_planes, out_planes, stride=1):
23
+ """3x3 convolution with padding"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25
+ padding=1, bias=False)
26
+
27
+
28
+ class BasicBlock(nn.Module):
29
+ expansion = 1
30
+
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
32
+ super(BasicBlock, self).__init__()
33
+ self.conv1 = conv3x3(inplanes, planes, stride)
34
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
35
+ self.relu = nn.ReLU(inplace=True)
36
+ self.conv2 = conv3x3(planes, planes)
37
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
38
+ self.downsample = downsample
39
+ self.stride = stride
40
+
41
+ def forward(self, x):
42
+ residual = x
43
+
44
+ out = self.conv1(x)
45
+ out = self.bn1(out)
46
+ out = self.relu(out)
47
+
48
+ out = self.conv2(out)
49
+ out = self.bn2(out)
50
+
51
+ if self.downsample is not None:
52
+ residual = self.downsample(x)
53
+
54
+ out += residual
55
+ out = self.relu(out)
56
+
57
+ return out
58
+
59
+
60
+ class Bottleneck(nn.Module):
61
+ expansion = 4
62
+
63
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
64
+ super(Bottleneck, self).__init__()
65
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
66
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
67
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
68
+ padding=1, bias=False)
69
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
70
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
71
+ bias=False)
72
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
73
+ momentum=BN_MOMENTUM)
74
+ self.relu = nn.ReLU(inplace=True)
75
+ self.downsample = downsample
76
+ self.stride = stride
77
+
78
+ def forward(self, x):
79
+ residual = x
80
+
81
+ out = self.conv1(x)
82
+ out = self.bn1(out)
83
+ out = self.relu(out)
84
+
85
+ out = self.conv2(out)
86
+ out = self.bn2(out)
87
+ out = self.relu(out)
88
+
89
+ out = self.conv3(out)
90
+ out = self.bn3(out)
91
+
92
+ if self.downsample is not None:
93
+ residual = self.downsample(x)
94
+
95
+ out += residual
96
+ out = self.relu(out)
97
+
98
+ return out
99
+
100
+
101
+ class HighResolutionModule(nn.Module):
102
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
103
+ num_channels, fuse_method, multi_scale_output=True):
104
+ super(HighResolutionModule, self).__init__()
105
+ self._check_branches(
106
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
107
+
108
+ self.num_inchannels = num_inchannels
109
+ self.fuse_method = fuse_method
110
+ self.num_branches = num_branches
111
+
112
+ self.multi_scale_output = multi_scale_output
113
+
114
+ self.branches = self._make_branches(
115
+ num_branches, blocks, num_blocks, num_channels)
116
+ self.fuse_layers = self._make_fuse_layers()
117
+ self.relu = nn.ReLU(True)
118
+
119
+ def _check_branches(self, num_branches, blocks, num_blocks,
120
+ num_inchannels, num_channels):
121
+ if num_branches != len(num_blocks):
122
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
123
+ num_branches, len(num_blocks))
124
+ logger.error(error_msg)
125
+ raise ValueError(error_msg)
126
+
127
+ if num_branches != len(num_channels):
128
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
129
+ num_branches, len(num_channels))
130
+ logger.error(error_msg)
131
+ raise ValueError(error_msg)
132
+
133
+ if num_branches != len(num_inchannels):
134
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
135
+ num_branches, len(num_inchannels))
136
+ logger.error(error_msg)
137
+ raise ValueError(error_msg)
138
+
139
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
140
+ stride=1):
141
+ downsample = None
142
+ if stride != 1 or \
143
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
144
+ downsample = nn.Sequential(
145
+ nn.Conv2d(
146
+ self.num_inchannels[branch_index],
147
+ num_channels[branch_index] * block.expansion,
148
+ kernel_size=1, stride=stride, bias=False
149
+ ),
150
+ nn.BatchNorm2d(
151
+ num_channels[branch_index] * block.expansion,
152
+ momentum=BN_MOMENTUM
153
+ ),
154
+ )
155
+
156
+ layers = []
157
+ layers.append(
158
+ block(
159
+ self.num_inchannels[branch_index],
160
+ num_channels[branch_index],
161
+ stride,
162
+ downsample
163
+ )
164
+ )
165
+ self.num_inchannels[branch_index] = \
166
+ num_channels[branch_index] * block.expansion
167
+ for i in range(1, num_blocks[branch_index]):
168
+ layers.append(
169
+ block(
170
+ self.num_inchannels[branch_index],
171
+ num_channels[branch_index]
172
+ )
173
+ )
174
+
175
+ return nn.Sequential(*layers)
176
+
177
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
178
+ branches = []
179
+
180
+ for i in range(num_branches):
181
+ branches.append(
182
+ self._make_one_branch(i, block, num_blocks, num_channels)
183
+ )
184
+
185
+ return nn.ModuleList(branches)
186
+
187
+ def _make_fuse_layers(self):
188
+ if self.num_branches == 1:
189
+ return None
190
+
191
+ num_branches = self.num_branches
192
+ num_inchannels = self.num_inchannels
193
+ fuse_layers = []
194
+ for i in range(num_branches if self.multi_scale_output else 1):
195
+ fuse_layer = []
196
+ for j in range(num_branches):
197
+ if j > i:
198
+ fuse_layer.append(
199
+ nn.Sequential(
200
+ nn.Conv2d(
201
+ num_inchannels[j],
202
+ num_inchannels[i],
203
+ 1, 1, 0, bias=False
204
+ ),
205
+ nn.BatchNorm2d(num_inchannels[i]),
206
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
207
+ )
208
+ )
209
+ elif j == i:
210
+ fuse_layer.append(None)
211
+ else:
212
+ conv3x3s = []
213
+ for k in range(i-j):
214
+ if k == i - j - 1:
215
+ num_outchannels_conv3x3 = num_inchannels[i]
216
+ conv3x3s.append(
217
+ nn.Sequential(
218
+ nn.Conv2d(
219
+ num_inchannels[j],
220
+ num_outchannels_conv3x3,
221
+ 3, 2, 1, bias=False
222
+ ),
223
+ nn.BatchNorm2d(num_outchannels_conv3x3)
224
+ )
225
+ )
226
+ else:
227
+ num_outchannels_conv3x3 = num_inchannels[j]
228
+ conv3x3s.append(
229
+ nn.Sequential(
230
+ nn.Conv2d(
231
+ num_inchannels[j],
232
+ num_outchannels_conv3x3,
233
+ 3, 2, 1, bias=False
234
+ ),
235
+ nn.BatchNorm2d(num_outchannels_conv3x3),
236
+ nn.ReLU(True)
237
+ )
238
+ )
239
+ fuse_layer.append(nn.Sequential(*conv3x3s))
240
+ fuse_layers.append(nn.ModuleList(fuse_layer))
241
+
242
+ return nn.ModuleList(fuse_layers)
243
+
244
+ def get_num_inchannels(self):
245
+ return self.num_inchannels
246
+
247
+ def forward(self, x):
248
+ if self.num_branches == 1:
249
+ return [self.branches[0](x[0])]
250
+
251
+ for i in range(self.num_branches):
252
+ x[i] = self.branches[i](x[i])
253
+
254
+ x_fuse = []
255
+
256
+ for i in range(len(self.fuse_layers)):
257
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
258
+ for j in range(1, self.num_branches):
259
+ if i == j:
260
+ y = y + x[j]
261
+ else:
262
+ y = y + self.fuse_layers[i][j](x[j])
263
+ x_fuse.append(self.relu(y))
264
+
265
+ return x_fuse
266
+
267
+
268
+ blocks_dict = {
269
+ 'BASIC': BasicBlock,
270
+ 'BOTTLENECK': Bottleneck
271
+ }
272
+
273
+
274
+ class PoseHighResolutionNet(nn.Module):
275
+
276
+ def __init__(self, cfg, **kwargs):
277
+ self.inplanes = 64
278
+ extra = cfg['MODEL']['EXTRA']
279
+ super(PoseHighResolutionNet, self).__init__()
280
+
281
+ # stem net
282
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
283
+ bias=False)
284
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
285
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
286
+ bias=False)
287
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
288
+ self.relu = nn.ReLU(inplace=True)
289
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
290
+
291
+ self.stage2_cfg = extra['STAGE2']
292
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
293
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
294
+ num_channels = [
295
+ num_channels[i] * block.expansion for i in range(len(num_channels))
296
+ ]
297
+ self.transition1 = self._make_transition_layer([256], num_channels)
298
+ self.stage2, pre_stage_channels = self._make_stage(
299
+ self.stage2_cfg, num_channels)
300
+
301
+ self.stage3_cfg = extra['STAGE3']
302
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
303
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
304
+ num_channels = [
305
+ num_channels[i] * block.expansion for i in range(len(num_channels))
306
+ ]
307
+ self.transition2 = self._make_transition_layer(
308
+ pre_stage_channels, num_channels)
309
+ self.stage3, pre_stage_channels = self._make_stage(
310
+ self.stage3_cfg, num_channels)
311
+
312
+ self.stage4_cfg = extra['STAGE4']
313
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
314
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
315
+ num_channels = [
316
+ num_channels[i] * block.expansion for i in range(len(num_channels))
317
+ ]
318
+ self.transition3 = self._make_transition_layer(
319
+ pre_stage_channels, num_channels)
320
+ self.stage4, pre_stage_channels = self._make_stage(
321
+ self.stage4_cfg, num_channels,
322
+ multi_scale_output=True)
323
+ # multi_scale_output=False)
324
+
325
+ self.final_layer = nn.Conv2d(
326
+ in_channels=pre_stage_channels[0],
327
+ out_channels=cfg['MODEL']['NUM_JOINTS'],
328
+ kernel_size=extra['FINAL_CONV_KERNEL'],
329
+ stride=1,
330
+ padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
331
+ )
332
+
333
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
334
+
335
+ def _make_transition_layer(
336
+ self, num_channels_pre_layer, num_channels_cur_layer):
337
+ num_branches_cur = len(num_channels_cur_layer)
338
+ num_branches_pre = len(num_channels_pre_layer)
339
+
340
+ transition_layers = []
341
+ for i in range(num_branches_cur):
342
+ if i < num_branches_pre:
343
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
344
+ transition_layers.append(
345
+ nn.Sequential(
346
+ nn.Conv2d(
347
+ num_channels_pre_layer[i],
348
+ num_channels_cur_layer[i],
349
+ 3, 1, 1, bias=False
350
+ ),
351
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
352
+ nn.ReLU(inplace=True)
353
+ )
354
+ )
355
+ else:
356
+ transition_layers.append(None)
357
+ else:
358
+ conv3x3s = []
359
+ for j in range(i+1-num_branches_pre):
360
+ inchannels = num_channels_pre_layer[-1]
361
+ outchannels = num_channels_cur_layer[i] \
362
+ if j == i-num_branches_pre else inchannels
363
+ conv3x3s.append(
364
+ nn.Sequential(
365
+ nn.Conv2d(
366
+ inchannels, outchannels, 3, 2, 1, bias=False
367
+ ),
368
+ nn.BatchNorm2d(outchannels),
369
+ nn.ReLU(inplace=True)
370
+ )
371
+ )
372
+ transition_layers.append(nn.Sequential(*conv3x3s))
373
+
374
+ return nn.ModuleList(transition_layers)
375
+
376
+ def _make_layer(self, block, planes, blocks, stride=1):
377
+ downsample = None
378
+ if stride != 1 or self.inplanes != planes * block.expansion:
379
+ downsample = nn.Sequential(
380
+ nn.Conv2d(
381
+ self.inplanes, planes * block.expansion,
382
+ kernel_size=1, stride=stride, bias=False
383
+ ),
384
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
385
+ )
386
+
387
+ layers = []
388
+ layers.append(block(self.inplanes, planes, stride, downsample))
389
+ self.inplanes = planes * block.expansion
390
+ for i in range(1, blocks):
391
+ layers.append(block(self.inplanes, planes))
392
+
393
+ return nn.Sequential(*layers)
394
+
395
+ def _make_stage(self, layer_config, num_inchannels,
396
+ multi_scale_output=True):
397
+ num_modules = layer_config['NUM_MODULES']
398
+ num_branches = layer_config['NUM_BRANCHES']
399
+ num_blocks = layer_config['NUM_BLOCKS']
400
+ num_channels = layer_config['NUM_CHANNELS']
401
+ block = blocks_dict[layer_config['BLOCK']]
402
+ fuse_method = layer_config['FUSE_METHOD']
403
+
404
+ modules = []
405
+ for i in range(num_modules):
406
+ # multi_scale_output is only used last module
407
+ if not multi_scale_output and i == num_modules - 1:
408
+ reset_multi_scale_output = False
409
+ else:
410
+ reset_multi_scale_output = True
411
+
412
+ modules.append(
413
+ HighResolutionModule(
414
+ num_branches,
415
+ block,
416
+ num_blocks,
417
+ num_inchannels,
418
+ num_channels,
419
+ fuse_method,
420
+ reset_multi_scale_output
421
+ )
422
+ )
423
+ num_inchannels = modules[-1].get_num_inchannels()
424
+
425
+ return nn.Sequential(*modules), num_inchannels
426
+
427
+ def forward(self, x):
428
+ x = self.conv1(x)
429
+ x = self.bn1(x)
430
+ x = self.relu(x)
431
+ x = self.conv2(x)
432
+ x = self.bn2(x)
433
+ x = self.relu(x)
434
+ x = self.layer1(x)
435
+
436
+ x_list = []
437
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
438
+ if self.transition1[i] is not None:
439
+ x_list.append(self.transition1[i](x))
440
+ else:
441
+ x_list.append(x)
442
+ y_list = self.stage2(x_list)
443
+
444
+ x_list = []
445
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
446
+ if self.transition2[i] is not None:
447
+ x_list.append(self.transition2[i](y_list[-1]))
448
+ else:
449
+ x_list.append(y_list[i])
450
+ y_list = self.stage3(x_list)
451
+
452
+ x_list = []
453
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
454
+ if self.transition3[i] is not None:
455
+ x_list.append(self.transition3[i](y_list[-1]))
456
+ else:
457
+ x_list.append(y_list[i])
458
+ y_list = self.stage4(x_list)
459
+
460
+ return y_list[-1]
461
+ # x = self.final_layer(y_list[0])
462
+ # return x
463
+
464
+ def init_weights(self, pretrained=''):
465
+ logger.info('=> init weights from normal distribution')
466
+ for m in self.modules():
467
+ if isinstance(m, nn.Conv2d):
468
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
469
+ nn.init.normal_(m.weight, std=0.001)
470
+ for name, _ in m.named_parameters():
471
+ if name in ['bias']:
472
+ nn.init.constant_(m.bias, 0)
473
+ elif isinstance(m, nn.BatchNorm2d):
474
+ nn.init.constant_(m.weight, 1)
475
+ nn.init.constant_(m.bias, 0)
476
+ elif isinstance(m, nn.ConvTranspose2d):
477
+ nn.init.normal_(m.weight, std=0.001)
478
+ for name, _ in m.named_parameters():
479
+ if name in ['bias']:
480
+ nn.init.constant_(m.bias, 0)
481
+
482
+ if os.path.isfile(pretrained):
483
+ pretrained_state_dict = torch.load(pretrained)
484
+ logger.info('=> loading pretrained model {}'.format(pretrained))
485
+
486
+ need_init_state_dict = {}
487
+ for name, m in pretrained_state_dict.items():
488
+ if name.split('.')[0] in self.pretrained_layers \
489
+ or self.pretrained_layers[0] is '*':
490
+ need_init_state_dict[name] = m
491
+ out = self.load_state_dict(need_init_state_dict, strict=False)
492
+ elif pretrained:
493
+ logger.error('=> please download pre-trained models first!')
494
+ raise ValueError('{} is not exist!'.format(pretrained))
495
+
496
+
497
+ def get_pose_hrnet(cfg, pretrained, **kwargs):
498
+ model = PoseHighResolutionNet(cfg, **kwargs)
499
+ if pretrained is not None:
500
+ model.init_weights(pretrained=pretrained)
501
+
502
+ return model
main/postometro_utils/pose_hrnet_config.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
6
+ # ------------------------------------------------------------------------------
7
+
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+
12
+ import os
13
+
14
+ from yacs.config import CfgNode as CN
15
+
16
+
17
+ _C = CN()
18
+
19
+ _C.OUTPUT_DIR = ''
20
+ _C.LOG_DIR = ''
21
+ _C.DATA_DIR = ''
22
+ _C.GPUS = (0,)
23
+ _C.WORKERS = 4
24
+ _C.PRINT_FREQ = 20
25
+ _C.AUTO_RESUME = False
26
+ _C.PIN_MEMORY = True
27
+ _C.RANK = 0
28
+
29
+ # Cudnn related params
30
+ _C.CUDNN = CN()
31
+ _C.CUDNN.BENCHMARK = True
32
+ _C.CUDNN.DETERMINISTIC = False
33
+ _C.CUDNN.ENABLED = True
34
+
35
+ # common params for NETWORK
36
+ _C.MODEL = CN()
37
+ _C.MODEL.NAME = 'cls_hrnet'
38
+ _C.MODEL.INIT_WEIGHTS = True
39
+ _C.MODEL.PRETRAINED = ''
40
+ _C.MODEL.NUM_JOINTS = 17
41
+ _C.MODEL.NUM_CLASSES = 1000
42
+ _C.MODEL.TAG_PER_JOINT = True
43
+ _C.MODEL.TARGET_TYPE = 'gaussian'
44
+ _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
45
+ _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
46
+ _C.MODEL.SIGMA = 2
47
+ _C.MODEL.EXTRA = CN(new_allowed=True)
48
+
49
+ _C.LOSS = CN()
50
+ _C.LOSS.USE_OHKM = False
51
+ _C.LOSS.TOPK = 8
52
+ _C.LOSS.USE_TARGET_WEIGHT = True
53
+ _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
54
+
55
+ # DATASET related params
56
+ _C.DATASET = CN()
57
+ _C.DATASET.ROOT = ''
58
+ _C.DATASET.DATASET = 'mpii'
59
+ _C.DATASET.TRAIN_SET = 'train'
60
+ _C.DATASET.TEST_SET = 'valid'
61
+ _C.DATASET.DATA_FORMAT = 'jpg'
62
+ _C.DATASET.HYBRID_JOINTS_TYPE = ''
63
+ _C.DATASET.SELECT_DATA = False
64
+
65
+ # training data augmentation
66
+ _C.DATASET.FLIP = True
67
+ _C.DATASET.SCALE_FACTOR = 0.25
68
+ _C.DATASET.ROT_FACTOR = 30
69
+ _C.DATASET.PROB_HALF_BODY = 0.0
70
+ _C.DATASET.NUM_JOINTS_HALF_BODY = 8
71
+ _C.DATASET.COLOR_RGB = False
72
+
73
+ # train
74
+ _C.TRAIN = CN()
75
+
76
+ _C.TRAIN.LR_FACTOR = 0.1
77
+ _C.TRAIN.LR_STEP = [90, 110]
78
+ _C.TRAIN.LR = 0.001
79
+
80
+ _C.TRAIN.OPTIMIZER = 'adam'
81
+ _C.TRAIN.MOMENTUM = 0.9
82
+ _C.TRAIN.WD = 0.0001
83
+ _C.TRAIN.NESTEROV = False
84
+ _C.TRAIN.GAMMA1 = 0.99
85
+ _C.TRAIN.GAMMA2 = 0.0
86
+
87
+ _C.TRAIN.BEGIN_EPOCH = 0
88
+ _C.TRAIN.END_EPOCH = 140
89
+
90
+ _C.TRAIN.RESUME = False
91
+ _C.TRAIN.CHECKPOINT = ''
92
+
93
+ _C.TRAIN.BATCH_SIZE_PER_GPU = 32
94
+ _C.TRAIN.SHUFFLE = True
95
+
96
+ # testing
97
+ _C.TEST = CN()
98
+
99
+ # size of images for each device
100
+ _C.TEST.BATCH_SIZE_PER_GPU = 32
101
+ # Test Model Epoch
102
+ _C.TEST.FLIP_TEST = False
103
+ _C.TEST.POST_PROCESS = False
104
+ _C.TEST.SHIFT_HEATMAP = False
105
+
106
+ _C.TEST.USE_GT_BBOX = False
107
+
108
+ # nms
109
+ _C.TEST.IMAGE_THRE = 0.1
110
+ _C.TEST.NMS_THRE = 0.6
111
+ _C.TEST.SOFT_NMS = False
112
+ _C.TEST.OKS_THRE = 0.5
113
+ _C.TEST.IN_VIS_THRE = 0.0
114
+ _C.TEST.COCO_BBOX_FILE = ''
115
+ _C.TEST.BBOX_THRE = 1.0
116
+ _C.TEST.MODEL_FILE = ''
117
+
118
+ # debug
119
+ _C.DEBUG = CN()
120
+ _C.DEBUG.DEBUG = False
121
+ _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
122
+ _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
123
+ _C.DEBUG.SAVE_HEATMAPS_GT = False
124
+ _C.DEBUG.SAVE_HEATMAPS_PRED = False
125
+
126
+
127
+ def update_config(cfg, config_file):
128
+ cfg.defrost()
129
+ cfg.merge_from_file(config_file)
130
+ cfg.freeze()
131
+
132
+
133
+ if __name__ == '__main__':
134
+ import sys
135
+ with open(sys.argv[1], 'w') as f:
136
+ print(_C, file=f)
137
+
main/postometro_utils/pose_resnet.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import os
12
+ import logging
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from collections import OrderedDict
17
+
18
+
19
+ BN_MOMENTUM = 0.1
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def conv3x3(in_planes, out_planes, stride=1):
24
+ """3x3 convolution with padding"""
25
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26
+ padding=1, bias=False)
27
+
28
+
29
+ class BasicBlock(nn.Module):
30
+ expansion = 1
31
+
32
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
33
+ super(BasicBlock, self).__init__()
34
+ self.conv1 = conv3x3(inplanes, planes, stride)
35
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = conv3x3(planes, planes)
38
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
39
+ self.downsample = downsample
40
+ self.stride = stride
41
+
42
+ def forward(self, x):
43
+ residual = x
44
+
45
+ out = self.conv1(x)
46
+ out = self.bn1(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv2(out)
50
+ out = self.bn2(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+
55
+ out += residual
56
+ out = self.relu(out)
57
+
58
+ return out
59
+
60
+
61
+ class Bottleneck(nn.Module):
62
+ expansion = 4
63
+
64
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
65
+ super(Bottleneck, self).__init__()
66
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
68
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69
+ padding=1, bias=False)
70
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
71
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
72
+ bias=False)
73
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
74
+ momentum=BN_MOMENTUM)
75
+ self.relu = nn.ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ residual = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ residual = self.downsample(x)
95
+
96
+ out += residual
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+
102
+ class Bottleneck_CAFFE(nn.Module):
103
+ expansion = 4
104
+
105
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
106
+ super(Bottleneck_CAFFE, self).__init__()
107
+ # add stride to conv1x1
108
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
110
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
111
+ padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
113
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
114
+ bias=False)
115
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
116
+ momentum=BN_MOMENTUM)
117
+ self.relu = nn.ReLU(inplace=True)
118
+ self.downsample = downsample
119
+ self.stride = stride
120
+
121
+ def forward(self, x):
122
+ residual = x
123
+
124
+ out = self.conv1(x)
125
+ out = self.bn1(out)
126
+ out = self.relu(out)
127
+
128
+ out = self.conv2(out)
129
+ out = self.bn2(out)
130
+ out = self.relu(out)
131
+
132
+ out = self.conv3(out)
133
+ out = self.bn3(out)
134
+
135
+ if self.downsample is not None:
136
+ residual = self.downsample(x)
137
+
138
+ out += residual
139
+ out = self.relu(out)
140
+
141
+ return out
142
+
143
+
144
+ class PoseResNet(nn.Module):
145
+
146
+ def __init__(self, block, layers, cfg, **kwargs):
147
+ self.inplanes = 64
148
+ extra = cfg.MODEL.EXTRA
149
+ self.deconv_with_bias = extra.DECONV_WITH_BIAS
150
+
151
+ super(PoseResNet, self).__init__()
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
153
+ bias=False)
154
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
155
+ self.relu = nn.ReLU(inplace=True)
156
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
157
+ self.layer1 = self._make_layer(block, 64, layers[0])
158
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
159
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
160
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
161
+
162
+ # used for deconv layers
163
+ # self.deconv_layers = self._make_deconv_layer(
164
+ # extra.NUM_DECONV_LAYERS,
165
+ # extra.NUM_DECONV_FILTERS,
166
+ # extra.NUM_DECONV_KERNELS,
167
+ # )
168
+
169
+ # self.final_layer = nn.Conv2d(
170
+ # in_channels=extra.NUM_DECONV_FILTERS[-1],
171
+ # out_channels=cfg.MODEL.NUM_JOINTS,
172
+ # kernel_size=extra.FINAL_CONV_KERNEL,
173
+ # stride=1,
174
+ # padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
175
+ # )
176
+
177
+ def _make_layer(self, block, planes, blocks, stride=1):
178
+ downsample = None
179
+ if stride != 1 or self.inplanes != planes * block.expansion:
180
+ downsample = nn.Sequential(
181
+ nn.Conv2d(self.inplanes, planes * block.expansion,
182
+ kernel_size=1, stride=stride, bias=False),
183
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
184
+ )
185
+
186
+ layers = []
187
+ layers.append(block(self.inplanes, planes, stride, downsample))
188
+ self.inplanes = planes * block.expansion
189
+ for i in range(1, blocks):
190
+ layers.append(block(self.inplanes, planes))
191
+
192
+ return nn.Sequential(*layers)
193
+
194
+ def _get_deconv_cfg(self, deconv_kernel, index):
195
+ if deconv_kernel == 4:
196
+ padding = 1
197
+ output_padding = 0
198
+ elif deconv_kernel == 3:
199
+ padding = 1
200
+ output_padding = 1
201
+ elif deconv_kernel == 2:
202
+ padding = 0
203
+ output_padding = 0
204
+
205
+ return deconv_kernel, padding, output_padding
206
+
207
+ def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
208
+ assert num_layers == len(num_filters), \
209
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
210
+ assert num_layers == len(num_kernels), \
211
+ 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
212
+
213
+ layers = []
214
+ for i in range(num_layers):
215
+ kernel, padding, output_padding = \
216
+ self._get_deconv_cfg(num_kernels[i], i)
217
+
218
+ planes = num_filters[i]
219
+ layers.append(
220
+ nn.ConvTranspose2d(
221
+ in_channels=self.inplanes,
222
+ out_channels=planes,
223
+ kernel_size=kernel,
224
+ stride=2,
225
+ padding=padding,
226
+ output_padding=output_padding,
227
+ bias=self.deconv_with_bias))
228
+ layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
229
+ layers.append(nn.ReLU(inplace=True))
230
+ self.inplanes = planes
231
+
232
+ return nn.Sequential(*layers)
233
+
234
+ def forward(self, x, skip_early = False, use_pct = False):
235
+ if not use_pct:
236
+ x = self.conv1(x)
237
+ x = self.bn1(x)
238
+ x = self.relu(x)
239
+ x = self.maxpool(x)
240
+ x = self.layer1(x)
241
+ x = self.layer2(x)
242
+ x = self.layer3(x)
243
+ x = self.layer4(x)
244
+
245
+ return x
246
+
247
+ if skip_early:
248
+ x = self.conv1(x)
249
+ x = self.bn1(x)
250
+ x = self.relu(x)
251
+ x = self.maxpool(x)
252
+ return x
253
+
254
+ x = self.layer1(x)
255
+ x = self.layer2(x)
256
+ x = self.layer3(x)
257
+ x = self.layer4(x)
258
+
259
+ return x
260
+
261
+ def init_weights(self, pretrained=''):
262
+ if os.path.isfile(pretrained):
263
+ # pretrained_state_dict = torch.load(pretrained)
264
+ logger.info('=> loading pretrained model {}'.format(pretrained))
265
+ # self.load_state_dict(pretrained_state_dict, strict=False)
266
+ checkpoint = torch.load(pretrained)
267
+ if isinstance(checkpoint, OrderedDict):
268
+ state_dict = checkpoint
269
+ elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
270
+ state_dict_old = checkpoint['state_dict']
271
+ state_dict = OrderedDict()
272
+ # delete 'module.' because it is saved from DataParallel module
273
+ for key in state_dict_old.keys():
274
+ if key.startswith('module.'):
275
+ # state_dict[key[7:]] = state_dict[key]
276
+ # state_dict.pop(key)
277
+ state_dict[key[7:]] = state_dict_old[key]
278
+ else:
279
+ state_dict[key] = state_dict_old[key]
280
+ else:
281
+ raise RuntimeError(
282
+ 'No state_dict found in checkpoint file {}'.format(pretrained))
283
+ state_dict_old = state_dict
284
+ state_dict = OrderedDict()
285
+ for k,v in state_dict_old.items():
286
+ if 'deconv_layers' in k or 'final_layer' in k:
287
+ continue
288
+ else:
289
+ state_dict[k] = state_dict_old[k]
290
+ self.load_state_dict(state_dict, strict=True)
291
+ else:
292
+ logger.error('=> imagenet pretrained model dose not exist')
293
+ logger.error('=> please download it first')
294
+ raise ValueError('imagenet pretrained model does not exist')
295
+
296
+
297
+ resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
298
+ 34: (BasicBlock, [3, 4, 6, 3]),
299
+ 50: (Bottleneck, [3, 4, 6, 3]),
300
+ 101: (Bottleneck, [3, 4, 23, 3]),
301
+ 152: (Bottleneck, [3, 8, 36, 3])}
302
+
303
+
304
+ def get_pose_net(cfg, is_train, **kwargs):
305
+ num_layers = cfg.MODEL.EXTRA.NUM_LAYERS
306
+ style = cfg.MODEL.STYLE
307
+
308
+ block_class, layers = resnet_spec[num_layers]
309
+
310
+ if style == 'caffe':
311
+ block_class = Bottleneck_CAFFE
312
+
313
+ model = PoseResNet(block_class, layers, cfg, **kwargs)
314
+
315
+ if is_train and cfg.MODEL.INIT_WEIGHTS:
316
+ model.init_weights(cfg.MODEL.PRETRAINED)
317
+
318
+ return model
main/postometro_utils/pose_resnet_config.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft
3
+ # Licensed under the MIT License.
4
+ # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+
11
+ import os
12
+ import yaml
13
+
14
+ import numpy as np
15
+ from easydict import EasyDict as edict
16
+
17
+
18
+ config = edict()
19
+
20
+ config.OUTPUT_DIR = ''
21
+ config.LOG_DIR = ''
22
+ config.DATA_DIR = ''
23
+ config.GPUS = '0'
24
+ config.WORKERS = 4
25
+ config.PRINT_FREQ = 20
26
+
27
+ # Cudnn related params
28
+ config.CUDNN = edict()
29
+ config.CUDNN.BENCHMARK = True
30
+ config.CUDNN.DETERMINISTIC = False
31
+ config.CUDNN.ENABLED = True
32
+
33
+ # pose_resnet related params
34
+ POSE_RESNET = edict()
35
+ POSE_RESNET.NUM_LAYERS = 50
36
+ POSE_RESNET.DECONV_WITH_BIAS = False
37
+ POSE_RESNET.NUM_DECONV_LAYERS = 3
38
+ POSE_RESNET.NUM_DECONV_FILTERS = [256, 256, 256]
39
+ POSE_RESNET.NUM_DECONV_KERNELS = [4, 4, 4]
40
+ POSE_RESNET.FINAL_CONV_KERNEL = 1
41
+ POSE_RESNET.TARGET_TYPE = 'gaussian'
42
+ POSE_RESNET.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
43
+ POSE_RESNET.SIGMA = 2
44
+
45
+ MODEL_EXTRAS = {
46
+ 'pose_resnet': POSE_RESNET,
47
+ }
48
+
49
+ # common params for NETWORK
50
+ config.MODEL = edict()
51
+ config.MODEL.NAME = 'pose_resnet'
52
+ config.MODEL.INIT_WEIGHTS = True
53
+ config.MODEL.PRETRAINED = ''
54
+ config.MODEL.NUM_JOINTS = 16
55
+ config.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
56
+ config.MODEL.EXTRA = MODEL_EXTRAS[config.MODEL.NAME]
57
+
58
+ config.MODEL.STYLE = 'pytorch'
59
+
60
+ config.LOSS = edict()
61
+ config.LOSS.USE_TARGET_WEIGHT = True
62
+
63
+ # DATASET related params
64
+ config.DATASET = edict()
65
+ config.DATASET.ROOT = ''
66
+ config.DATASET.DATASET = 'mpii'
67
+ config.DATASET.TRAIN_SET = 'train'
68
+ config.DATASET.TEST_SET = 'valid'
69
+ config.DATASET.DATA_FORMAT = 'jpg'
70
+ config.DATASET.HYBRID_JOINTS_TYPE = ''
71
+ config.DATASET.SELECT_DATA = False
72
+
73
+ # training data augmentation
74
+ config.DATASET.FLIP = True
75
+ config.DATASET.SCALE_FACTOR = 0.25
76
+ config.DATASET.ROT_FACTOR = 30
77
+
78
+ # train
79
+ config.TRAIN = edict()
80
+
81
+ config.TRAIN.LR_FACTOR = 0.1
82
+ config.TRAIN.LR_STEP = [90, 110]
83
+ config.TRAIN.LR = 0.001
84
+
85
+ config.TRAIN.OPTIMIZER = 'adam'
86
+ config.TRAIN.MOMENTUM = 0.9
87
+ config.TRAIN.WD = 0.0001
88
+ config.TRAIN.NESTEROV = False
89
+ config.TRAIN.GAMMA1 = 0.99
90
+ config.TRAIN.GAMMA2 = 0.0
91
+
92
+ config.TRAIN.BEGIN_EPOCH = 0
93
+ config.TRAIN.END_EPOCH = 140
94
+
95
+ config.TRAIN.RESUME = False
96
+ config.TRAIN.CHECKPOINT = ''
97
+
98
+ config.TRAIN.BATCH_SIZE = 32
99
+ config.TRAIN.SHUFFLE = True
100
+
101
+ # testing
102
+ config.TEST = edict()
103
+
104
+ # size of images for each device
105
+ config.TEST.BATCH_SIZE = 32
106
+ # Test Model Epoch
107
+ config.TEST.FLIP_TEST = False
108
+ config.TEST.POST_PROCESS = True
109
+ config.TEST.SHIFT_HEATMAP = True
110
+
111
+ config.TEST.USE_GT_BBOX = False
112
+ # nms
113
+ config.TEST.OKS_THRE = 0.5
114
+ config.TEST.IN_VIS_THRE = 0.0
115
+ config.TEST.COCO_BBOX_FILE = ''
116
+ config.TEST.BBOX_THRE = 1.0
117
+ config.TEST.MODEL_FILE = ''
118
+ config.TEST.IMAGE_THRE = 0.0
119
+ config.TEST.NMS_THRE = 1.0
120
+
121
+ # debug
122
+ config.DEBUG = edict()
123
+ config.DEBUG.DEBUG = False
124
+ config.DEBUG.SAVE_BATCH_IMAGES_GT = False
125
+ config.DEBUG.SAVE_BATCH_IMAGES_PRED = False
126
+ config.DEBUG.SAVE_HEATMAPS_GT = False
127
+ config.DEBUG.SAVE_HEATMAPS_PRED = False
128
+
129
+
130
+ def _update_dict(k, v):
131
+ if k == 'DATASET':
132
+ if 'MEAN' in v and v['MEAN']:
133
+ v['MEAN'] = np.array([eval(x) if isinstance(x, str) else x
134
+ for x in v['MEAN']])
135
+ if 'STD' in v and v['STD']:
136
+ v['STD'] = np.array([eval(x) if isinstance(x, str) else x
137
+ for x in v['STD']])
138
+ if k == 'MODEL':
139
+ if 'EXTRA' in v and 'HEATMAP_SIZE' in v['EXTRA']:
140
+ if isinstance(v['EXTRA']['HEATMAP_SIZE'], int):
141
+ v['EXTRA']['HEATMAP_SIZE'] = np.array(
142
+ [v['EXTRA']['HEATMAP_SIZE'], v['EXTRA']['HEATMAP_SIZE']])
143
+ else:
144
+ v['EXTRA']['HEATMAP_SIZE'] = np.array(
145
+ v['EXTRA']['HEATMAP_SIZE'])
146
+ if 'IMAGE_SIZE' in v:
147
+ if isinstance(v['IMAGE_SIZE'], int):
148
+ v['IMAGE_SIZE'] = np.array([v['IMAGE_SIZE'], v['IMAGE_SIZE']])
149
+ else:
150
+ v['IMAGE_SIZE'] = np.array(v['IMAGE_SIZE'])
151
+ for vk, vv in v.items():
152
+ if vk in config[k]:
153
+ config[k][vk] = vv
154
+ else:
155
+ raise ValueError("{}.{} not exist in config.py".format(k, vk))
156
+
157
+
158
+ def update_config(config_file):
159
+ exp_config = None
160
+ with open(config_file) as f:
161
+ exp_config = edict(yaml.load(f))
162
+ for k, v in exp_config.items():
163
+ if k in config:
164
+ if isinstance(v, dict):
165
+ _update_dict(k, v)
166
+ else:
167
+ if k == 'SCALES':
168
+ config[k][0] = (tuple(v))
169
+ else:
170
+ config[k] = v
171
+ else:
172
+ raise ValueError("{} not exist in config.py".format(k))
173
+
174
+
175
+ def gen_config(config_file):
176
+ cfg = dict(config)
177
+ for k, v in cfg.items():
178
+ if isinstance(v, edict):
179
+ cfg[k] = dict(v)
180
+
181
+ with open(config_file, 'w') as f:
182
+ yaml.dump(dict(cfg), f, default_flow_style=False)
183
+
184
+
185
+ def update_dir(model_dir, log_dir, data_dir):
186
+ if model_dir:
187
+ config.OUTPUT_DIR = model_dir
188
+
189
+ if log_dir:
190
+ config.LOG_DIR = log_dir
191
+
192
+ if data_dir:
193
+ config.DATA_DIR = data_dir
194
+
195
+ config.DATASET.ROOT = os.path.join(
196
+ config.DATA_DIR, config.DATASET.ROOT)
197
+
198
+ config.TEST.COCO_BBOX_FILE = os.path.join(
199
+ config.DATA_DIR, config.TEST.COCO_BBOX_FILE)
200
+
201
+ config.MODEL.PRETRAINED = os.path.join(
202
+ config.DATA_DIR, config.MODEL.PRETRAINED)
203
+
204
+
205
+ def get_model_name(cfg):
206
+ name = cfg.MODEL.NAME
207
+ full_name = cfg.MODEL.NAME
208
+ extra = cfg.MODEL.EXTRA
209
+ if name in ['pose_resnet']:
210
+ name = '{model}_{num_layers}'.format(
211
+ model=name,
212
+ num_layers=extra.NUM_LAYERS)
213
+ deconv_suffix = ''.join(
214
+ 'd{}'.format(num_filters)
215
+ for num_filters in extra.NUM_DECONV_FILTERS)
216
+ full_name = '{height}x{width}_{name}_{deconv_suffix}'.format(
217
+ height=cfg.MODEL.IMAGE_SIZE[1],
218
+ width=cfg.MODEL.IMAGE_SIZE[0],
219
+ name=name,
220
+ deconv_suffix=deconv_suffix)
221
+ else:
222
+ raise ValueError('Unkown model: {}'.format(cfg.MODEL))
223
+
224
+ return name, full_name
225
+
226
+
227
+ if __name__ == '__main__':
228
+ import sys
229
+ gen_config(sys.argv[1])
main/postometro_utils/pose_w48_256x192_adam_lr1e-3.yaml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AUTO_RESUME: true
2
+ CUDNN:
3
+ BENCHMARK: true
4
+ DETERMINISTIC: false
5
+ ENABLED: true
6
+ DATA_DIR: ''
7
+ GPUS: (0,1,2,3)
8
+ OUTPUT_DIR: 'output'
9
+ LOG_DIR: 'log'
10
+ WORKERS: 24
11
+ PRINT_FREQ: 100
12
+
13
+ DATASET:
14
+ COLOR_RGB: true
15
+ DATASET: 'coco'
16
+ DATA_FORMAT: jpg
17
+ FLIP: true
18
+ NUM_JOINTS_HALF_BODY: 8
19
+ PROB_HALF_BODY: 0.3
20
+ ROOT: 'data/coco/'
21
+ ROT_FACTOR: 45
22
+ SCALE_FACTOR: 0.35
23
+ TEST_SET: 'val2017'
24
+ TRAIN_SET: 'train2017'
25
+ MODEL:
26
+ INIT_WEIGHTS: true
27
+ NAME: pose_hrnet
28
+ NUM_JOINTS: 17
29
+ PRETRAINED: 'models/pytorch/imagenet/hrnet_w48-8ef0771d.pth'
30
+ TARGET_TYPE: gaussian
31
+ IMAGE_SIZE:
32
+ - 192
33
+ - 256
34
+ HEATMAP_SIZE:
35
+ - 48
36
+ - 64
37
+ SIGMA: 2
38
+ EXTRA:
39
+ PRETRAINED_LAYERS:
40
+ - 'conv1'
41
+ - 'bn1'
42
+ - 'conv2'
43
+ - 'bn2'
44
+ - 'layer1'
45
+ - 'transition1'
46
+ - 'stage2'
47
+ - 'transition2'
48
+ - 'stage3'
49
+ - 'transition3'
50
+ - 'stage4'
51
+ FINAL_CONV_KERNEL: 1
52
+ STAGE2:
53
+ NUM_MODULES: 1
54
+ NUM_BRANCHES: 2
55
+ BLOCK: BASIC
56
+ NUM_BLOCKS:
57
+ - 4
58
+ - 4
59
+ NUM_CHANNELS:
60
+ - 48
61
+ - 96
62
+ FUSE_METHOD: SUM
63
+ STAGE3:
64
+ NUM_MODULES: 4
65
+ NUM_BRANCHES: 3
66
+ BLOCK: BASIC
67
+ NUM_BLOCKS:
68
+ - 4
69
+ - 4
70
+ - 4
71
+ NUM_CHANNELS:
72
+ - 48
73
+ - 96
74
+ - 192
75
+ FUSE_METHOD: SUM
76
+ STAGE4:
77
+ NUM_MODULES: 3
78
+ NUM_BRANCHES: 4
79
+ BLOCK: BASIC
80
+ NUM_BLOCKS:
81
+ - 4
82
+ - 4
83
+ - 4
84
+ - 4
85
+ NUM_CHANNELS:
86
+ - 48
87
+ - 96
88
+ - 192
89
+ - 384
90
+ FUSE_METHOD: SUM
91
+ LOSS:
92
+ USE_TARGET_WEIGHT: true
93
+ TRAIN:
94
+ BATCH_SIZE_PER_GPU: 32
95
+ SHUFFLE: true
96
+ BEGIN_EPOCH: 0
97
+ END_EPOCH: 210
98
+ OPTIMIZER: adam
99
+ LR: 0.001
100
+ LR_FACTOR: 0.1
101
+ LR_STEP:
102
+ - 170
103
+ - 200
104
+ WD: 0.0001
105
+ GAMMA1: 0.99
106
+ GAMMA2: 0.0
107
+ MOMENTUM: 0.9
108
+ NESTEROV: false
109
+ TEST:
110
+ BATCH_SIZE_PER_GPU: 32
111
+ COCO_BBOX_FILE: 'data/coco/person_detection_results/COCO_val2017_detections_AP_H_56_person.json'
112
+ BBOX_THRE: 1.0
113
+ IMAGE_THRE: 0.0
114
+ IN_VIS_THRE: 0.2
115
+ MODEL_FILE: ''
116
+ NMS_THRE: 1.0
117
+ OKS_THRE: 0.9
118
+ USE_GT_BBOX: true
119
+ FLIP_TEST: true
120
+ POST_PROCESS: true
121
+ SHIFT_HEATMAP: true
122
+ DEBUG:
123
+ DEBUG: true
124
+ SAVE_BATCH_IMAGES_GT: true
125
+ SAVE_BATCH_IMAGES_PRED: true
126
+ SAVE_HEATMAPS_GT: true
127
+ SAVE_HEATMAPS_PRED: true
main/postometro_utils/positional_encoding.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------------------------
2
+ # FastMETRO Official Code
3
+ # Copyright (c) POSTECH Algorithmic Machine Intelligence Lab. (P-AMI Lab.) All Rights Reserved
4
+ # Licensed under the MIT license.
5
+ # ----------------------------------------------------------------------------------------------
6
+ # Modified from DETR (https://github.com/facebookresearch/detr)
7
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved [see https://github.com/facebookresearch/detr/blob/main/LICENSE for details]
8
+ # ----------------------------------------------------------------------------------------------
9
+
10
+ import math
11
+ import torch
12
+ from torch import nn
13
+
14
+ class PositionEmbeddingSine(nn.Module):
15
+ """
16
+ This is a more standard version of the position embedding, very similar to the one
17
+ used by the Attention is all you need paper, generalized to work on images.
18
+ """
19
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
20
+ super().__init__()
21
+ self.num_pos_feats = num_pos_feats
22
+ self.temperature = temperature
23
+ self.normalize = normalize
24
+ if scale is not None and normalize is False:
25
+ raise ValueError("normalize should be True if scale is passed")
26
+ if scale is None:
27
+ scale = 2 * math.pi
28
+ self.scale = scale
29
+
30
+ def forward(self, bs, h, w, device):
31
+ ones = torch.ones((bs, h, w), dtype=torch.bool, device=device)
32
+ y_embed = ones.cumsum(1, dtype=torch.float32)
33
+ x_embed = ones.cumsum(2, dtype=torch.float32)
34
+ if self.normalize:
35
+ eps = 1e-6
36
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
37
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
38
+
39
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=device)
40
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.num_pos_feats) # cancel warning
41
+
42
+ pos_x = x_embed[:, :, :, None] / dim_t
43
+ pos_y = y_embed[:, :, :, None] / dim_t
44
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
45
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
46
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
47
+ return pos
48
+
49
+
50
+ def build_position_encoding(pos_type, hidden_dim):
51
+ N_steps = hidden_dim // 2
52
+ if pos_type == 'sine':
53
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
54
+ else:
55
+ raise ValueError("not supported {pos_type}")
56
+
57
+ return position_embedding
main/postometro_utils/renderer_pyrender.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ----------------------------------------------------------------------------------------------
2
+ # Modified from Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
3
+ # Copyright (c) Hongsuk Choi. All Rights Reserved [see https://github.com/hongsukchoi/Pose2Mesh_RELEASE/blob/main/LICENSE for details]
4
+ # ----------------------------------------------------------------------------------------------
5
+
6
+ import os
7
+ os.environ['PYOPENGL_PLATFORM'] = 'osmesa'
8
+ import torch
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+ import math
12
+ import cv2
13
+ import trimesh
14
+ import pyrender
15
+ from pyrender.constants import RenderFlags
16
+
17
+ def crop_bbox(bbox_meta, resolution, rgb, valid_mask):
18
+ bbox, original_img_height, original_img_width = bbox_meta['bbox'], *bbox_meta['img_hw']
19
+ start_x = int(bbox[0])
20
+ start_y = int(bbox[1])
21
+ end_x = start_x + int(resolution[0]) # w + start_x
22
+ end_y = start_y + int(resolution[1]) # h + start_y
23
+ real_start_x, real_start_y, real_end_x, real_end_y = max(0, start_x), max(0, start_y), min(original_img_width, end_x), min(original_img_height, end_y)
24
+ max_height, max_width = rgb.shape[:2]
25
+ real_rgb = rgb[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height),
26
+ (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy()
27
+ real_valid_mask = valid_mask[(real_start_y - start_y):((real_end_y - end_y) if real_end_y < end_y else max_height),
28
+ (real_start_x - start_x):((real_end_x - end_x) if real_end_x < end_x else max_width)].copy()
29
+ return {'bbox': [real_start_x, real_start_y, real_end_x, real_end_y], 'img_hw': [original_img_height, original_img_width]}, real_rgb, real_valid_mask
30
+
31
+
32
+ class WeakPerspectiveCamera(pyrender.Camera):
33
+ def __init__(self, scale, translation, znear=pyrender.camera.DEFAULT_Z_NEAR, zfar=None, name=None):
34
+ super(WeakPerspectiveCamera, self).__init__(znear=znear, zfar=zfar, name=name)
35
+ self.scale = scale
36
+ self.translation = translation
37
+
38
+ def get_projection_matrix(self, width=None, height=None):
39
+ P = np.eye(4)
40
+ P[0, 0] = self.scale[0]
41
+ P[1, 1] = self.scale[1]
42
+ P[0, 3] = self.translation[0] * self.scale[0]
43
+ P[1, 3] = -self.translation[1] * self.scale[1]
44
+ P[2, 2] = -1
45
+ return P
46
+
47
+
48
+ class PyRender_Renderer:
49
+ def __init__(self, resolution=(256, 256), faces=None, orig_img=False, wireframe=False):
50
+ self.resolution = resolution
51
+ self.faces = faces
52
+ self.orig_img = orig_img
53
+ self.wireframe = wireframe
54
+ self.renderer = pyrender.OffscreenRenderer(viewport_width=self.resolution[0],
55
+ viewport_height=self.resolution[1],
56
+ point_size=1.0)
57
+
58
+ # set the scene & create light source
59
+ self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], ambient_light=(0.05, 0.05, 0.05))
60
+ light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=3.0)
61
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(-45), [1, 0, 0])
62
+ self.scene.add(light, pose=light_pose)
63
+ light_pose = trimesh.transformations.rotation_matrix(np.radians(45), [0, 1, 0])
64
+ self.scene.add(light, pose=light_pose)
65
+
66
+ # mesh colors
67
+ self.colors_dict = {'blue': np.array([0.35, 0.60, 0.92]),
68
+ 'neutral': np.array([0.7, 0.7, 0.6]),
69
+ 'pink': np.array([0.7, 0.5, 0.5]),
70
+ 'white': np.array([1.0, 0.98, 0.94]),
71
+ 'green': np.array([0.5, 0.55, 0.3]),
72
+ 'sky': np.array([0.3, 0.5, 0.55])}
73
+
74
+ def __call__(self, verts, bbox_meta, img=np.zeros((224, 224, 3)), cam=np.array([1, 0, 0]),
75
+ angle=None, axis=None, mesh_filename=None, color_type=None, color=[0.7, 0.7, 0.6]):
76
+ if color_type != None:
77
+ color = self.colors_dict[color_type]
78
+
79
+ mesh = trimesh.Trimesh(vertices=verts, faces=self.faces, process=False)
80
+ Rx = trimesh.transformations.rotation_matrix(math.radians(180), [1, 0, 0])
81
+ mesh.apply_transform(Rx)
82
+ if mesh_filename is not None:
83
+ mesh.export(mesh_filename)
84
+ if angle and axis:
85
+ R = trimesh.transformations.rotation_matrix(math.radians(angle), axis)
86
+ mesh.apply_transform(R)
87
+
88
+ sy, tx, ty = cam
89
+ sx = sy
90
+ camera = WeakPerspectiveCamera(scale=[sx, sy], translation=[tx, ty], zfar=1000.0)
91
+
92
+ material = pyrender.MetallicRoughnessMaterial(
93
+ metallicFactor=0.2,
94
+ roughnessFactor=1.0,
95
+ alphaMode='OPAQUE',
96
+ baseColorFactor=(color[0], color[1], color[2], 1.0)
97
+ )
98
+
99
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
100
+ mesh_node = self.scene.add(mesh, 'mesh')
101
+
102
+ camera_pose = np.eye(4)
103
+ cam_node = self.scene.add(camera, pose=camera_pose)
104
+
105
+ if self.wireframe:
106
+ render_flags = RenderFlags.RGBA | RenderFlags.ALL_WIREFRAME
107
+ else:
108
+ render_flags = RenderFlags.RGBA
109
+
110
+ rgb, depth = self.renderer.render(self.scene, flags=render_flags)
111
+ valid_mask = (depth > 0)[:, :, np.newaxis] # bbox size
112
+ # adjust bbox (no out of boundary)
113
+ bbox_meta, rgb, valid_mask = crop_bbox(bbox_meta, [self.resolution[0], self.resolution[1]], rgb, valid_mask)
114
+ # parse bbox
115
+ start_x, start_y, end_x, end_y, original_img_height, original_img_width = *bbox_meta['bbox'], *bbox_meta['img_hw']
116
+ # start_x = int(bbox_meta['bbox'][0])
117
+ # start_y = int(bbox_meta['bbox'][1])
118
+ # end_x = start_x + int(self.resolution[0]) # w + start_x
119
+ # end_y = start_y + int(self.resolution[1]) # h + start_y
120
+ whole_img_mask = np.zeros((original_img_height, original_img_width,1))
121
+ whole_img_mask[start_y:end_y, start_x:end_x] = valid_mask
122
+ whole_rgb = np.zeros((original_img_height, original_img_width,4))
123
+ whole_rgb[start_y:end_y, start_x:end_x,:3] = rgb
124
+ output_img = whole_rgb[:, :, :3] * whole_img_mask + (1 - whole_img_mask) * img
125
+ image = output_img.astype(np.uint8)
126
+
127
+ self.scene.remove_node(mesh_node)
128
+ self.scene.remove_node(cam_node)
129
+
130
+ return image
131
+
132
+
133
+ def visualize_reconstruction_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000):
134
+ img = (img * 255).astype(np.uint8)
135
+ save_mesh_path = None
136
+ rend_color = color
137
+
138
+ # Render front view
139
+ rend_img = renderer(vertices,
140
+ img=img,
141
+ cam=camera,
142
+ color_type=rend_color,
143
+ mesh_filename=save_mesh_path)
144
+
145
+ combined = np.hstack([img, rend_img])
146
+
147
+ return combined
148
+
149
+ def visualize_reconstruction_multi_view_pyrender(img, vertices, camera, renderer, color='blue', focal_length=1000):
150
+ img = (img * 255).astype(np.uint8)
151
+ save_mesh_path = None
152
+ rend_color = color
153
+
154
+ # Render front view
155
+ rend_img = renderer(vertices,
156
+ img=img,
157
+ cam=camera,
158
+ color_type=rend_color,
159
+ mesh_filename=save_mesh_path)
160
+
161
+ # Render side views
162
+ aroundy0 = cv2.Rodrigues(np.array([0, np.radians(0.), 0]))[0]
163
+ aroundy1 = cv2.Rodrigues(np.array([0, np.radians(90.), 0]))[0]
164
+ aroundy2 = cv2.Rodrigues(np.array([0, np.radians(180.), 0]))[0]
165
+ aroundy3 = cv2.Rodrigues(np.array([0, np.radians(270.), 0]))[0]
166
+ aroundy4 = cv2.Rodrigues(np.array([0, np.radians(45.), 0]))[0]
167
+ center = vertices.mean(axis=0)
168
+ rot_vertices0 = np.dot((vertices - center), aroundy0) + center
169
+ rot_vertices1 = np.dot((vertices - center), aroundy1) + center
170
+ rot_vertices2 = np.dot((vertices - center), aroundy2) + center
171
+ rot_vertices3 = np.dot((vertices - center), aroundy3) + center
172
+ rot_vertices4 = np.dot((vertices - center), aroundy4) + center
173
+
174
+ # Render side-view shape
175
+ img_side0 = renderer(rot_vertices0,
176
+ img=np.ones_like(img)*255,
177
+ cam=camera,
178
+ color_type=rend_color,
179
+ mesh_filename=save_mesh_path)
180
+ img_side1 = renderer(rot_vertices1,
181
+ img=np.ones_like(img)*255,
182
+ cam=camera,
183
+ color_type=rend_color,
184
+ mesh_filename=save_mesh_path)
185
+ img_side2 = renderer(rot_vertices2,
186
+ img=np.ones_like(img)*255,
187
+ cam=camera,
188
+ color_type=rend_color,
189
+ mesh_filename=save_mesh_path)
190
+ img_side3 = renderer(rot_vertices3,
191
+ img=np.ones_like(img)*255,
192
+ cam=camera,
193
+ color_type=rend_color,
194
+ mesh_filename=save_mesh_path)
195
+ img_side4 = renderer(rot_vertices4,
196
+ img=np.ones_like(img)*255,
197
+ cam=camera,
198
+ color_type=rend_color,
199
+ mesh_filename=save_mesh_path)
200
+
201
+ combined = np.hstack([img, rend_img, img_side0, img_side1, img_side2, img_side3, img_side4])
202
+
203
+ return combined
204
+
205
+ def visualize_reconstruction_smpl_pyrender(img, vertices, camera, renderer, smpl_vertices, color='blue', focal_length=1000):
206
+ img = (img * 255).astype(np.uint8)
207
+ save_mesh_path = None
208
+ rend_color = color
209
+
210
+ # Render front view
211
+ rend_img = renderer(vertices,
212
+ img=img,
213
+ cam=camera,
214
+ color_type=rend_color,
215
+ mesh_filename=save_mesh_path)
216
+
217
+ rend_img_smpl = renderer(smpl_vertices,
218
+ img=img,
219
+ cam=camera,
220
+ color_type=rend_color,
221
+ mesh_filename=save_mesh_path)
222
+
223
+ combined = np.hstack([img, rend_img, rend_img_smpl])
224
+
225
+ return combined