luoxue-star commited on
Commit
48cafca
·
1 Parent(s): 0410ce8

init commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ example_data
2
+ *.pyc
3
+ demo_out
4
+ logs/
amr/__init__.py ADDED
File without changes
amr/configs/__init__.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict
3
+ from yacs.config import CfgNode as CN
4
+
5
+ CACHE_DIR_HAMER = "./logs"
6
+
7
+
8
+ def to_lower(x: Dict) -> Dict:
9
+ """
10
+ Convert all dictionary keys to lowercase
11
+ Args:
12
+ x (dict): Input dictionary
13
+ Returns:
14
+ dict: Output dictionary with all keys converted to lowercase
15
+ """
16
+ return {k.lower(): v for k, v in x.items()}
17
+
18
+
19
+ _C = CN(new_allowed=True)
20
+
21
+ _C.GENERAL = CN(new_allowed=True)
22
+ _C.GENERAL.RESUME = True
23
+ _C.GENERAL.TIME_TO_RUN = 3300
24
+ _C.GENERAL.VAL_STEPS = 100
25
+ _C.GENERAL.LOG_STEPS = 100
26
+ _C.GENERAL.CHECKPOINT_STEPS = 20000
27
+ _C.GENERAL.CHECKPOINT_DIR = "checkpoints"
28
+ _C.GENERAL.SUMMARY_DIR = "tensorboard"
29
+ _C.GENERAL.NUM_GPUS = 1
30
+ _C.GENERAL.NUM_WORKERS = 4
31
+ _C.GENERAL.MIXED_PRECISION = True
32
+ _C.GENERAL.ALLOW_CUDA = True
33
+ _C.GENERAL.PIN_MEMORY = False
34
+ _C.GENERAL.DISTRIBUTED = False
35
+ _C.GENERAL.LOCAL_RANK = 0
36
+ _C.GENERAL.USE_SYNCBN = False
37
+ _C.GENERAL.WORLD_SIZE = 1
38
+
39
+ _C.TRAIN = CN(new_allowed=True)
40
+ _C.TRAIN.NUM_EPOCHS = 100
41
+ _C.TRAIN.BATCH_SIZE = 32
42
+ _C.TRAIN.SHUFFLE = True
43
+ _C.TRAIN.WARMUP = False
44
+ _C.TRAIN.NORMALIZE_PER_IMAGE = False
45
+ _C.TRAIN.CLIP_GRAD = False
46
+ _C.TRAIN.CLIP_GRAD_VALUE = 1.0
47
+ _C.LOSS_WEIGHTS = CN(new_allowed=True)
48
+
49
+ _C.DATASETS = CN(new_allowed=True)
50
+
51
+ _C.MODEL = CN(new_allowed=True)
52
+ _C.MODEL.IMAGE_SIZE = 224
53
+
54
+ _C.EXTRA = CN(new_allowed=True)
55
+ _C.EXTRA.FOCAL_LENGTH = 5000
56
+
57
+ _C.DATASETS.CONFIG = CN(new_allowed=True)
58
+ _C.DATASETS.CONFIG.SCALE_FACTOR = 0.3
59
+ _C.DATASETS.CONFIG.ROT_FACTOR = 30
60
+ _C.DATASETS.CONFIG.TRANS_FACTOR = 0.02
61
+ _C.DATASETS.CONFIG.COLOR_SCALE = 0.2
62
+ _C.DATASETS.CONFIG.ROT_AUG_RATE = 0.6
63
+ _C.DATASETS.CONFIG.TRANS_AUG_RATE = 0.5
64
+ _C.DATASETS.CONFIG.DO_FLIP = False
65
+ _C.DATASETS.CONFIG.FLIP_AUG_RATE = 0.5
66
+ _C.DATASETS.CONFIG.EXTREME_CROP_AUG_RATE = 0.10
67
+
68
+
69
+ def default_config() -> CN:
70
+ """
71
+ Get a yacs CfgNode object with the default config values.
72
+ """
73
+ # Return a clone so that the defaults will not be altered
74
+ # This is for the "local variable" use pattern
75
+ return _C.clone()
76
+
77
+
78
+ def dataset_config() -> CN:
79
+ """
80
+ Get dataset config file
81
+ Returns:
82
+ CfgNode: Dataset config as a yacs CfgNode object.
83
+ """
84
+ cfg = CN(new_allowed=True)
85
+ config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'datasets_tar.yaml')
86
+ cfg.merge_from_file(config_file)
87
+ cfg.freeze()
88
+ return cfg
89
+
90
+
91
+ def get_config(config_file: str, merge: bool = True, update_cachedir: bool = False) -> CN:
92
+ """
93
+ Read a config file and optionally merge it with the default config file.
94
+ Args:
95
+ config_file (str): Path to config file.
96
+ merge (bool): Whether to merge with the default config or not.
97
+ Returns:
98
+ CfgNode: Config as a yacs CfgNode object.
99
+ """
100
+ if merge:
101
+ cfg = default_config()
102
+ else:
103
+ cfg = CN(new_allowed=True)
104
+ cfg.merge_from_file(config_file)
105
+
106
+ if update_cachedir:
107
+ def update_path(path: str) -> str:
108
+ if os.path.isabs(path):
109
+ return path
110
+ return os.path.join(CACHE_DIR_HAMER, path)
111
+
112
+ cfg.freeze()
113
+ return cfg
amr/datasets/__init__.py ADDED
File without changes
amr/datasets/utils.py ADDED
@@ -0,0 +1,1038 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Parts of the code are taken or adapted from
3
+ https://github.com/mkocabas/EpipolarPose/blob/master/lib/utils/img_utils.py
4
+ """
5
+ import torch
6
+ import numpy as np
7
+ from skimage.transform import rotate, resize
8
+ from skimage.filters import gaussian
9
+ import random
10
+ import cv2
11
+ from typing import List, Dict, Tuple
12
+ from yacs.config import CfgNode
13
+ from typing import Union
14
+
15
+
16
+ def expand_to_aspect_ratio(input_shape, target_aspect_ratio=None):
17
+ """Increase the size of the bounding box to match the target shape."""
18
+ if target_aspect_ratio is None:
19
+ return input_shape
20
+
21
+ try:
22
+ w, h = input_shape
23
+ except (ValueError, TypeError):
24
+ return input_shape
25
+
26
+ w_t, h_t = target_aspect_ratio
27
+ if h / w < h_t / w_t:
28
+ h_new = max(w * h_t / w_t, h)
29
+ w_new = w
30
+ else:
31
+ h_new = h
32
+ w_new = max(h * w_t / h_t, w)
33
+ if h_new < h or w_new < w:
34
+ breakpoint()
35
+ return np.array([w_new, h_new])
36
+
37
+
38
+ def do_augmentation(aug_config: CfgNode) -> Tuple:
39
+ """
40
+ Compute random augmentation parameters.
41
+ Args:
42
+ aug_config (CfgNode): Config containing augmentation parameters.
43
+ Returns:
44
+ scale (float): Box rescaling factor.
45
+ rot (float): Random image rotation.
46
+ do_flip (bool): Whether to flip image or not.
47
+ do_extreme_crop (bool): Whether to apply extreme cropping (as proposed in EFT).
48
+ color_scale (List): Color rescaling factor
49
+ tx (float): Random translation along the x axis.
50
+ ty (float): Random translation along the y axis.
51
+ """
52
+
53
+ tx = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
54
+ ty = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.TRANS_FACTOR
55
+ scale = np.clip(np.random.randn(), -1.0, 1.0) * aug_config.SCALE_FACTOR + 1.0
56
+ rot = np.clip(np.random.randn(), -2.0,
57
+ 2.0) * aug_config.ROT_FACTOR if random.random() <= aug_config.ROT_AUG_RATE else 0
58
+ do_flip = aug_config.DO_FLIP and random.random() <= aug_config.FLIP_AUG_RATE
59
+ do_extreme_crop = random.random() <= aug_config.EXTREME_CROP_AUG_RATE
60
+ extreme_crop_lvl = aug_config.get('EXTREME_CROP_AUG_LEVEL', 0)
61
+ # extreme_crop_lvl = 0
62
+ c_up = 1.0 + aug_config.COLOR_SCALE
63
+ c_low = 1.0 - aug_config.COLOR_SCALE
64
+ color_scale = [random.uniform(c_low, c_up), random.uniform(c_low, c_up), random.uniform(c_low, c_up)]
65
+ return scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty
66
+
67
+
68
+ def rotate_2d(pt_2d: np.array, rot_rad: float) -> np.array:
69
+ """
70
+ Rotate a 2D point on the x-y plane.
71
+ Args:
72
+ pt_2d (np.array): Input 2D point with shape (2,).
73
+ rot_rad (float): Rotation angle
74
+ Returns:
75
+ np.array: Rotated 2D point.
76
+ """
77
+ x = pt_2d[0]
78
+ y = pt_2d[1]
79
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
80
+ xx = x * cs - y * sn
81
+ yy = x * sn + y * cs
82
+ return np.array([xx, yy], dtype=np.float32)
83
+
84
+
85
+ def gen_trans_from_patch_cv(c_x: float, c_y: float,
86
+ src_width: float, src_height: float,
87
+ dst_width: float, dst_height: float,
88
+ scale: float, rot: float) -> np.array:
89
+ """
90
+ Create transformation matrix for the bounding box crop.
91
+ Args:
92
+ c_x (float): Bounding box center x coordinate in the original image.
93
+ c_y (float): Bounding box center y coordinate in the original image.
94
+ src_width (float): Bounding box width.
95
+ src_height (float): Bounding box height.
96
+ dst_width (float): Output box width.
97
+ dst_height (float): Output box height.
98
+ scale (float): Rescaling factor for the bounding box (augmentation).
99
+ rot (float): Random rotation applied to the box.
100
+ Returns:
101
+ trans (np.array): Target geometric transformation.
102
+ """
103
+ # augment size with scale
104
+ src_w = src_width * scale
105
+ src_h = src_height * scale
106
+ src_center = np.zeros(2)
107
+ src_center[0] = c_x
108
+ src_center[1] = c_y
109
+ # augment rotation
110
+ rot_rad = np.pi * rot / 180
111
+ src_downdir = rotate_2d(np.array([0, src_h * 0.5], dtype=np.float32), rot_rad)
112
+ src_rightdir = rotate_2d(np.array([src_w * 0.5, 0], dtype=np.float32), rot_rad)
113
+
114
+ dst_w = dst_width
115
+ dst_h = dst_height
116
+ dst_center = np.array([dst_w * 0.5, dst_h * 0.5], dtype=np.float32)
117
+ dst_downdir = np.array([0, dst_h * 0.5], dtype=np.float32)
118
+ dst_rightdir = np.array([dst_w * 0.5, 0], dtype=np.float32)
119
+
120
+ src = np.zeros((3, 2), dtype=np.float32)
121
+ src[0, :] = src_center
122
+ src[1, :] = src_center + src_downdir
123
+ src[2, :] = src_center + src_rightdir
124
+
125
+ dst = np.zeros((3, 2), dtype=np.float32)
126
+ dst[0, :] = dst_center
127
+ dst[1, :] = dst_center + dst_downdir
128
+ dst[2, :] = dst_center + dst_rightdir
129
+
130
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
131
+
132
+ return trans
133
+
134
+
135
+ def trans_point2d(pt_2d: np.array, trans: np.array):
136
+ """
137
+ Transform a 2D point using translation matrix trans.
138
+ Args:
139
+ pt_2d (np.array): Input 2D point with shape (2,).
140
+ trans (np.array): Transformation matrix.
141
+ Returns:
142
+ np.array: Transformed 2D point.
143
+ """
144
+ src_pt = np.array([pt_2d[0], pt_2d[1], 1.]).T
145
+ dst_pt = np.dot(trans, src_pt)
146
+ return dst_pt[0:2]
147
+
148
+
149
+ def get_transform(center, scale, res, rot=0):
150
+ """Generate transformation matrix."""
151
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
152
+ h = 200 * scale
153
+ t = np.zeros((3, 3))
154
+ t[0, 0] = float(res[1]) / h
155
+ t[1, 1] = float(res[0]) / h
156
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
157
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
158
+ t[2, 2] = 1
159
+ if not rot == 0:
160
+ rot = -rot # To match direction of rotation from cropping
161
+ rot_mat = np.zeros((3, 3))
162
+ rot_rad = rot * np.pi / 180
163
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
164
+ rot_mat[0, :2] = [cs, -sn]
165
+ rot_mat[1, :2] = [sn, cs]
166
+ rot_mat[2, 2] = 1
167
+ # Need to rotate around center
168
+ t_mat = np.eye(3)
169
+ t_mat[0, 2] = -res[1] / 2
170
+ t_mat[1, 2] = -res[0] / 2
171
+ t_inv = t_mat.copy()
172
+ t_inv[:2, 2] *= -1
173
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
174
+ return t
175
+
176
+
177
+ def transform(pt, center, scale, res, invert=0, rot=0, as_int=True):
178
+ """Transform pixel location to different reference."""
179
+ """Taken from PARE: https://github.com/mkocabas/PARE/blob/6e0caca86c6ab49ff80014b661350958e5b72fd8/pare/utils/image_utils.py"""
180
+ t = get_transform(center, scale, res, rot=rot)
181
+ if invert:
182
+ t = np.linalg.inv(t)
183
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
184
+ new_pt = np.dot(t, new_pt)
185
+ if as_int:
186
+ new_pt = new_pt.astype(int)
187
+ return new_pt[:2] + 1
188
+
189
+
190
+ def crop_img(img, ul, br, border_mode=cv2.BORDER_CONSTANT, border_value=0):
191
+ c_x = (ul[0] + br[0]) / 2
192
+ c_y = (ul[1] + br[1]) / 2
193
+ bb_width = patch_width = br[0] - ul[0]
194
+ bb_height = patch_height = br[1] - ul[1]
195
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, 1.0, 0)
196
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
197
+ flags=cv2.INTER_LINEAR,
198
+ borderMode=border_mode,
199
+ borderValue=border_value
200
+ )
201
+
202
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
203
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
204
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
205
+ flags=cv2.INTER_LINEAR,
206
+ borderMode=cv2.BORDER_CONSTANT,
207
+ )
208
+
209
+ return img_patch
210
+
211
+
212
+ def generate_image_patch_skimage(img: np.array, c_x: float, c_y: float,
213
+ bb_width: float, bb_height: float,
214
+ patch_width: float, patch_height: float,
215
+ do_flip: bool, scale: float, rot: float,
216
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
217
+ """
218
+ Crop image according to the supplied bounding box.
219
+ Args:
220
+ img (np.array): Input image of shape (H, W, 3)
221
+ c_x (float): Bounding box center x coordinate in the original image.
222
+ c_y (float): Bounding box center y coordinate in the original image.
223
+ bb_width (float): Bounding box width.
224
+ bb_height (float): Bounding box height.
225
+ patch_width (float): Output box width.
226
+ patch_height (float): Output box height.
227
+ do_flip (bool): Whether to flip image or not.
228
+ scale (float): Rescaling factor for the bounding box (augmentation).
229
+ rot (float): Random rotation applied to the box.
230
+ Returns:
231
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
232
+ trans (np.array): Transformation matrix.
233
+ """
234
+
235
+ img_height, img_width, img_channels = img.shape
236
+ if do_flip:
237
+ img = img[:, ::-1, :]
238
+ c_x = img_width - c_x - 1
239
+
240
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
241
+
242
+ # img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)), flags=cv2.INTER_LINEAR)
243
+
244
+ # skimage
245
+ center = np.zeros(2)
246
+ center[0] = c_x
247
+ center[1] = c_y
248
+ res = np.zeros(2)
249
+ res[0] = patch_width
250
+ res[1] = patch_height
251
+ # assumes bb_width = bb_height
252
+ # assumes patch_width = patch_height
253
+ assert bb_width == bb_height, f'{bb_width=} != {bb_height=}'
254
+ assert patch_width == patch_height, f'{patch_width=} != {patch_height=}'
255
+ scale1 = scale * bb_width / 200.
256
+
257
+ # Upper left point
258
+ ul = np.array(transform([1, 1], center, scale1, res, invert=1, as_int=False)) - 1
259
+ # Bottom right point
260
+ br = np.array(transform([res[0] + 1,
261
+ res[1] + 1], center, scale1, res, invert=1, as_int=False)) - 1
262
+
263
+ # Padding so that when rotated proper amount of context is included
264
+ try:
265
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + 1
266
+ except:
267
+ breakpoint()
268
+ if not rot == 0:
269
+ ul -= pad
270
+ br += pad
271
+
272
+ if False:
273
+ # Old way of cropping image
274
+ ul_int = ul.astype(int)
275
+ br_int = br.astype(int)
276
+ new_shape = [br_int[1] - ul_int[1], br_int[0] - ul_int[0]]
277
+ if len(img.shape) > 2:
278
+ new_shape += [img.shape[2]]
279
+ new_img = np.zeros(new_shape)
280
+
281
+ # Range to fill new array
282
+ new_x = max(0, -ul_int[0]), min(br_int[0], len(img[0])) - ul_int[0]
283
+ new_y = max(0, -ul_int[1]), min(br_int[1], len(img)) - ul_int[1]
284
+ # Range to sample from original image
285
+ old_x = max(0, ul_int[0]), min(len(img[0]), br_int[0])
286
+ old_y = max(0, ul_int[1]), min(len(img), br_int[1])
287
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
288
+ old_x[0]:old_x[1]]
289
+
290
+ # New way of cropping image
291
+ new_img = crop_img(img, ul, br, border_mode=border_mode, border_value=border_value).astype(np.float32)
292
+
293
+ # print(f'{new_img.shape=}')
294
+ # print(f'{new_img1.shape=}')
295
+ # print(f'{np.allclose(new_img, new_img1)=}')
296
+ # print(f'{img.dtype=}')
297
+
298
+ if not rot == 0:
299
+ # Remove padding
300
+
301
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
302
+ new_img = new_img[pad:-pad, pad:-pad]
303
+
304
+ if new_img.shape[0] < 1 or new_img.shape[1] < 1:
305
+ print(f'{img.shape=}')
306
+ print(f'{new_img.shape=}')
307
+ print(f'{ul=}')
308
+ print(f'{br=}')
309
+ print(f'{pad=}')
310
+ print(f'{rot=}')
311
+
312
+ breakpoint()
313
+
314
+ # resize image
315
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
316
+
317
+ new_img = np.clip(new_img, 0, 255).astype(np.uint8)
318
+
319
+ return new_img, trans
320
+
321
+
322
+ def generate_image_patch_cv2(img: np.array, c_x: float, c_y: float,
323
+ bb_width: float, bb_height: float,
324
+ patch_width: float, patch_height: float,
325
+ do_flip: bool, scale: float, rot: float,
326
+ border_mode=cv2.BORDER_CONSTANT, border_value=0) -> Tuple[np.array, np.array]:
327
+ """
328
+ Crop the input image and return the crop and the corresponding transformation matrix.
329
+ Args:
330
+ img (np.array): Input image of shape (H, W, 3)
331
+ c_x (float): Bounding box center x coordinate in the original image.
332
+ c_y (float): Bounding box center y coordinate in the original image.
333
+ bb_width (float): Bounding box width.
334
+ bb_height (float): Bounding box height.
335
+ patch_width (float): Output box width.
336
+ patch_height (float): Output box height.
337
+ do_flip (bool): Whether to flip image or not.
338
+ scale (float): Rescaling factor for the bounding box (augmentation).
339
+ rot (float): Random rotation applied to the box.
340
+ Returns:
341
+ img_patch (np.array): Cropped image patch of shape (patch_height, patch_height, 3)
342
+ trans (np.array): Transformation matrix.
343
+ """
344
+
345
+ img_height, img_width, img_channels = img.shape
346
+ if do_flip:
347
+ img = img[:, ::-1, :]
348
+ c_x = img_width - c_x - 1
349
+
350
+ trans = gen_trans_from_patch_cv(c_x, c_y, bb_width, bb_height, patch_width, patch_height, scale, rot)
351
+
352
+ img_patch = cv2.warpAffine(img, trans, (int(patch_width), int(patch_height)),
353
+ flags=cv2.INTER_LINEAR,
354
+ borderMode=border_mode,
355
+ borderValue=border_value,
356
+ )
357
+ # Force borderValue=cv2.BORDER_CONSTANT for alpha channel
358
+ if (img.shape[2] == 4) and (border_mode != cv2.BORDER_CONSTANT):
359
+ img_patch[:, :, 3] = cv2.warpAffine(img[:, :, 3], trans, (int(patch_width), int(patch_height)),
360
+ flags=cv2.INTER_LINEAR,
361
+ borderMode=cv2.BORDER_CONSTANT,
362
+ )
363
+
364
+ return img_patch, trans
365
+
366
+
367
+ def convert_cvimg_to_tensor(cvimg: np.array):
368
+ """
369
+ Convert image from HWC to CHW format.
370
+ Args:
371
+ cvimg (np.array): Image of shape (H, W, 3) as loaded by OpenCV.
372
+ Returns:
373
+ np.array: Output image of shape (3, H, W).
374
+ """
375
+ # from h,w,c(OpenCV) to c,h,w
376
+ img = cvimg.copy()
377
+ img = np.transpose(img, (2, 0, 1))
378
+ # from int to float
379
+ img = img.astype(np.float32)
380
+ return img
381
+
382
+
383
+ def fliplr_params(smal_params: Dict, has_smal_params: Dict) -> Tuple[Dict, Dict]:
384
+ """
385
+ Flip SMAL parameters when flipping the image.
386
+ Args:
387
+ smal_params (Dict): SMAL parameter annotations.
388
+ has_smal_params (Dict): Whether SMAL annotations are valid.
389
+ Returns:
390
+ Dict, Dict: Flipped SMAL parameters and valid flags.
391
+ """
392
+ global_orient = smal_params['global_orient'].copy()
393
+ pose = smal_params['pose'].copy()
394
+ betas = smal_params['betas'].copy()
395
+ translation = smal_params['translation'].copy()
396
+ has_global_orient = has_smal_params['global_orient'].copy()
397
+ has_pose = has_smal_params['pose'].copy()
398
+ has_betas = has_smal_params['betas'].copy()
399
+ has_translation = has_smal_params['translation'].copy()
400
+
401
+ global_orient[1::3] *= -1
402
+ global_orient[2::3] *= -1
403
+ pose[1::3] *= -1
404
+ pose[2::3] *= -1
405
+ translation[1::3] *= -1
406
+ translation[2::3] *= -1
407
+
408
+ smal_params = {'global_orient': global_orient.astype(np.float32),
409
+ 'pose': pose.astype(np.float32),
410
+ 'betas': betas.astype(np.float32),
411
+ 'translation': translation.astype(np.float32)
412
+ }
413
+
414
+ has_smal_params = {'global_orient': has_global_orient,
415
+ 'pose': has_pose,
416
+ 'betas': has_betas,
417
+ 'translation': has_translation
418
+ }
419
+
420
+ return smal_params, has_smal_params
421
+
422
+
423
+ def fliplr_keypoints(joints: np.array, width: float, flip_permutation: List[int]) -> np.array:
424
+ """
425
+ Flip 2D or 3D keypoints.
426
+ Args:
427
+ joints (np.array): Array of shape (N, 3) or (N, 4) containing 2D or 3D keypoint locations and confidence.
428
+ flip_permutation (List): Permutation to apply after flipping.
429
+ Returns:
430
+ np.array: Flipped 2D or 3D keypoints with shape (N, 3) or (N, 4) respectively.
431
+ """
432
+ joints = joints.copy()
433
+ # Flip horizontal
434
+ joints[:, 0] = width - joints[:, 0] - 1
435
+ joints = joints[flip_permutation, :]
436
+
437
+ return joints
438
+
439
+
440
+ def keypoint_3d_processing(keypoints_3d: np.array, rot: float, filp: bool) -> np.array:
441
+ """
442
+ Process 3D keypoints (rotation/flipping).
443
+ Args:
444
+ keypoints_3d (np.array): Input array of shape (N, 4) containing the 3D keypoints and confidence.
445
+ rot (float): Random rotation applied to the keypoints.
446
+ Returns:
447
+ np.array: Transformed 3D keypoints with shape (N, 4).
448
+ """
449
+ # in-plane rotation
450
+ rot_mat = np.eye(3, dtype=np.float32)
451
+ if not rot == 0:
452
+ rot_rad = -rot * np.pi / 180
453
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
454
+ rot_mat[0, :2] = [cs, -sn]
455
+ rot_mat[1, :2] = [sn, cs]
456
+ keypoints_3d[:, :-1] = np.einsum('ij,kj->ki', rot_mat, keypoints_3d[:, :-1])
457
+ # flip the x coordinates
458
+ if filp:
459
+ keypoints_3d = fliplr_keypoints(keypoints_3d, list(range(len(keypoints_3d))))
460
+ keypoints_3d = keypoints_3d.astype('float32')
461
+ return keypoints_3d
462
+
463
+
464
+ def rot_aa(aa: np.array, rot: float) -> np.array:
465
+ """
466
+ Rotate axis angle parameters.
467
+ Args:
468
+ aa (np.array): Axis-angle vector of shape (3,).
469
+ rot (np.array): Rotation angle in degrees.
470
+ Returns:
471
+ np.array: Rotated axis-angle vector.
472
+ """
473
+ # pose parameters
474
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
475
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
476
+ [0, 0, 1]])
477
+ # find the rotation of the hand in camera frame
478
+ per_rdg, _ = cv2.Rodrigues(aa)
479
+ # apply the global rotation to the global orientation
480
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
481
+ aa = (resrot.T)[0]
482
+ return aa.astype(np.float32)
483
+
484
+
485
+ def smal_param_processing(smal_params: Dict, has_smal_params: Dict, rot: float, do_flip: bool) -> Tuple[Dict, Dict]:
486
+ """
487
+ Apply random augmentations to the SMAL parameters.
488
+ Args:
489
+ smal_params (Dict): SMAL parameter annotations.
490
+ has_smal_params (Dict): Whether SMAL annotations are valid.
491
+ rot (float): Random rotation applied to the keypoints.
492
+ do_flip (bool): Whether to flip keypoints or not.
493
+ Returns:
494
+ Dict, Dict: Transformed SMAL parameters and valid flags.
495
+ """
496
+ if do_flip:
497
+ smal_params, has_smal_params = fliplr_params(smal_params, has_smal_params)
498
+ smal_params['global_orient'] = rot_aa(smal_params['global_orient'], rot)
499
+ # camera location is not change, so the translation is not change too.
500
+ # smal_params['transl'] = np.dot(np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
501
+ # [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
502
+ # [0, 0, 1]], dtype=np.float32), smal_params['transl'])
503
+ return smal_params, has_smal_params
504
+
505
+
506
+ def get_example(img_path: Union[str,np.ndarray], center_x: float, center_y: float,
507
+ width: float, height: float,
508
+ keypoints_2d: np.array, keypoints_3d: np.array,
509
+ smal_params: Dict, has_smal_params: Dict,
510
+ patch_width: int, patch_height: int,
511
+ mean: np.array, std: np.array,
512
+ do_augment: bool, augm_config: CfgNode,
513
+ is_bgr: bool = True,
514
+ use_skimage_antialias: bool = False,
515
+ border_mode: int = cv2.BORDER_CONSTANT,
516
+ return_trans: bool = False,) -> Tuple:
517
+ """
518
+ Get an example from the dataset and (possibly) apply random augmentations.
519
+ Args:
520
+ img_path (str): Image filename
521
+ center_x (float): Bounding box center x coordinate in the original image.
522
+ center_y (float): Bounding box center y coordinate in the original image.
523
+ width (float): Bounding box width.
524
+ height (float): Bounding box height.
525
+ keypoints_2d (np.array): Array with shape (N,3) containing the 2D keypoints in the original image coordinates.
526
+ keypoints_3d (np.array): Array with shape (N,4) containing the 3D keypoints.
527
+ smal_params (Dict): SMAL parameter annotations.
528
+ has_smal_params (Dict): Whether SMAL annotations are valid.
529
+ patch_width (float): Output box width.
530
+ patch_height (float): Output box height.
531
+ mean (np.array): Array of shape (3,) containing the mean for normalizing the input image.
532
+ std (np.array): Array of shape (3,) containing the std for normalizing the input image.
533
+ do_augment (bool): Whether to apply data augmentation or not.
534
+ aug_config (CfgNode): Config containing augmentation parameters.
535
+ Returns:
536
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
537
+ img_patch (np.array): Cropped image patch of shape (3, patch_height, patch_height)
538
+ keypoints_2d (np.array): Array with shape (N,3) containing the transformed 2D keypoints.
539
+ keypoints_3d (np.array): Array with shape (N,4) containing the transformed 3D keypoints.
540
+ smal_params (Dict): Transformed SMAL parameters.
541
+ has_smal_params (Dict): Valid flag for transformed SMAL parameters.
542
+ img_size (np.array): Image size of the original image.
543
+ """
544
+ if isinstance(img_path, str):
545
+ # 1. load image
546
+ cvimg = cv2.imread(img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
547
+ if not isinstance(cvimg, np.ndarray):
548
+ raise IOError("Fail to read %s" % img_path)
549
+ elif isinstance(img_path, np.ndarray):
550
+ cvimg = img_path
551
+ else:
552
+ raise TypeError('img_path must be either a string or a numpy array')
553
+ img_height, img_width, img_channels = cvimg.shape
554
+
555
+ img_size = np.array([img_height, img_width], dtype=np.int32)
556
+
557
+ # 2. get augmentation params
558
+ if do_augment:
559
+ # box rescale factor, rotation angle, flip or not flip, crop or not crop, ..., color scale, translation x, ...
560
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = do_augmentation(augm_config)
561
+ else:
562
+ scale, rot, do_flip, do_extreme_crop, extreme_crop_lvl, color_scale, tx, ty = 1.0, 0, False, False, 0, [1.0,
563
+ 1.0,
564
+ 1.0], 0., 0.
565
+ if width < 1 or height < 1:
566
+ breakpoint()
567
+
568
+ if do_extreme_crop:
569
+ if extreme_crop_lvl == 0:
570
+ center_x1, center_y1, width1, height1 = extreme_cropping(center_x, center_y, width, height, keypoints_2d)
571
+ elif extreme_crop_lvl == 1:
572
+ center_x1, center_y1, width1, height1 = extreme_cropping_aggressive(center_x, center_y, width, height,
573
+ keypoints_2d)
574
+
575
+ THRESH = 4
576
+ if width1 < THRESH or height1 < THRESH:
577
+ # print(f'{do_extreme_crop=}')
578
+ # print(f'width: {width}, height: {height}')
579
+ # print(f'width1: {width1}, height1: {height1}')
580
+ # print(f'center_x: {center_x}, center_y: {center_y}')
581
+ # print(f'center_x1: {center_x1}, center_y1: {center_y1}')
582
+ # print(f'keypoints_2d: {keypoints_2d}')
583
+ # print(f'\n\n', flush=True)
584
+ # breakpoint()
585
+ pass
586
+ # print(f'skip ==> width1: {width1}, height1: {height1}, width: {width}, height: {height}')
587
+ else:
588
+ center_x, center_y, width, height = center_x1, center_y1, width1, height1
589
+
590
+ center_x += width * tx
591
+ center_y += height * ty
592
+
593
+ # Process 3D keypoints
594
+ keypoints_3d = keypoint_3d_processing(keypoints_3d, rot, do_flip)
595
+
596
+ # 3. generate image patch
597
+ if use_skimage_antialias:
598
+ # Blur image to avoid aliasing artifacts
599
+ downsampling_factor = (patch_width / (width * scale))
600
+ if downsampling_factor > 1.1:
601
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True,
602
+ truncate=3.0)
603
+ # augmentation image, translation matrix
604
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
605
+ center_x, center_y,
606
+ width, height,
607
+ patch_width, patch_height,
608
+ do_flip, scale, rot,
609
+ border_mode=border_mode)
610
+ # img_patch_cv, trans = generate_image_patch_skimage(cvimg,
611
+ # center_x, center_y,
612
+ # width, height,
613
+ # patch_width, patch_height,
614
+ # do_flip, scale, rot,
615
+ # border_mode=border_mode)
616
+
617
+ image = img_patch_cv.copy()
618
+ if is_bgr:
619
+ image = image[:, :, ::-1]
620
+ img_patch_cv = image.copy()
621
+ img_patch = convert_cvimg_to_tensor(image) # [h, w, 4] -> [4, h, w]
622
+
623
+ smal_params, has_smal_params = smal_param_processing(smal_params, has_smal_params, rot, do_flip)
624
+
625
+ # apply normalization
626
+ for n_c in range(min(img_channels, 3)):
627
+ img_patch[n_c, :, :] = np.clip(img_patch[n_c, :, :] * color_scale[n_c], 0, 255)
628
+ if mean is not None and std is not None:
629
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - mean[n_c]) / std[n_c]
630
+
631
+ if do_flip:
632
+ keypoints_2d = fliplr_keypoints(keypoints_2d, img_width, list(range(len(keypoints_2d))))
633
+
634
+ for n_jt in range(len(keypoints_2d)):
635
+ keypoints_2d[n_jt, 0:2] = trans_point2d(keypoints_2d[n_jt, 0:2], trans)
636
+ keypoints_2d[:, :-1] = keypoints_2d[:, :-1] / patch_width - 0.5
637
+
638
+ if not return_trans:
639
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size
640
+ else:
641
+ return img_patch, keypoints_2d, keypoints_3d, smal_params, has_smal_params, img_size, trans
642
+
643
+
644
+ def crop_to_hips(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
645
+ """
646
+ Extreme cropping: Crop the box up to the hip locations.
647
+ Args:
648
+ center_x (float): x coordinate of the bounding box center.
649
+ center_y (float): y coordinate of the bounding box center.
650
+ width (float): Bounding box width.
651
+ height (float): Bounding box height.
652
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
653
+ Returns:
654
+ center_x (float): x coordinate of the new bounding box center.
655
+ center_y (float): y coordinate of the new bounding box center.
656
+ width (float): New bounding box width.
657
+ height (float): New bounding box height.
658
+ """
659
+ keypoints_2d = keypoints_2d.copy()
660
+ lower_body_keypoints = [10, 11, 13, 14, 19, 20, 21, 22, 23, 24, 25 + 0, 25 + 1, 25 + 4, 25 + 5]
661
+ keypoints_2d[lower_body_keypoints, :] = 0
662
+ if keypoints_2d[:, -1].sum() > 1:
663
+ center, scale = get_bbox(keypoints_2d)
664
+ center_x = center[0]
665
+ center_y = center[1]
666
+ width = 1.1 * scale[0]
667
+ height = 1.1 * scale[1]
668
+ return center_x, center_y, width, height
669
+
670
+
671
+ def crop_to_shoulders(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
672
+ """
673
+ Extreme cropping: Crop the box up to the shoulder locations.
674
+ Args:
675
+ center_x (float): x coordinate of the bounding box center.
676
+ center_y (float): y coordinate of the bounding box center.
677
+ width (float): Bounding box width.
678
+ height (float): Bounding box height.
679
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
680
+ Returns:
681
+ center_x (float): x coordinate of the new bounding box center.
682
+ center_y (float): y coordinate of the new bounding box center.
683
+ width (float): New bounding box width.
684
+ height (float): New bounding box height.
685
+ """
686
+ keypoints_2d = keypoints_2d.copy()
687
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
688
+ [0, 1, 2, 3, 4, 5, 6, 7,
689
+ 10, 11, 14, 15, 16]]
690
+ keypoints_2d[lower_body_keypoints, :] = 0
691
+ center, scale = get_bbox(keypoints_2d)
692
+ if keypoints_2d[:, -1].sum() > 1:
693
+ center, scale = get_bbox(keypoints_2d)
694
+ center_x = center[0]
695
+ center_y = center[1]
696
+ width = 1.2 * scale[0]
697
+ height = 1.2 * scale[1]
698
+ return center_x, center_y, width, height
699
+
700
+
701
+ def crop_to_head(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
702
+ """
703
+ Extreme cropping: Crop the box and keep on only the head.
704
+ Args:
705
+ center_x (float): x coordinate of the bounding box center.
706
+ center_y (float): y coordinate of the bounding box center.
707
+ width (float): Bounding box width.
708
+ height (float): Bounding box height.
709
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
710
+ Returns:
711
+ center_x (float): x coordinate of the new bounding box center.
712
+ center_y (float): y coordinate of the new bounding box center.
713
+ width (float): New bounding box width.
714
+ height (float): New bounding box height.
715
+ """
716
+ keypoints_2d = keypoints_2d.copy()
717
+ lower_body_keypoints = [3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 19, 20, 21, 22, 23, 24] + [25 + i for i in
718
+ [0, 1, 2, 3, 4, 5, 6, 7, 8,
719
+ 9, 10, 11, 14, 15, 16]]
720
+ keypoints_2d[lower_body_keypoints, :] = 0
721
+ if keypoints_2d[:, -1].sum() > 1:
722
+ center, scale = get_bbox(keypoints_2d)
723
+ center_x = center[0]
724
+ center_y = center[1]
725
+ width = 1.3 * scale[0]
726
+ height = 1.3 * scale[1]
727
+ return center_x, center_y, width, height
728
+
729
+
730
+ def crop_torso_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
731
+ """
732
+ Extreme cropping: Crop the box and keep on only the torso.
733
+ Args:
734
+ center_x (float): x coordinate of the bounding box center.
735
+ center_y (float): y coordinate of the bounding box center.
736
+ width (float): Bounding box width.
737
+ height (float): Bounding box height.
738
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
739
+ Returns:
740
+ center_x (float): x coordinate of the new bounding box center.
741
+ center_y (float): y coordinate of the new bounding box center.
742
+ width (float): New bounding box width.
743
+ height (float): New bounding box height.
744
+ """
745
+ keypoints_2d = keypoints_2d.copy()
746
+ nontorso_body_keypoints = [0, 3, 4, 6, 7, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [25 + i for i in
747
+ [0, 1, 4, 5, 6,
748
+ 7, 10, 11, 13,
749
+ 17, 18]]
750
+ keypoints_2d[nontorso_body_keypoints, :] = 0
751
+ if keypoints_2d[:, -1].sum() > 1:
752
+ center, scale = get_bbox(keypoints_2d)
753
+ center_x = center[0]
754
+ center_y = center[1]
755
+ width = 1.1 * scale[0]
756
+ height = 1.1 * scale[1]
757
+ return center_x, center_y, width, height
758
+
759
+
760
+ def crop_rightarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
761
+ """
762
+ Extreme cropping: Crop the box and keep on only the right arm.
763
+ Args:
764
+ center_x (float): x coordinate of the bounding box center.
765
+ center_y (float): y coordinate of the bounding box center.
766
+ width (float): Bounding box width.
767
+ height (float): Bounding box height.
768
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
769
+ Returns:
770
+ center_x (float): x coordinate of the new bounding box center.
771
+ center_y (float): y coordinate of the new bounding box center.
772
+ width (float): New bounding box width.
773
+ height (float): New bounding box height.
774
+ """
775
+ keypoints_2d = keypoints_2d.copy()
776
+ nonrightarm_body_keypoints = [0, 1, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
777
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]]
778
+ keypoints_2d[nonrightarm_body_keypoints, :] = 0
779
+ if keypoints_2d[:, -1].sum() > 1:
780
+ center, scale = get_bbox(keypoints_2d)
781
+ center_x = center[0]
782
+ center_y = center[1]
783
+ width = 1.1 * scale[0]
784
+ height = 1.1 * scale[1]
785
+ return center_x, center_y, width, height
786
+
787
+
788
+ def crop_leftarm_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
789
+ """
790
+ Extreme cropping: Crop the box and keep on only the left arm.
791
+ Args:
792
+ center_x (float): x coordinate of the bounding box center.
793
+ center_y (float): y coordinate of the bounding box center.
794
+ width (float): Bounding box width.
795
+ height (float): Bounding box height.
796
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
797
+ Returns:
798
+ center_x (float): x coordinate of the new bounding box center.
799
+ center_y (float): y coordinate of the new bounding box center.
800
+ width (float): New bounding box width.
801
+ height (float): New bounding box height.
802
+ """
803
+ keypoints_2d = keypoints_2d.copy()
804
+ nonleftarm_body_keypoints = [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] + [
805
+ 25 + i for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18]]
806
+ keypoints_2d[nonleftarm_body_keypoints, :] = 0
807
+ if keypoints_2d[:, -1].sum() > 1:
808
+ center, scale = get_bbox(keypoints_2d)
809
+ center_x = center[0]
810
+ center_y = center[1]
811
+ width = 1.1 * scale[0]
812
+ height = 1.1 * scale[1]
813
+ return center_x, center_y, width, height
814
+
815
+
816
+ def crop_legs_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
817
+ """
818
+ Extreme cropping: Crop the box and keep on only the legs.
819
+ Args:
820
+ center_x (float): x coordinate of the bounding box center.
821
+ center_y (float): y coordinate of the bounding box center.
822
+ width (float): Bounding box width.
823
+ height (float): Bounding box height.
824
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
825
+ Returns:
826
+ center_x (float): x coordinate of the new bounding box center.
827
+ center_y (float): y coordinate of the new bounding box center.
828
+ width (float): New bounding box width.
829
+ height (float): New bounding box height.
830
+ """
831
+ keypoints_2d = keypoints_2d.copy()
832
+ nonlegs_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 15, 16, 17, 18] + [25 + i for i in
833
+ [6, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18]]
834
+ keypoints_2d[nonlegs_body_keypoints, :] = 0
835
+ if keypoints_2d[:, -1].sum() > 1:
836
+ center, scale = get_bbox(keypoints_2d)
837
+ center_x = center[0]
838
+ center_y = center[1]
839
+ width = 1.1 * scale[0]
840
+ height = 1.1 * scale[1]
841
+ return center_x, center_y, width, height
842
+
843
+
844
+ def crop_rightleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
845
+ """
846
+ Extreme cropping: Crop the box and keep on only the right leg.
847
+ Args:
848
+ center_x (float): x coordinate of the bounding box center.
849
+ center_y (float): y coordinate of the bounding box center.
850
+ width (float): Bounding box width.
851
+ height (float): Bounding box height.
852
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
853
+ Returns:
854
+ center_x (float): x coordinate of the new bounding box center.
855
+ center_y (float): y coordinate of the new bounding box center.
856
+ width (float): New bounding box width.
857
+ height (float): New bounding box height.
858
+ """
859
+ keypoints_2d = keypoints_2d.copy()
860
+ nonrightleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] + [25 + i for i in
861
+ [3, 4, 5, 6, 7,
862
+ 8, 9, 10, 11,
863
+ 12, 13, 14, 15,
864
+ 16, 17, 18]]
865
+ keypoints_2d[nonrightleg_body_keypoints, :] = 0
866
+ if keypoints_2d[:, -1].sum() > 1:
867
+ center, scale = get_bbox(keypoints_2d)
868
+ center_x = center[0]
869
+ center_y = center[1]
870
+ width = 1.1 * scale[0]
871
+ height = 1.1 * scale[1]
872
+ return center_x, center_y, width, height
873
+
874
+
875
+ def crop_leftleg_only(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array):
876
+ """
877
+ Extreme cropping: Crop the box and keep on only the left leg.
878
+ Args:
879
+ center_x (float): x coordinate of the bounding box center.
880
+ center_y (float): y coordinate of the bounding box center.
881
+ width (float): Bounding box width.
882
+ height (float): Bounding box height.
883
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
884
+ Returns:
885
+ center_x (float): x coordinate of the new bounding box center.
886
+ center_y (float): y coordinate of the new bounding box center.
887
+ width (float): New bounding box width.
888
+ height (float): New bounding box height.
889
+ """
890
+ keypoints_2d = keypoints_2d.copy()
891
+ nonleftleg_body_keypoints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16, 17, 18, 22, 23, 24] + [25 + i for i in
892
+ [0, 1, 2, 6, 7, 8,
893
+ 9, 10, 11, 12,
894
+ 13, 14, 15, 16,
895
+ 17, 18]]
896
+ keypoints_2d[nonleftleg_body_keypoints, :] = 0
897
+ if keypoints_2d[:, -1].sum() > 1:
898
+ center, scale = get_bbox(keypoints_2d)
899
+ center_x = center[0]
900
+ center_y = center[1]
901
+ width = 1.1 * scale[0]
902
+ height = 1.1 * scale[1]
903
+ return center_x, center_y, width, height
904
+
905
+
906
+ def full_body(keypoints_2d: np.array) -> bool:
907
+ """
908
+ Check if all main body joints are visible.
909
+ Args:
910
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
911
+ Returns:
912
+ bool: True if all main body joints are visible.
913
+ """
914
+
915
+ body_keypoints_openpose = [2, 3, 4, 5, 6, 7, 10, 11, 13, 14]
916
+ body_keypoints = [25 + i for i in [8, 7, 6, 9, 10, 11, 1, 0, 4, 5]]
917
+ return (np.maximum(keypoints_2d[body_keypoints, -1], keypoints_2d[body_keypoints_openpose, -1]) > 0).sum() == len(
918
+ body_keypoints)
919
+
920
+
921
+ def upper_body(keypoints_2d: np.array):
922
+ """
923
+ Check if all upper body joints are visible.
924
+ Args:
925
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
926
+ Returns:
927
+ bool: True if all main body joints are visible.
928
+ """
929
+ lower_body_keypoints_openpose = [10, 11, 13, 14]
930
+ lower_body_keypoints = [25 + i for i in [1, 0, 4, 5]]
931
+ upper_body_keypoints_openpose = [0, 1, 15, 16, 17, 18]
932
+ upper_body_keypoints = [25 + 8, 25 + 9, 25 + 12, 25 + 13, 25 + 17, 25 + 18]
933
+ return ((keypoints_2d[lower_body_keypoints + lower_body_keypoints_openpose, -1] > 0).sum() == 0) \
934
+ and ((keypoints_2d[upper_body_keypoints + upper_body_keypoints_openpose, -1] > 0).sum() >= 2)
935
+
936
+
937
+ def get_bbox(keypoints_2d: np.array, rescale: float = 1.2) -> Tuple:
938
+ """
939
+ Get center and scale for bounding box from openpose detections.
940
+ Args:
941
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
942
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
943
+ Returns:
944
+ center (np.array): Array of shape (2,) containing the new bounding box center.
945
+ scale (float): New bounding box scale.
946
+ """
947
+ valid = keypoints_2d[:, -1] > 0
948
+ valid_keypoints = keypoints_2d[valid][:, :-1]
949
+ center = 0.5 * (valid_keypoints.max(axis=0) + valid_keypoints.min(axis=0))
950
+ bbox_size = (valid_keypoints.max(axis=0) - valid_keypoints.min(axis=0))
951
+ # adjust bounding box tightness
952
+ scale = bbox_size
953
+ scale *= rescale
954
+ return center, scale
955
+
956
+
957
+ def extreme_cropping(center_x: float, center_y: float, width: float, height: float, keypoints_2d: np.array) -> Tuple:
958
+ """
959
+ Perform extreme cropping
960
+ Args:
961
+ center_x (float): x coordinate of bounding box center.
962
+ center_y (float): y coordinate of bounding box center.
963
+ width (float): bounding box width.
964
+ height (float): bounding box height.
965
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
966
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
967
+ Returns:
968
+ center_x (float): x coordinate of bounding box center.
969
+ center_y (float): y coordinate of bounding box center.
970
+ width (float): bounding box width.
971
+ height (float): bounding box height.
972
+ """
973
+ p = torch.rand(1).item()
974
+ if full_body(keypoints_2d):
975
+ if p < 0.7:
976
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
977
+ elif p < 0.9:
978
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
979
+ else:
980
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
981
+ elif upper_body(keypoints_2d):
982
+ if p < 0.9:
983
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
984
+ else:
985
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
986
+
987
+ return center_x, center_y, max(width, height), max(width, height)
988
+
989
+
990
+ def extreme_cropping_aggressive(center_x: float, center_y: float, width: float, height: float,
991
+ keypoints_2d: np.array) -> Tuple:
992
+ """
993
+ Perform aggressive extreme cropping
994
+ Args:
995
+ center_x (float): x coordinate of bounding box center.
996
+ center_y (float): y coordinate of bounding box center.
997
+ width (float): bounding box width.
998
+ height (float): bounding box height.
999
+ keypoints_2d (np.array): Array of shape (N, 3) containing 2D keypoint locations.
1000
+ rescale (float): Scale factor to rescale bounding boxes computed from the keypoints.
1001
+ Returns:
1002
+ center_x (float): x coordinate of bounding box center.
1003
+ center_y (float): y coordinate of bounding box center.
1004
+ width (float): bounding box width.
1005
+ height (float): bounding box height.
1006
+ """
1007
+ p = torch.rand(1).item()
1008
+ if full_body(keypoints_2d):
1009
+ if p < 0.2:
1010
+ center_x, center_y, width, height = crop_to_hips(center_x, center_y, width, height, keypoints_2d)
1011
+ elif p < 0.3:
1012
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1013
+ elif p < 0.4:
1014
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1015
+ elif p < 0.5:
1016
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1017
+ elif p < 0.6:
1018
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1019
+ elif p < 0.7:
1020
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1021
+ elif p < 0.8:
1022
+ center_x, center_y, width, height = crop_legs_only(center_x, center_y, width, height, keypoints_2d)
1023
+ elif p < 0.9:
1024
+ center_x, center_y, width, height = crop_rightleg_only(center_x, center_y, width, height, keypoints_2d)
1025
+ else:
1026
+ center_x, center_y, width, height = crop_leftleg_only(center_x, center_y, width, height, keypoints_2d)
1027
+ elif upper_body(keypoints_2d):
1028
+ if p < 0.2:
1029
+ center_x, center_y, width, height = crop_to_shoulders(center_x, center_y, width, height, keypoints_2d)
1030
+ elif p < 0.4:
1031
+ center_x, center_y, width, height = crop_to_head(center_x, center_y, width, height, keypoints_2d)
1032
+ elif p < 0.6:
1033
+ center_x, center_y, width, height = crop_torso_only(center_x, center_y, width, height, keypoints_2d)
1034
+ elif p < 0.8:
1035
+ center_x, center_y, width, height = crop_rightarm_only(center_x, center_y, width, height, keypoints_2d)
1036
+ else:
1037
+ center_x, center_y, width, height = crop_leftarm_only(center_x, center_y, width, height, keypoints_2d)
1038
+ return center_x, center_y, max(width, height), max(width, height)
amr/datasets/vitdet_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from skimage.filters import gaussian
6
+ from yacs.config import CfgNode
7
+ import torch
8
+
9
+ from .utils import (convert_cvimg_to_tensor,
10
+ expand_to_aspect_ratio,
11
+ generate_image_patch_cv2)
12
+
13
+ DEFAULT_MEAN = 255. * np.array([0.485, 0.456, 0.406])
14
+ DEFAULT_STD = 255. * np.array([0.229, 0.224, 0.225])
15
+
16
+
17
+ class ViTDetDataset(torch.utils.data.Dataset):
18
+
19
+ def __init__(self,
20
+ cfg: CfgNode,
21
+ img_cv2: np.array,
22
+ boxes: np.array,
23
+ rescale_factor=1,
24
+ train: bool = False,
25
+ **kwargs):
26
+ super().__init__()
27
+ self.cfg = cfg
28
+ self.img_cv2 = img_cv2
29
+ self.boxes = boxes
30
+
31
+ assert train is False, "ViTDetDataset is only for inference"
32
+ self.train = train
33
+ self.img_size = cfg.MODEL.IMAGE_SIZE
34
+ self.mean = 255. * np.array(self.cfg.MODEL.IMAGE_MEAN)
35
+ self.std = 255. * np.array(self.cfg.MODEL.IMAGE_STD)
36
+
37
+ # Preprocess annotations
38
+ boxes = boxes.astype(np.float32)
39
+ self.center = (boxes[:, 2:4] + boxes[:, 0:2]) / 2.0
40
+ self.scale = rescale_factor * (boxes[:, 2:4] - boxes[:, 0:2]) / 200.0
41
+ self.personid = np.arange(len(boxes), dtype=np.int32)
42
+
43
+ def __len__(self) -> int:
44
+ return len(self.personid)
45
+
46
+ def __getitem__(self, idx: int) -> Dict[str, np.array]:
47
+
48
+ center = self.center[idx].copy()
49
+ center_x = center[0]
50
+ center_y = center[1]
51
+
52
+ scale = self.scale[idx]
53
+ BBOX_SHAPE = self.cfg.MODEL.get('BBOX_SHAPE', None)
54
+ bbox_size = expand_to_aspect_ratio(scale * 200, target_aspect_ratio=BBOX_SHAPE).max()
55
+
56
+ patch_width = patch_height = self.img_size
57
+
58
+ flip = False
59
+
60
+ # 3. generate image patch
61
+ # if use_skimage_antialias:
62
+ cvimg = self.img_cv2.copy()
63
+ if True:
64
+ # Blur image to avoid aliasing artifacts
65
+ downsampling_factor = ((bbox_size * 1.0) / patch_width)
66
+ print(f'{downsampling_factor=}')
67
+ downsampling_factor = downsampling_factor / 2.0
68
+ if downsampling_factor > 1.1:
69
+ cvimg = gaussian(cvimg, sigma=(downsampling_factor - 1) / 2, channel_axis=2, preserve_range=True)
70
+
71
+ img_patch_cv, trans = generate_image_patch_cv2(cvimg,
72
+ center_x, center_y,
73
+ bbox_size, bbox_size,
74
+ patch_width, patch_height,
75
+ flip, 1.0, 0.0,
76
+ border_mode=cv2.BORDER_CONSTANT)
77
+ img_patch_cv = img_patch_cv[:, :, ::-1]
78
+ img_patch = convert_cvimg_to_tensor(img_patch_cv)
79
+
80
+ # apply normalization
81
+ for n_c in range(min(self.img_cv2.shape[2], 3)):
82
+ img_patch[n_c, :, :] = (img_patch[n_c, :, :] - self.mean[n_c]) / self.std[n_c]
83
+
84
+ item = {
85
+ 'img': img_patch / 255.,
86
+ 'personid': int(self.personid[idx]),
87
+ 'box_center': self.center[idx].copy(),
88
+ 'box_size': bbox_size,
89
+ 'img_size': 1.0 * np.array([cvimg.shape[1], cvimg.shape[0]]),
90
+ 'focal_length': np.array([self.cfg.EXTRA.FOCAL_LENGTH, self.cfg.EXTRA.FOCAL_LENGTH]),
91
+ }
92
+ return item
amr/models/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .smal_warapper import SMAL
2
+ from ..configs import CACHE_DIR_HAMER
3
+ from .amr import AMR
4
+
5
+ DEFAULT_CHECKPOINT = f'{CACHE_DIR_HAMER}/train/runs/AniMer/checkpoints/checkpoint.ckpt'
6
+
7
+
8
+ def load_amr(checkpoint_path=DEFAULT_CHECKPOINT):
9
+ from pathlib import Path
10
+ from ..configs import get_config
11
+ model_cfg = str(Path(checkpoint_path).parent.parent / '.hydra/config.yaml')
12
+ model_cfg = get_config(model_cfg, update_cachedir=True)
13
+
14
+ # Override some config values, to crop bbox correctly
15
+ if (model_cfg.MODEL.BACKBONE.TYPE == 'vit') and ('BBOX_SHAPE' not in model_cfg.MODEL):
16
+ model_cfg.defrost()
17
+ assert model_cfg.MODEL.IMAGE_SIZE == 256, f"MODEL.IMAGE_SIZE ({model_cfg.MODEL.IMAGE_SIZE}) should be 256 for ViT backbone"
18
+ model_cfg.MODEL.BBOX_SHAPE = [192, 256]
19
+ model_cfg.freeze()
20
+
21
+ # Update config to be compatible with demo
22
+ if ('PRETRAINED_WEIGHTS' in model_cfg.MODEL.BACKBONE):
23
+ model_cfg.defrost()
24
+ model_cfg.MODEL.BACKBONE.pop('PRETRAINED_WEIGHTS')
25
+ model_cfg.freeze()
26
+
27
+ model = AMR.load_from_checkpoint(checkpoint_path, strict=False, cfg=model_cfg)
28
+ return model, model_cfg
amr/models/amr.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pickle
3
+ import pytorch_lightning as pl
4
+ from typing import Any, Dict
5
+ from yacs.config import CfgNode
6
+ from ..utils.geometry import aa_to_rotmat, perspective_projection
7
+ from ..utils.pylogger import get_pylogger
8
+ from .backbones import create_backbone
9
+ from .heads import build_smal_head
10
+ from . import SMAL
11
+
12
+ log = get_pylogger(__name__)
13
+
14
+
15
+ class AMR(pl.LightningModule):
16
+
17
+ def __init__(self, cfg: CfgNode, init_renderer: bool = True):
18
+ """
19
+ Setup AMR model
20
+ Args:
21
+ cfg (CfgNode): Config file as a yacs CfgNode
22
+ """
23
+ super().__init__()
24
+
25
+ # Save hyperparameters
26
+ self.save_hyperparameters(logger=False, ignore=['init_renderer'])
27
+
28
+ self.cfg = cfg
29
+ # Create backbone feature extractor
30
+ self.backbone = create_backbone(cfg)
31
+
32
+ # Create SMAL head
33
+ self.smal_head = build_smal_head(cfg)
34
+
35
+ # Instantiate SMAL model
36
+ smal_model_path = cfg.SMAL.MODEL_PATH
37
+ with open(smal_model_path, 'rb') as f:
38
+ smal_cfg = pickle.load(f, encoding="latin1")
39
+ self.smal = SMAL(**smal_cfg)
40
+
41
+ def forward_step(self, batch: Dict) -> Dict:
42
+ """
43
+ Run a forward step of the network
44
+ Args:
45
+ batch (Dict): Dictionary containing batch data
46
+ Returns:
47
+ Dict: Dictionary containing the regression output
48
+ """
49
+
50
+ # Use RGB image as input
51
+ x = batch['img']
52
+ batch_size = x.shape[0]
53
+
54
+ # Compute conditioning features using the backbone
55
+ conditioning_feats, cls = self.backbone(x[:, :, :, 32:-32]) # [256, 192]
56
+ # conditioning_feats = self.backbone.forward_features(x)['x_norm_patchtokens']
57
+ # pred_mano_params:{'betas':[batch_size, 10], 'global_orient': [batch_size, 1, 3, 3],
58
+ # 'pose':[batch_size, 33, 3, 3], 'translation': [batch_size, 3]}
59
+ # pred_cam:[batch_size, 3]
60
+ pred_smal_params, pred_cam, _ = self.smal_head(conditioning_feats)
61
+
62
+ # Store useful regression outputs to the output dict
63
+ output = {}
64
+
65
+ output['pred_cam'] = pred_cam
66
+ output['pred_smal_params'] = {k: v.clone() for k, v in pred_smal_params.items()}
67
+
68
+ # Compute camera translation
69
+ focal_length = batch['focal_length']
70
+ pred_cam_t = torch.stack([pred_cam[:, 1],
71
+ pred_cam[:, 2],
72
+ 2 * focal_length[:, 0] / (self.cfg.MODEL.IMAGE_SIZE * pred_cam[:, 0] + 1e-9)], dim=-1)
73
+ output['pred_cam_t'] = pred_cam_t
74
+ output['focal_length'] = focal_length
75
+
76
+ # Compute model vertices, joints and the projected joints
77
+ pred_smal_params['global_orient'] = pred_smal_params['global_orient'].reshape(batch_size, -1, 3, 3)
78
+ pred_smal_params['pose'] = pred_smal_params['pose'].reshape(batch_size, -1, 3, 3)
79
+ pred_smal_params['betas'] = pred_smal_params['betas'].reshape(batch_size, -1)
80
+ smal_output = self.smal(**pred_smal_params, pose2rot=False)
81
+
82
+ pred_keypoints_3d = smal_output.joints
83
+ pred_vertices = smal_output.vertices
84
+ output['pred_keypoints_3d'] = pred_keypoints_3d.reshape(batch_size, -1, 3)
85
+ output['pred_vertices'] = pred_vertices.reshape(batch_size, -1, 3)
86
+ pred_cam_t = pred_cam_t.reshape(-1, 3)
87
+ focal_length = focal_length.reshape(-1, 2)
88
+ pred_keypoints_2d = perspective_projection(pred_keypoints_3d,
89
+ translation=pred_cam_t,
90
+ focal_length=focal_length / self.cfg.MODEL.IMAGE_SIZE)
91
+
92
+ output['pred_keypoints_2d'] = pred_keypoints_2d.reshape(batch_size, -1, 2)
93
+ return output
94
+
95
+ def forward(self, batch: Dict) -> Dict:
96
+ """
97
+ Run a forward step of the network in val mode
98
+ Args:
99
+ batch (Dict): Dictionary containing batch data
100
+ Returns:
101
+ Dict: Dictionary containing the regression output
102
+ """
103
+ return self.forward_step(batch)
104
+
amr/models/backbones/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .vit import vit, vitl
2
+
3
+ def create_backbone(cfg):
4
+ if cfg.MODEL.BACKBONE.TYPE == 'vit':
5
+ return vit(cfg)
6
+ else:
7
+ raise NotImplementedError('Backbone type is not implemented')
amr/models/backbones/vit.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ from functools import partial
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint as checkpoint
9
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
10
+
11
+
12
+ def vit(cfg):
13
+ return ViT(
14
+ img_size=(256, 192),
15
+ patch_size=16,
16
+ embed_dim=1280,
17
+ depth=32,
18
+ num_heads=16,
19
+ ratio=1,
20
+ use_checkpoint=False,
21
+ mlp_ratio=4,
22
+ qkv_bias=True,
23
+ drop_path_rate=0.55,
24
+ use_cls=cfg.MODEL.get("USE_CLS", False),
25
+ )
26
+
27
+
28
+ def vitl(cfg):
29
+ return ViT(
30
+ img_size=(256, 192),
31
+ patch_size=16,
32
+ embed_dim=1024,
33
+ depth=24,
34
+ num_heads=16,
35
+ ratio=1,
36
+ use_checkpoint=False,
37
+ mlp_ratio=4,
38
+ qkv_bias=True,
39
+ drop_path_rate=0.5,
40
+ use_cls=cfg.MODEL.get("USE_CLS", False),
41
+ )
42
+
43
+
44
+ def get_abs_pos(abs_pos, h, w, ori_h, ori_w, has_cls_token=True):
45
+ """
46
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
47
+ dimension for the original embeddings.
48
+ Args:
49
+ abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
50
+ has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
51
+ hw (Tuple): size of input image tokens.
52
+
53
+ Returns:
54
+ Absolute positional embeddings after processing with shape (1, H, W, C)
55
+ """
56
+ cls_token = None
57
+ B, L, C = abs_pos.shape
58
+ if has_cls_token:
59
+ cls_token = abs_pos[:, 0:1]
60
+ abs_pos = abs_pos[:, 1:]
61
+
62
+ if ori_h != h or ori_w != w:
63
+ new_abs_pos = F.interpolate(
64
+ abs_pos.reshape(1, ori_h, ori_w, -1).permute(0, 3, 1, 2),
65
+ size=(h, w),
66
+ mode="bicubic",
67
+ align_corners=False,
68
+ ).permute(0, 2, 3, 1).reshape(B, -1, C)
69
+
70
+ else:
71
+ new_abs_pos = abs_pos
72
+
73
+ if cls_token is not None:
74
+ new_abs_pos = torch.cat([cls_token, new_abs_pos], dim=1)
75
+ return new_abs_pos
76
+
77
+
78
+ class DropPath(nn.Module):
79
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
80
+ """
81
+
82
+ def __init__(self, drop_prob=None):
83
+ super(DropPath, self).__init__()
84
+ self.drop_prob = drop_prob
85
+
86
+ def forward(self, x):
87
+ return drop_path(x, self.drop_prob, self.training)
88
+
89
+ def extra_repr(self):
90
+ return 'p={}'.format(self.drop_prob)
91
+
92
+
93
+ class Mlp(nn.Module):
94
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
95
+ super().__init__()
96
+ out_features = out_features or in_features
97
+ hidden_features = hidden_features or in_features
98
+ self.fc1 = nn.Linear(in_features, hidden_features)
99
+ self.act = act_layer()
100
+ self.fc2 = nn.Linear(hidden_features, out_features)
101
+ self.drop = nn.Dropout(drop)
102
+
103
+ def forward(self, x):
104
+ x = self.fc1(x)
105
+ x = self.act(x)
106
+ x = self.fc2(x)
107
+ x = self.drop(x)
108
+ return x
109
+
110
+
111
+ class Attention(nn.Module):
112
+ def __init__(
113
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
114
+ proj_drop=0., attn_head_dim=None):
115
+ super().__init__()
116
+ self.num_heads = num_heads
117
+ head_dim = dim // num_heads
118
+ self.dim = dim
119
+
120
+ if attn_head_dim is not None:
121
+ head_dim = attn_head_dim
122
+ all_head_dim = head_dim * self.num_heads
123
+
124
+ self.scale = qk_scale or head_dim ** -0.5
125
+
126
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=qkv_bias)
127
+
128
+ self.attn_drop = nn.Dropout(attn_drop)
129
+ self.proj = nn.Linear(all_head_dim, dim)
130
+ self.proj_drop = nn.Dropout(proj_drop)
131
+
132
+ def forward(self, x):
133
+ B, N, C = x.shape
134
+ qkv = self.qkv(x)
135
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
136
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
137
+
138
+ q = q * self.scale
139
+ attn = (q @ k.transpose(-2, -1))
140
+ attn = attn.softmax(dim=-1)
141
+ attn = self.attn_drop(attn)
142
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
143
+
144
+ x = self.proj(x)
145
+ x = self.proj_drop(x)
146
+
147
+ return x
148
+
149
+
150
+ class Block(nn.Module):
151
+
152
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
153
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
154
+ norm_layer=nn.LayerNorm, attn_head_dim=None,
155
+ ):
156
+ super().__init__()
157
+
158
+ self.norm1 = norm_layer(dim)
159
+ self.attn = Attention(
160
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
161
+ attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim
162
+ )
163
+
164
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
165
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
166
+ self.norm2 = norm_layer(dim)
167
+ mlp_hidden_dim = int(dim * mlp_ratio)
168
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
169
+
170
+ def forward(self, x):
171
+ x = x + self.drop_path(self.attn(self.norm1(x)))
172
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
173
+ return x
174
+
175
+
176
+ class PatchEmbed(nn.Module):
177
+ """ Image to Patch Embedding
178
+ """
179
+
180
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, ratio=1):
181
+ super().__init__()
182
+ img_size = to_2tuple(img_size)
183
+ patch_size = to_2tuple(patch_size)
184
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (ratio ** 2)
185
+ self.patch_shape = (int(img_size[0] // patch_size[0] * ratio), int(img_size[1] // patch_size[1] * ratio))
186
+ self.origin_patch_shape = (int(img_size[0] // patch_size[0]), int(img_size[1] // patch_size[1]))
187
+ self.img_size = img_size
188
+ self.patch_size = patch_size
189
+ self.num_patches = num_patches
190
+
191
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=(patch_size[0] // ratio),
192
+ padding=4 + 2 * (ratio // 2 - 1))
193
+
194
+ def forward(self, x, **kwargs):
195
+ B, C, H, W = x.shape
196
+ x = self.proj(x)
197
+ Hp, Wp = x.shape[2], x.shape[3]
198
+
199
+ x = x.flatten(2).transpose(1, 2)
200
+ return x, (Hp, Wp)
201
+
202
+
203
+ class HybridEmbed(nn.Module):
204
+ """ CNN Feature Map Embedding
205
+ Extract feature map from CNN, flatten, project to embedding dim.
206
+ """
207
+
208
+ def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
209
+ super().__init__()
210
+ assert isinstance(backbone, nn.Module)
211
+ img_size = to_2tuple(img_size)
212
+ self.img_size = img_size
213
+ self.backbone = backbone
214
+ if feature_size is None:
215
+ with torch.no_grad():
216
+ training = backbone.training
217
+ if training:
218
+ backbone.eval()
219
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
220
+ feature_size = o.shape[-2:]
221
+ feature_dim = o.shape[1]
222
+ backbone.train(training)
223
+ else:
224
+ feature_size = to_2tuple(feature_size)
225
+ feature_dim = self.backbone.feature_info.channels()[-1]
226
+ self.num_patches = feature_size[0] * feature_size[1]
227
+ self.proj = nn.Linear(feature_dim, embed_dim)
228
+
229
+ def forward(self, x):
230
+ x = self.backbone(x)[-1]
231
+ x = x.flatten(2).transpose(1, 2)
232
+ x = self.proj(x)
233
+ return x
234
+
235
+
236
+ class ViT(nn.Module):
237
+
238
+ def __init__(self,
239
+ img_size=224, patch_size=16, in_chans=3, num_classes=80, embed_dim=768, depth=12,
240
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
241
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=None, use_checkpoint=False,
242
+ frozen_stages=-1, ratio=1, last_norm=True, use_cls=False,
243
+ patch_padding='pad', freeze_attn=False, freeze_ffn=False,
244
+ ):
245
+ # Protect mutable default arguments
246
+ super(ViT, self).__init__()
247
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
248
+ self.num_classes = num_classes
249
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
250
+ self.frozen_stages = frozen_stages
251
+ self.use_checkpoint = use_checkpoint
252
+ self.patch_padding = patch_padding
253
+ self.freeze_attn = freeze_attn
254
+ self.freeze_ffn = freeze_ffn
255
+ self.depth = depth
256
+
257
+ if hybrid_backbone is not None:
258
+ self.patch_embed = HybridEmbed(
259
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
260
+ else:
261
+ self.patch_embed = PatchEmbed(
262
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, ratio=ratio)
263
+ num_patches = self.patch_embed.num_patches
264
+
265
+ # since the pretraining model has class token
266
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
267
+
268
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
269
+
270
+ self.blocks = nn.ModuleList([
271
+ Block(
272
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
273
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
274
+ )
275
+ for i in range(depth)])
276
+
277
+ self.last_norm = norm_layer(embed_dim) if last_norm else nn.Identity()
278
+
279
+ if self.pos_embed is not None:
280
+ trunc_normal_(self.pos_embed, std=.02)
281
+
282
+ self.use_cls = use_cls
283
+ if use_cls:
284
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
285
+ nn.init.normal_(self.cls_token, std=1e-6)
286
+ else:
287
+ self.cls_token = None
288
+
289
+ self._freeze_stages()
290
+
291
+ def _freeze_stages(self):
292
+ """Freeze parameters."""
293
+ if self.frozen_stages >= 0:
294
+ self.patch_embed.eval()
295
+ for param in self.patch_embed.parameters():
296
+ param.requires_grad = False
297
+
298
+ for i in range(1, self.frozen_stages + 1):
299
+ m = self.blocks[i]
300
+ m.eval()
301
+ for param in m.parameters():
302
+ param.requires_grad = False
303
+
304
+ if self.freeze_attn:
305
+ for i in range(0, self.depth):
306
+ m = self.blocks[i]
307
+ m.attn.eval()
308
+ m.norm1.eval()
309
+ for param in m.attn.parameters():
310
+ param.requires_grad = False
311
+ for param in m.norm1.parameters():
312
+ param.requires_grad = False
313
+
314
+ if self.freeze_ffn:
315
+ self.pos_embed.requires_grad = False
316
+ self.patch_embed.eval()
317
+ for param in self.patch_embed.parameters():
318
+ param.requires_grad = False
319
+ for i in range(0, self.depth):
320
+ m = self.blocks[i]
321
+ m.mlp.eval()
322
+ m.norm2.eval()
323
+ for param in m.mlp.parameters():
324
+ param.requires_grad = False
325
+ for param in m.norm2.parameters():
326
+ param.requires_grad = False
327
+
328
+ def init_weights(self):
329
+ """Initialize the weights in backbone.
330
+ Args:
331
+ pretrained (str, optional): Path to pre-trained weights.
332
+ Defaults to None.
333
+ """
334
+
335
+ def _init_weights(m):
336
+ if isinstance(m, nn.Linear):
337
+ trunc_normal_(m.weight, std=.02)
338
+ if isinstance(m, nn.Linear) and m.bias is not None:
339
+ nn.init.constant_(m.bias, 0)
340
+ elif isinstance(m, nn.LayerNorm):
341
+ nn.init.constant_(m.bias, 0)
342
+ nn.init.constant_(m.weight, 1.0)
343
+
344
+ self.apply(_init_weights)
345
+
346
+ def get_num_layers(self):
347
+ return len(self.blocks)
348
+
349
+ @torch.jit.ignore
350
+ def no_weight_decay(self):
351
+ return {'pos_embed', 'cls_token'}
352
+
353
+ def forward_features(self, x):
354
+ B, C, H, W = x.shape
355
+ x, (Hp, Wp) = self.patch_embed(x)
356
+
357
+ if self.pos_embed is not None:
358
+ # fit for multiple GPU training
359
+ # since the first element for pos embed (sin-cos manner) is zero, it will cause no difference
360
+ x = x + self.pos_embed[:, 1:] + self.pos_embed[:, :1]
361
+
362
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) if self.use_cls else x
363
+ for blk in self.blocks:
364
+ if self.use_checkpoint:
365
+ x = checkpoint.checkpoint(blk, x)
366
+ else:
367
+ x = blk(x)
368
+
369
+ x = self.last_norm(x)
370
+
371
+ cls = x[:, 0] if self.use_cls else None
372
+ x = x[:, 1:] if self.use_cls else x
373
+ xp = x.permute(0, 2, 1).reshape(B, -1, Hp, Wp).contiguous()
374
+
375
+ return xp, cls
376
+
377
+ def forward(self, x):
378
+ x, cls = self.forward_features(x)
379
+ return x, cls
380
+
381
+ def train(self, mode=True):
382
+ """Convert the model into training mode."""
383
+ super().train(mode)
384
+ self._freeze_stages()
amr/models/components/__init__.py ADDED
File without changes
amr/models/components/pose_transformer.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from einops.layers.torch import Rearrange
7
+ from torch import nn
8
+
9
+ from .t_cond_mlp import (
10
+ AdaptiveLayerNorm1D,
11
+ FrequencyEmbedder,
12
+ normalization_layer,
13
+ )
14
+ # from .vit import Attention, FeedForward
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if isfunction(d) else d
25
+
26
+
27
+ class PreNorm(nn.Module):
28
+ def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
29
+ super().__init__()
30
+ self.norm = normalization_layer(norm, dim, norm_cond_dim)
31
+ self.fn = fn
32
+
33
+ def forward(self, x: torch.Tensor, *args, **kwargs):
34
+ if isinstance(self.norm, AdaptiveLayerNorm1D):
35
+ return self.fn(self.norm(x, *args), **kwargs)
36
+ else:
37
+ return self.fn(self.norm(x), **kwargs)
38
+
39
+
40
+ class FeedForward(nn.Module):
41
+ def __init__(self, dim, hidden_dim, dropout=0.0):
42
+ super().__init__()
43
+ self.net = nn.Sequential(
44
+ nn.Linear(dim, hidden_dim),
45
+ nn.GELU(),
46
+ nn.Dropout(dropout),
47
+ nn.Linear(hidden_dim, dim),
48
+ nn.Dropout(dropout),
49
+ )
50
+
51
+ def forward(self, x):
52
+ return self.net(x)
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
57
+ super().__init__()
58
+ inner_dim = dim_head * heads
59
+ project_out = not (heads == 1 and dim_head == dim)
60
+
61
+ self.heads = heads
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.attend = nn.Softmax(dim=-1)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
68
+
69
+ self.to_out = (
70
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
71
+ if project_out
72
+ else nn.Identity()
73
+ )
74
+
75
+ def forward(self, x):
76
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
77
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
78
+
79
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
80
+
81
+ attn = self.attend(dots)
82
+ attn = self.dropout(attn)
83
+
84
+ out = torch.matmul(attn, v)
85
+ out = rearrange(out, "b h n d -> b n (h d)")
86
+ return self.to_out(out)
87
+
88
+
89
+ class CrossAttention(nn.Module):
90
+ def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
91
+ super().__init__()
92
+ inner_dim = dim_head * heads
93
+ project_out = not (heads == 1 and dim_head == dim)
94
+
95
+ self.heads = heads
96
+ self.scale = dim_head**-0.5
97
+
98
+ self.attend = nn.Softmax(dim=-1)
99
+ self.dropout = nn.Dropout(dropout)
100
+
101
+ context_dim = default(context_dim, dim)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
103
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
104
+
105
+ self.to_out = (
106
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
107
+ if project_out
108
+ else nn.Identity()
109
+ )
110
+
111
+ def forward(self, x, context=None):
112
+ context = default(context, x)
113
+ k, v = self.to_kv(context).chunk(2, dim=-1)
114
+ q = self.to_q(x)
115
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
116
+
117
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
118
+
119
+ attn = self.attend(dots)
120
+ attn = self.dropout(attn)
121
+
122
+ out = torch.matmul(attn, v)
123
+ out = rearrange(out, "b h n d -> b n (h d)")
124
+ return self.to_out(out)
125
+
126
+
127
+ class Transformer(nn.Module):
128
+ def __init__(
129
+ self,
130
+ dim: int,
131
+ depth: int,
132
+ heads: int,
133
+ dim_head: int,
134
+ mlp_dim: int,
135
+ dropout: float = 0.0,
136
+ norm: str = "layer",
137
+ norm_cond_dim: int = -1,
138
+ ):
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([])
141
+ for _ in range(depth):
142
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
143
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
144
+ self.layers.append(
145
+ nn.ModuleList(
146
+ [
147
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
148
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
149
+ ]
150
+ )
151
+ )
152
+
153
+ def forward(self, x: torch.Tensor, *args):
154
+ for attn, ff in self.layers:
155
+ x = attn(x, *args) + x
156
+ x = ff(x, *args) + x
157
+ return x
158
+
159
+
160
+ class TransformerCrossAttn(nn.Module):
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ depth: int,
165
+ heads: int,
166
+ dim_head: int,
167
+ mlp_dim: int,
168
+ dropout: float = 0.0,
169
+ norm: str = "layer",
170
+ norm_cond_dim: int = -1,
171
+ context_dim: Optional[int] = None,
172
+ ):
173
+ super().__init__()
174
+ self.layers = nn.ModuleList([])
175
+ for _ in range(depth):
176
+ sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
177
+ ca = CrossAttention(
178
+ dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
179
+ )
180
+ ff = FeedForward(dim, mlp_dim, dropout=dropout)
181
+ self.layers.append(
182
+ nn.ModuleList(
183
+ [
184
+ PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
185
+ PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
186
+ PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
187
+ ]
188
+ )
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
192
+ if context_list is None:
193
+ context_list = [context] * len(self.layers)
194
+ if len(context_list) != len(self.layers):
195
+ raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
196
+
197
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
198
+ x = self_attn(x, *args) + x
199
+ x = cross_attn(x, *args, context=context_list[i]) + x
200
+ x = ff(x, *args) + x
201
+ return x
202
+
203
+
204
+ class DropTokenDropout(nn.Module):
205
+ def __init__(self, p: float = 0.1):
206
+ super().__init__()
207
+ if p < 0 or p > 1:
208
+ raise ValueError(
209
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
210
+ )
211
+ self.p = p
212
+
213
+ def forward(self, x: torch.Tensor):
214
+ # x: (batch_size, seq_len, dim)
215
+ if self.training and self.p > 0:
216
+ zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
217
+ # TODO: permutation idx for each batch using torch.argsort
218
+ if zero_mask.any():
219
+ x = x[:, ~zero_mask, :]
220
+ return x
221
+
222
+
223
+ class ZeroTokenDropout(nn.Module):
224
+ def __init__(self, p: float = 0.1):
225
+ super().__init__()
226
+ if p < 0 or p > 1:
227
+ raise ValueError(
228
+ "dropout probability has to be between 0 and 1, " "but got {}".format(p)
229
+ )
230
+ self.p = p
231
+
232
+ def forward(self, x: torch.Tensor):
233
+ # x: (batch_size, seq_len, dim)
234
+ if self.training and self.p > 0:
235
+ zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
236
+ # Zero-out the masked tokens
237
+ x[zero_mask, :] = 0
238
+ return x
239
+
240
+
241
+ class TransformerEncoder(nn.Module):
242
+ def __init__(
243
+ self,
244
+ num_tokens: int,
245
+ token_dim: int,
246
+ dim: int,
247
+ depth: int,
248
+ heads: int,
249
+ mlp_dim: int,
250
+ dim_head: int = 64,
251
+ dropout: float = 0.0,
252
+ emb_dropout: float = 0.0,
253
+ emb_dropout_type: str = "drop",
254
+ emb_dropout_loc: str = "token",
255
+ norm: str = "layer",
256
+ norm_cond_dim: int = -1,
257
+ token_pe_numfreq: int = -1,
258
+ ):
259
+ super().__init__()
260
+ if token_pe_numfreq > 0:
261
+ token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
262
+ self.to_token_embedding = nn.Sequential(
263
+ Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
264
+ FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
265
+ Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
266
+ nn.Linear(token_dim_new, dim),
267
+ )
268
+ else:
269
+ self.to_token_embedding = nn.Linear(token_dim, dim)
270
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
271
+ if emb_dropout_type == "drop":
272
+ self.dropout = DropTokenDropout(emb_dropout)
273
+ elif emb_dropout_type == "zero":
274
+ self.dropout = ZeroTokenDropout(emb_dropout)
275
+ else:
276
+ raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
277
+ self.emb_dropout_loc = emb_dropout_loc
278
+
279
+ self.transformer = Transformer(
280
+ dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
281
+ )
282
+
283
+ def forward(self, inp: torch.Tensor, *args, **kwargs):
284
+ x = inp
285
+
286
+ if self.emb_dropout_loc == "input":
287
+ x = self.dropout(x)
288
+ x = self.to_token_embedding(x)
289
+
290
+ if self.emb_dropout_loc == "token":
291
+ x = self.dropout(x)
292
+ b, n, _ = x.shape
293
+ x += self.pos_embedding[:, :n]
294
+
295
+ if self.emb_dropout_loc == "token_afterpos":
296
+ x = self.dropout(x)
297
+ x = self.transformer(x, *args)
298
+ return x
299
+
300
+
301
+ class TransformerDecoder(nn.Module):
302
+ def __init__(
303
+ self,
304
+ num_tokens: int,
305
+ token_dim: int,
306
+ dim: int,
307
+ depth: int,
308
+ heads: int,
309
+ mlp_dim: int,
310
+ dim_head: int = 64,
311
+ dropout: float = 0.0,
312
+ emb_dropout: float = 0.0,
313
+ emb_dropout_type: str = 'drop',
314
+ norm: str = "layer",
315
+ norm_cond_dim: int = -1,
316
+ context_dim: Optional[int] = None,
317
+ skip_token_embedding: bool = False,
318
+ ):
319
+ super().__init__()
320
+ if not skip_token_embedding:
321
+ self.to_token_embedding = nn.Linear(token_dim, dim)
322
+ else:
323
+ self.to_token_embedding = nn.Identity()
324
+ if token_dim != dim:
325
+ raise ValueError(
326
+ f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
327
+ )
328
+
329
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
330
+ if emb_dropout_type == "drop":
331
+ self.dropout = DropTokenDropout(emb_dropout)
332
+ elif emb_dropout_type == "zero":
333
+ self.dropout = ZeroTokenDropout(emb_dropout)
334
+ elif emb_dropout_type == "normal":
335
+ self.dropout = nn.Dropout(emb_dropout)
336
+
337
+ self.transformer = TransformerCrossAttn(
338
+ dim,
339
+ depth,
340
+ heads,
341
+ dim_head,
342
+ mlp_dim,
343
+ dropout,
344
+ norm=norm,
345
+ norm_cond_dim=norm_cond_dim,
346
+ context_dim=context_dim,
347
+ )
348
+
349
+ def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
350
+ x = self.to_token_embedding(inp)
351
+ b, n, _ = x.shape
352
+
353
+ x = self.dropout(x)
354
+ x += self.pos_embedding[:, :n]
355
+
356
+ x = self.transformer(x, *args, context=context, context_list=context_list)
357
+ return x
358
+
amr/models/components/t_cond_mlp.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+
6
+
7
+ class AdaptiveLayerNorm1D(torch.nn.Module):
8
+ def __init__(self, data_dim: int, norm_cond_dim: int):
9
+ super().__init__()
10
+ if data_dim <= 0:
11
+ raise ValueError(f"data_dim must be positive, but got {data_dim}")
12
+ if norm_cond_dim <= 0:
13
+ raise ValueError(f"norm_cond_dim must be positive, but got {norm_cond_dim}")
14
+ self.norm = torch.nn.LayerNorm(
15
+ data_dim
16
+ ) # TODO: Check if elementwise_affine=True is correct
17
+ self.linear = torch.nn.Linear(norm_cond_dim, 2 * data_dim)
18
+ torch.nn.init.zeros_(self.linear.weight)
19
+ torch.nn.init.zeros_(self.linear.bias)
20
+
21
+ def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
22
+ # x: (batch, ..., data_dim)
23
+ # t: (batch, norm_cond_dim)
24
+ # return: (batch, data_dim)
25
+ x = self.norm(x)
26
+ alpha, beta = self.linear(t).chunk(2, dim=-1)
27
+
28
+ # Add singleton dimensions to alpha and beta
29
+ if x.dim() > 2:
30
+ alpha = alpha.view(alpha.shape[0], *([1] * (x.dim() - 2)), alpha.shape[1])
31
+ beta = beta.view(beta.shape[0], *([1] * (x.dim() - 2)), beta.shape[1])
32
+
33
+ return x * (1 + alpha) + beta
34
+
35
+
36
+ class SequentialCond(torch.nn.Sequential):
37
+ def forward(self, input, *args, **kwargs):
38
+ for module in self:
39
+ if isinstance(module, (AdaptiveLayerNorm1D, SequentialCond, ResidualMLPBlock)):
40
+ # print(f'Passing on args to {module}', [a.shape for a in args])
41
+ input = module(input, *args, **kwargs)
42
+ else:
43
+ # print(f'Skipping passing args to {module}', [a.shape for a in args])
44
+ input = module(input)
45
+ return input
46
+
47
+
48
+ def normalization_layer(norm: Optional[str], dim: int, norm_cond_dim: int = -1):
49
+ if norm == "batch":
50
+ return torch.nn.BatchNorm1d(dim)
51
+ elif norm == "layer":
52
+ return torch.nn.LayerNorm(dim)
53
+ elif norm == "ada":
54
+ assert norm_cond_dim > 0, f"norm_cond_dim must be positive, got {norm_cond_dim}"
55
+ return AdaptiveLayerNorm1D(dim, norm_cond_dim)
56
+ elif norm is None:
57
+ return torch.nn.Identity()
58
+ else:
59
+ raise ValueError(f"Unknown norm: {norm}")
60
+
61
+
62
+ def linear_norm_activ_dropout(
63
+ input_dim: int,
64
+ output_dim: int,
65
+ activation: torch.nn.Module = torch.nn.ReLU(),
66
+ bias: bool = True,
67
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
68
+ dropout: float = 0.0,
69
+ norm_cond_dim: int = -1,
70
+ ) -> SequentialCond:
71
+ layers = []
72
+ layers.append(torch.nn.Linear(input_dim, output_dim, bias=bias))
73
+ if norm is not None:
74
+ layers.append(normalization_layer(norm, output_dim, norm_cond_dim))
75
+ layers.append(copy.deepcopy(activation))
76
+ if dropout > 0.0:
77
+ layers.append(torch.nn.Dropout(dropout))
78
+ return SequentialCond(*layers)
79
+
80
+
81
+ def create_simple_mlp(
82
+ input_dim: int,
83
+ hidden_dims: List[int],
84
+ output_dim: int,
85
+ activation: torch.nn.Module = torch.nn.ReLU(),
86
+ bias: bool = True,
87
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
88
+ dropout: float = 0.0,
89
+ norm_cond_dim: int = -1,
90
+ ) -> SequentialCond:
91
+ layers = []
92
+ prev_dim = input_dim
93
+ for hidden_dim in hidden_dims:
94
+ layers.extend(
95
+ linear_norm_activ_dropout(
96
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
97
+ )
98
+ )
99
+ prev_dim = hidden_dim
100
+ layers.append(torch.nn.Linear(prev_dim, output_dim, bias=bias))
101
+ return SequentialCond(*layers)
102
+
103
+
104
+ class ResidualMLPBlock(torch.nn.Module):
105
+ def __init__(
106
+ self,
107
+ input_dim: int,
108
+ hidden_dim: int,
109
+ num_hidden_layers: int,
110
+ output_dim: int,
111
+ activation: torch.nn.Module = torch.nn.ReLU(),
112
+ bias: bool = True,
113
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
114
+ dropout: float = 0.0,
115
+ norm_cond_dim: int = -1,
116
+ ):
117
+ super().__init__()
118
+ if not (input_dim == output_dim == hidden_dim):
119
+ raise NotImplementedError(
120
+ f"input_dim {input_dim} != output_dim {output_dim} is not implemented"
121
+ )
122
+
123
+ layers = []
124
+ prev_dim = input_dim
125
+ for i in range(num_hidden_layers):
126
+ layers.append(
127
+ linear_norm_activ_dropout(
128
+ prev_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
129
+ )
130
+ )
131
+ prev_dim = hidden_dim
132
+ self.model = SequentialCond(*layers)
133
+ self.skip = torch.nn.Identity()
134
+
135
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
136
+ return x + self.model(x, *args, **kwargs)
137
+
138
+
139
+ class ResidualMLP(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ input_dim: int,
143
+ hidden_dim: int,
144
+ num_hidden_layers: int,
145
+ output_dim: int,
146
+ activation: torch.nn.Module = torch.nn.ReLU(),
147
+ bias: bool = True,
148
+ norm: Optional[str] = "layer", # Options: ada/batch/layer
149
+ dropout: float = 0.0,
150
+ num_blocks: int = 1,
151
+ norm_cond_dim: int = -1,
152
+ ):
153
+ super().__init__()
154
+ self.input_dim = input_dim
155
+ self.model = SequentialCond(
156
+ linear_norm_activ_dropout(
157
+ input_dim, hidden_dim, activation, bias, norm, dropout, norm_cond_dim
158
+ ),
159
+ *[
160
+ ResidualMLPBlock(
161
+ hidden_dim,
162
+ hidden_dim,
163
+ num_hidden_layers,
164
+ hidden_dim,
165
+ activation,
166
+ bias,
167
+ norm,
168
+ dropout,
169
+ norm_cond_dim,
170
+ )
171
+ for _ in range(num_blocks)
172
+ ],
173
+ torch.nn.Linear(hidden_dim, output_dim, bias=bias),
174
+ )
175
+
176
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
177
+ return self.model(x, *args, **kwargs)
178
+
179
+
180
+ class FrequencyEmbedder(torch.nn.Module):
181
+ def __init__(self, num_frequencies, max_freq_log2):
182
+ super().__init__()
183
+ frequencies = 2 ** torch.linspace(0, max_freq_log2, steps=num_frequencies)
184
+ self.register_buffer("frequencies", frequencies)
185
+
186
+ def forward(self, x):
187
+ # x should be of size (N,) or (N, D)
188
+ N = x.size(0)
189
+ if x.dim() == 1: # (N,)
190
+ x = x.unsqueeze(1) # (N, D) where D=1
191
+ x_unsqueezed = x.unsqueeze(-1) # (N, D, 1)
192
+ scaled = self.frequencies.view(1, 1, -1) * x_unsqueezed # (N, D, num_frequencies)
193
+ s = torch.sin(scaled)
194
+ c = torch.cos(scaled)
195
+ embedded = torch.cat([s, c, x_unsqueezed], dim=-1).view(
196
+ N, -1
197
+ ) # (N, D * 2 * num_frequencies + D)
198
+ return embedded
199
+
amr/models/heads/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .smal_head import build_smal_head
amr/models/heads/smal_head.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import einops
5
+ from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat
6
+ from ..components.pose_transformer import TransformerDecoder
7
+
8
+
9
+ def build_smal_head(cfg):
10
+ smal_head_type = cfg.MODEL.SMAL_HEAD.get('TYPE', 'amr')
11
+ if smal_head_type == 'transformer_decoder':
12
+ return SMALTransformerDecoderHead(cfg)
13
+ else:
14
+ raise ValueError('Unknown SMAL head type: {}'.format(smal_head_type))
15
+
16
+
17
+ class SMALTransformerDecoderHead(nn.Module):
18
+ """ Cross-attention based SMAL Transformer decoder
19
+ """
20
+ # Cat (e.g. House Cat/Tiger/Lion), Canine (e.g. Dog/Wolf), Equine (e.g. Horse/Zebra), Bovine (e.g. Cow), Hippo
21
+ def __init__(self, cfg):
22
+ super().__init__()
23
+ self.cfg = cfg
24
+ self.joint_rep_type = cfg.MODEL.SMAL_HEAD.get('JOINT_REP', '6d')
25
+ self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type]
26
+ npose = self.joint_rep_dim * (cfg.SMAL.NUM_JOINTS + 1)
27
+ self.npose = npose
28
+ self.input_is_mean_shape = cfg.MODEL.SMAL_HEAD.get('TRANSFORMER_INPUT', 'zero') == 'mean_shape'
29
+ transformer_args = dict(
30
+ num_tokens=1,
31
+ token_dim=(npose + 10 + 3) if self.input_is_mean_shape else 1,
32
+ dim=1024,
33
+ )
34
+
35
+ # transformer_args = (transformer_args | dict(cfg.MODEL.SMAL_HEAD.TRANSFORMER_DECODER))
36
+ # For compatibility
37
+ transformer_args = {**transformer_args, **dict(cfg.MODEL.SMAL_HEAD.TRANSFORMER_DECODER)}
38
+
39
+ self.transformer = TransformerDecoder(
40
+ **transformer_args
41
+ )
42
+ dim = transformer_args['dim']
43
+ self.decpose = nn.Linear(dim, npose)
44
+ self.decshape = nn.Linear(dim, 41)
45
+ self.deccam = nn.Linear(dim, 3)
46
+
47
+ if cfg.MODEL.SMAL_HEAD.get('INIT_DECODER_XAVIER', False):
48
+ # True by default in MLP. False by default in Transformer
49
+ nn.init.xavier_uniform_(self.decpose.weight, gain=0.01)
50
+ nn.init.xavier_uniform_(self.decshape.weight, gain=0.01)
51
+ nn.init.xavier_uniform_(self.deccam.weight, gain=0.01)
52
+
53
+ init_pose = torch.zeros(size=(1, npose), dtype=torch.float32)
54
+ init_betas = torch.zeros(size=(1, 41), dtype=torch.float32)
55
+ init_cam = torch.zeros(size=(1, 3), dtype=torch.float32)
56
+ self.register_buffer('init_pose', init_pose)
57
+ self.register_buffer('init_betas', init_betas)
58
+ self.register_buffer('init_cam', init_cam)
59
+
60
+ def forward(self, x, **kwargs):
61
+ batch_size = x.shape[0]
62
+ # category = kwargs["category"]
63
+ # vit pretrained backbone is channel-first. Change to token-first
64
+ x = einops.rearrange(x, 'b c h w -> b (h w) c') if len(x.shape) == 4 else x
65
+
66
+ init_pose = self.init_pose.expand(batch_size, -1)
67
+ init_betas = self.init_betas.expand(batch_size, -1) if not self.cfg.MODEL.SMAL_HEAD.get("RES", False) else torch.mean(self.init_betas, dim=0, keepdim=True)
68
+ # self.init_betas[kwargs["category"]]
69
+ init_cam = self.init_cam.expand(batch_size, -1)
70
+
71
+ pred_pose = init_pose
72
+ pred_betas = init_betas
73
+ pred_cam = init_cam
74
+ pred_pose_list = []
75
+ pred_betas_list = []
76
+ pred_cam_list = []
77
+ for i in range(self.cfg.MODEL.SMAL_HEAD.get('IEF_ITERS', 3)):
78
+ # Input token to transformer is zero token
79
+ if self.input_is_mean_shape:
80
+ token = torch.cat([pred_pose, pred_betas, pred_cam], dim=1)[:, None, :]
81
+ else:
82
+ token = torch.zeros(batch_size, 1, 1).to(x.device)
83
+
84
+ # Pass through transformer
85
+ token_out = self.transformer(token, context=x)
86
+ token_out = token_out.squeeze(1) # (B, C)
87
+
88
+ # Readout from token_out
89
+ pred_pose = self.decpose(token_out) + pred_pose
90
+ pred_betas = self.decshape(token_out) + pred_betas
91
+ pred_cam = self.deccam(token_out) + pred_cam
92
+ pred_pose_list.append(pred_pose)
93
+ pred_betas_list.append(pred_betas)
94
+ pred_cam_list.append(pred_cam)
95
+
96
+ # Convert self.joint_rep_type -> rotmat
97
+ joint_conversion_fn = {
98
+ '6d': rot6d_to_rotmat,
99
+ 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous())
100
+ }[self.joint_rep_type]
101
+
102
+ pred_smal_params_list = {}
103
+ pred_smal_params_list['pose'] = torch.cat(
104
+ [joint_conversion_fn(pbp).view(batch_size, -1, 3, 3)[:, 1:, :, :] for pbp in pred_pose_list], dim=0)
105
+ pred_smal_params_list['betas'] = torch.cat(pred_betas_list, dim=0)
106
+ pred_smal_params_list['cam'] = torch.cat(pred_cam_list, dim=0)
107
+ pred_pose = joint_conversion_fn(pred_pose).view(batch_size, self.cfg.SMAL.NUM_JOINTS + 1, 3, 3)
108
+
109
+ pred_smal_params = {'global_orient': pred_pose[:, [0]],
110
+ 'pose': pred_pose[:, 1:],
111
+ 'betas': pred_betas,
112
+ }
113
+ return pred_smal_params, pred_cam, pred_smal_params_list
114
+
115
+
116
+
amr/models/smal_warapper.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from torch import nn
3
+ import torch
4
+ import numpy as np
5
+ import pickle
6
+ import cv2
7
+ from typing import Optional, Tuple, NewType
8
+ from dataclasses import dataclass
9
+ import smplx
10
+ from smplx.lbs import vertices2joints, lbs
11
+ from smplx.utils import MANOOutput, to_tensor, ModelOutput
12
+ from smplx.vertex_ids import vertex_ids
13
+
14
+ Tensor = NewType('Tensor', torch.Tensor)
15
+ keypoint_vertices_idx = [[1068, 1080, 1029, 1226], [2660, 3030, 2675, 3038], [910], [360, 1203, 1235, 1230],
16
+ [3188, 3156, 2327, 3183], [1976, 1974, 1980, 856], [3854, 2820, 3852, 3858], [452, 1811],
17
+ [416, 235, 182], [2156, 2382, 2203], [829], [2793], [60, 114, 186, 59],
18
+ [2091, 2037, 2036, 2160], [384, 799, 1169, 431], [2351, 2763, 2397, 3127],
19
+ [221, 104], [2754, 2192], [191, 1158, 3116, 2165],
20
+ [28, 1109, 1110, 1111, 1835, 1836, 3067, 3068, 3069],
21
+ [498, 499, 500, 501, 502, 503], [2463, 2464, 2465, 2466, 2467, 2468],
22
+ [764, 915, 916, 917, 934, 935, 956], [2878, 2879, 2880, 2897, 2898, 2919, 3751],
23
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762],
24
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]]
25
+
26
+ name2id35 = {'RFoot': 14, 'RFootBack': 24, 'spine1': 4, 'Head': 16, 'LLegBack3': 19, 'RLegBack1': 21, 'pelvis0': 1,
27
+ 'RLegBack3': 23, 'LLegBack2': 18, 'spine0': 3, 'spine3': 6, 'spine2': 5, 'Mouth': 32, 'Neck': 15,
28
+ 'LFootBack': 20, 'LLegBack1': 17, 'RLeg3': 13, 'RLeg2': 12, 'LLeg1': 7, 'LLeg3': 9, 'RLeg1': 11,
29
+ 'LLeg2': 8, 'spine': 2, 'LFoot': 10, 'Tail7': 31, 'Tail6': 30, 'Tail5': 29, 'Tail4': 28, 'Tail3': 27,
30
+ 'Tail2': 26, 'Tail1': 25, 'RLegBack2': 22, 'root': 0, 'LEar': 33, 'REar': 34, 'EndNose': 35, 'Chin': 36,
31
+ 'RightEarTip': 37, 'LeftEarTip': 38, 'LeftEye': 39, 'RightEye': 40}
32
+
33
+ @dataclass
34
+ class SMALOutput(ModelOutput):
35
+ betas: Optional[Tensor] = None
36
+ pose: Optional[Tensor] = None
37
+
38
+
39
+ class SMALLayer(nn.Module):
40
+ def __init__(self, num_betas=41, **kwargs):
41
+ super().__init__()
42
+ self.num_betas = num_betas
43
+ self.register_buffer("shapedirs", torch.from_numpy(np.array(kwargs['shapedirs'], dtype=np.float32))[:, :, :num_betas]) # [3889, 3, 41]
44
+ self.register_buffer("v_template", torch.from_numpy(np.array(kwargs['v_template']).astype(np.float32))) # [3889, 3]
45
+ self.register_buffer("posedirs", torch.from_numpy(np.array(kwargs['posedirs'], dtype=np.float32)).reshape(-1,
46
+ 34*9).T) # [34*9, 11667]
47
+ self.register_buffer("J_regressor", torch.from_numpy(kwargs['J_regressor'].toarray().astype(np.float32))) # [33, 3389]
48
+ self.register_buffer("lbs_weights", torch.from_numpy(np.array(kwargs['weights'], dtype=np.float32))) # [3889, 33]
49
+ self.register_buffer("faces", torch.from_numpy(np.array(kwargs['f'], dtype=np.int32))) # [7774, 3]
50
+
51
+ kintree_table = kwargs['kintree_table']
52
+ # self.register_buffer("parents", torch.from_numpy(kintree_table[0].astype(np.int32)))
53
+ id_to_col = {kintree_table[1, i]: i for i in range(kintree_table.shape[1])}
54
+ self.register_buffer("parents", torch.tensor([0] + [id_to_col[kintree_table[0, i]] for i in range(1, kintree_table.shape[1])],
55
+ dtype=torch.long))
56
+
57
+ def forward(
58
+ self,
59
+ betas: Optional[Tensor] = None,
60
+ global_orient: Optional[Tensor] = None,
61
+ pose: Optional[Tensor] = None,
62
+ transl: Optional[Tensor] = None,
63
+ return_verts: bool = True,
64
+ return_full_pose: bool = False,
65
+ **kwargs):
66
+ """
67
+ Args:
68
+ betas: [batch_size, 10]
69
+ global_orient: [batch_size, 1, 3, 3]
70
+ pose: [batch_size, num_joints, 3, 3]
71
+ transl: [batch_size, num_joints, 3]
72
+ return_verts:
73
+ return_full_pose:
74
+ **kwargs:
75
+ Returns:
76
+ """
77
+ device, dtype = betas.device, betas.dtype
78
+ if global_orient is None:
79
+ batch_size = 1
80
+ global_orient = torch.eye(3, device=device, dtype=dtype).view(
81
+ 1, 1, 3, 3).expand(batch_size, -1, -1, -1).contiguous()
82
+ else:
83
+ batch_size = global_orient.shape[0]
84
+ if pose is None:
85
+ pose = torch.eye(3, device=device, dtype=dtype).view(
86
+ 1, 1, 3, 3).expand(batch_size, 34, -1, -1).contiguous()
87
+ if betas is None:
88
+ betas = torch.zeros(
89
+ [batch_size, self.num_betas], dtype=dtype, device=device)
90
+ if transl is None:
91
+ transl = torch.zeros([batch_size, 3], dtype=dtype, device=device)
92
+
93
+ full_pose = torch.cat([global_orient, pose], dim=1)
94
+ vertices, joints = lbs(betas, full_pose, self.v_template,
95
+ self.shapedirs, self.posedirs,
96
+ self.J_regressor, self.parents,
97
+ self.lbs_weights, pose2rot=False)
98
+
99
+ if transl is not None:
100
+ joints = joints + transl.unsqueeze(dim=1)
101
+ vertices = vertices + transl.unsqueeze(dim=1)
102
+
103
+ output = SMALOutput(
104
+ vertices=vertices if return_verts else None,
105
+ joints=joints if return_verts else None,
106
+ betas=betas,
107
+ global_orient=global_orient,
108
+ pose=pose,
109
+ transl=transl,
110
+ full_pose=full_pose if return_full_pose else None,
111
+ )
112
+ return output
113
+
114
+
115
+ class SMAL(SMALLayer):
116
+ def __init__(self, **kwargs):
117
+ super(SMAL, self).__init__(**kwargs)
118
+
119
+ def forward(self, *args, **kwargs):
120
+ smal_output = super(SMAL, self).forward(**kwargs)
121
+
122
+ keypoint = []
123
+ for kp_v in keypoint_vertices_idx:
124
+ keypoint.append(smal_output.vertices[:, kp_v, :].mean(dim=1))
125
+ smal_output.joints = torch.stack(keypoint, dim=1)
126
+ return smal_output
127
+
128
+
amr/utils/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Any
3
+
4
+
5
+ def recursive_to(x: Any, target: torch.device):
6
+ """
7
+ Recursively transfer a batch of data to the target device
8
+ Args:
9
+ x (Any): Batch of data.
10
+ target (torch.device): Target device.
11
+ Returns:
12
+ Batch of data where all tensors are transfered to the target device.
13
+ """
14
+ if isinstance(x, dict):
15
+ return {k: recursive_to(v, target) for k, v in x.items()}
16
+ elif isinstance(x, torch.Tensor):
17
+ return x.to(target)
18
+ elif isinstance(x, list):
19
+ return [recursive_to(i, target) for i in x]
20
+ else:
21
+ return x
amr/utils/geometry.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def aa_to_rotmat(theta: torch.Tensor):
7
+ """
8
+ Convert axis-angle representation to rotation matrix.
9
+ Works by first converting it to a quaternion.
10
+ Args:
11
+ theta (torch.Tensor): Tensor of shape (B, 3) containing axis-angle representations.
12
+ Returns:
13
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
14
+ """
15
+ norm = torch.norm(theta + 1e-8, p=2, dim=1)
16
+ angle = torch.unsqueeze(norm, -1)
17
+ normalized = torch.div(theta, angle)
18
+ angle = angle * 0.5
19
+ v_cos = torch.cos(angle)
20
+ v_sin = torch.sin(angle)
21
+ quat = torch.cat([v_cos, v_sin * normalized], dim=1)
22
+ return quat_to_rotmat(quat)
23
+
24
+
25
+ def quat_to_rotmat(quat: torch.Tensor) -> torch.Tensor:
26
+ """
27
+ Convert quaternion representation to rotation matrix.
28
+ Args:
29
+ quat (torch.Tensor) of shape (B, 4); 4 <===> (w, x, y, z).
30
+ Returns:
31
+ torch.Tensor: Corresponding rotation matrices with shape (B, 3, 3).
32
+ """
33
+ norm_quat = quat
34
+ norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
35
+ w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]
36
+
37
+ B = quat.size(0)
38
+
39
+ w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
40
+ wx, wy, wz = w * x, w * y, w * z
41
+ xy, xz, yz = x * y, x * z, y * z
42
+
43
+ rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
44
+ 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
45
+ 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
46
+ return rotMat
47
+
48
+
49
+ def rot6d_to_rotmat(x: torch.Tensor) -> torch.Tensor:
50
+ """
51
+ Convert 6D rotation representation to 3x3 rotation matrix.
52
+ Based on Zhou et al., "On the Continuity of Rotation Representations in Neural Networks", CVPR 2019
53
+ Args:
54
+ x (torch.Tensor): (B,6) Batch of 6-D rotation representations.
55
+ Returns:
56
+ torch.Tensor: Batch of corresponding rotation matrices with shape (B,3,3).
57
+ """
58
+ x = x.reshape(-1, 2, 3).permute(0, 2, 1).contiguous()
59
+ a1 = x[:, :, 0]
60
+ a2 = x[:, :, 1]
61
+ b1 = F.normalize(a1)
62
+ b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
63
+ b3 = torch.cross(b1, b2, dim=1)
64
+ return torch.stack((b1, b2, b3), dim=-1)
65
+
66
+
67
+ def perspective_projection(points: torch.Tensor,
68
+ translation: torch.Tensor,
69
+ focal_length: torch.Tensor,
70
+ camera_center: Optional[torch.Tensor] = None,
71
+ rotation: Optional[torch.Tensor] = None) -> torch.Tensor:
72
+ """
73
+ Computes the perspective projection of a set of 3D points.
74
+ Args:
75
+ points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points.
76
+ translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation.
77
+ focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels.
78
+ camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels.
79
+ rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation.
80
+ Returns:
81
+ torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points.
82
+ """
83
+ batch_size = points.shape[0]
84
+ if rotation is None:
85
+ rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1)
86
+ if camera_center is None:
87
+ camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype)
88
+ # Populate intrinsic camera matrix K.
89
+ K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype)
90
+ K[:, 0, 0] = focal_length[:, 0]
91
+ K[:, 1, 1] = focal_length[:, 1]
92
+ K[:, 2, 2] = 1.
93
+ K[:, :-1, -1] = camera_center
94
+
95
+ # Transform points
96
+ points = torch.einsum('bij,bkj->bki', rotation, points)
97
+ points = points + translation.unsqueeze(1)
98
+
99
+ # Apply perspective distortion
100
+ projected_points = points / points[:, :, -1].unsqueeze(-1)
101
+
102
+ # Apply camera intrinsics
103
+ projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
104
+
105
+ return projected_points[:, :, :-1]
amr/utils/pylogger.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from pytorch_lightning.utilities import rank_zero_only
4
+
5
+
6
+ def get_pylogger(name=__name__) -> logging.Logger:
7
+ """Initializes multi-GPU-friendly python command line logger."""
8
+
9
+ logger = logging.getLogger(name)
10
+
11
+ # this ensures all logging levels get marked with the rank zero decorator
12
+ # otherwise logs would get multiplied for each GPU process in multi-GPU setup
13
+ logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
14
+ for level in logging_levels:
15
+ setattr(logger, level, rank_zero_only(getattr(logger, level)))
16
+
17
+ return logger
amr/utils/renderer.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ if 'PYOPENGL_PLATFORM' not in os.environ:
4
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
5
+ import torch
6
+ import numpy as np
7
+ import pyrender
8
+ import trimesh
9
+ import cv2
10
+ from yacs.config import CfgNode
11
+ from typing import List, Optional
12
+
13
+
14
+ def cam_crop_to_full(cam_bbox, box_center, box_size, img_size, focal_length=5000.):
15
+ # Convert cam_bbox to full image
16
+ img_w, img_h = img_size[:, 0], img_size[:, 1]
17
+ cx, cy, b = box_center[:, 0], box_center[:, 1], box_size
18
+ w_2, h_2 = img_w / 2., img_h / 2.
19
+ bs = b * cam_bbox[:, 0] + 1e-9
20
+ tz = 2 * focal_length / bs
21
+ tx = (2 * (cx - w_2) / bs) + cam_bbox[:, 1]
22
+ ty = (2 * (cy - h_2) / bs) + cam_bbox[:, 2]
23
+ full_cam = torch.stack([tx, ty, tz], dim=-1)
24
+ return full_cam
25
+
26
+
27
+ def get_light_poses(n_lights=5, elevation=np.pi / 3, dist=12):
28
+ # get lights in a circle around origin at elevation
29
+ thetas = elevation * np.ones(n_lights)
30
+ phis = 2 * np.pi * np.arange(n_lights) / n_lights
31
+ poses = []
32
+ trans = make_translation(torch.tensor([0, 0, dist]))
33
+ for phi, theta in zip(phis, thetas):
34
+ rot = make_rotation(rx=-theta, ry=phi, order="xyz")
35
+ poses.append((rot @ trans).numpy())
36
+ return poses
37
+
38
+
39
+ def make_translation(t):
40
+ return make_4x4_pose(torch.eye(3), t)
41
+
42
+
43
+ def make_rotation(rx=0, ry=0, rz=0, order="xyz"):
44
+ Rx = rotx(rx)
45
+ Ry = roty(ry)
46
+ Rz = rotz(rz)
47
+ if order == "xyz":
48
+ R = Rz @ Ry @ Rx
49
+ elif order == "xzy":
50
+ R = Ry @ Rz @ Rx
51
+ elif order == "yxz":
52
+ R = Rz @ Rx @ Ry
53
+ elif order == "yzx":
54
+ R = Rx @ Rz @ Ry
55
+ elif order == "zyx":
56
+ R = Rx @ Ry @ Rz
57
+ elif order == "zxy":
58
+ R = Ry @ Rx @ Rz
59
+ return make_4x4_pose(R, torch.zeros(3))
60
+
61
+
62
+ def make_4x4_pose(R, t):
63
+ """
64
+ :param R (*, 3, 3)
65
+ :param t (*, 3)
66
+ return (*, 4, 4)
67
+ """
68
+ dims = R.shape[:-2]
69
+ pose_3x4 = torch.cat([R, t.view(*dims, 3, 1)], dim=-1)
70
+ bottom = (
71
+ torch.tensor([0, 0, 0, 1], device=R.device)
72
+ .reshape(*(1,) * len(dims), 1, 4)
73
+ .expand(*dims, 1, 4)
74
+ )
75
+ return torch.cat([pose_3x4, bottom], dim=-2)
76
+
77
+
78
+ def rotx(theta):
79
+ return torch.tensor(
80
+ [
81
+ [1, 0, 0],
82
+ [0, np.cos(theta), -np.sin(theta)],
83
+ [0, np.sin(theta), np.cos(theta)],
84
+ ],
85
+ dtype=torch.float32,
86
+ )
87
+
88
+
89
+ def roty(theta):
90
+ return torch.tensor(
91
+ [
92
+ [np.cos(theta), 0, np.sin(theta)],
93
+ [0, 1, 0],
94
+ [-np.sin(theta), 0, np.cos(theta)],
95
+ ],
96
+ dtype=torch.float32,
97
+ )
98
+
99
+
100
+ def rotz(theta):
101
+ return torch.tensor(
102
+ [
103
+ [np.cos(theta), -np.sin(theta), 0],
104
+ [np.sin(theta), np.cos(theta), 0],
105
+ [0, 0, 1],
106
+ ],
107
+ dtype=torch.float32,
108
+ )
109
+
110
+
111
+ def create_raymond_lights() -> List[pyrender.Node]:
112
+ """
113
+ Return raymond light nodes for the scene.
114
+ """
115
+ thetas = np.pi * np.array([1.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0])
116
+ phis = np.pi * np.array([0.0, 2.0 / 3.0, 4.0 / 3.0])
117
+
118
+ nodes = []
119
+
120
+ for phi, theta in zip(phis, thetas):
121
+ xp = np.sin(theta) * np.cos(phi)
122
+ yp = np.sin(theta) * np.sin(phi)
123
+ zp = np.cos(theta)
124
+
125
+ z = np.array([xp, yp, zp])
126
+ z = z / np.linalg.norm(z)
127
+ x = np.array([-z[1], z[0], 0.0])
128
+ if np.linalg.norm(x) == 0:
129
+ x = np.array([1.0, 0.0, 0.0])
130
+ x = x / np.linalg.norm(x)
131
+ y = np.cross(z, x)
132
+
133
+ matrix = np.eye(4)
134
+ matrix[:3, :3] = np.c_[x, y, z]
135
+ nodes.append(pyrender.Node(
136
+ light=pyrender.DirectionalLight(color=np.ones(3), intensity=1.0),
137
+ matrix=matrix
138
+ ))
139
+
140
+ return nodes
141
+
142
+
143
+ class Renderer:
144
+
145
+ def __init__(self, cfg: CfgNode, faces: np.array):
146
+ """
147
+ Wrapper around the pyrender renderer to render MANO meshes.
148
+ Args:
149
+ cfg (CfgNode): Model config file.
150
+ faces (np.array): Array of shape (F, 3) containing the mesh faces.
151
+ """
152
+ self.cfg = cfg
153
+ self.focal_length = cfg.EXTRA.FOCAL_LENGTH
154
+ self.img_res = cfg.MODEL.IMAGE_SIZE
155
+
156
+ self.camera_center = [self.img_res // 2, self.img_res // 2]
157
+ self.faces = faces.cpu().numpy()
158
+
159
+ def __call__(self,
160
+ vertices: np.array,
161
+ camera_translation: np.array,
162
+ image: torch.Tensor,
163
+ full_frame: bool = False,
164
+ imgname: Optional[str] = None,
165
+ side_view=False, rot_angle=90,
166
+ mesh_base_color=(1.0, 1.0, 0.9),
167
+ scene_bg_color=(0, 0, 0),
168
+ return_rgba=False,
169
+ ) -> np.array:
170
+ """
171
+ Render meshes on input image
172
+ Args:
173
+ vertices (np.array): Array of shape (V, 3) containing the mesh vertices.
174
+ camera_translation (np.array): Array of shape (3,) with the camera translation.
175
+ image (torch.Tensor): Tensor of shape (3, H, W) containing the image crop with normalized pixel values.
176
+ full_frame (bool): If True, then render on the full image.
177
+ imgname (Optional[str]): Contains the original image filenamee. Used only if full_frame == True.
178
+ """
179
+
180
+ if full_frame:
181
+ image = cv2.imread(imgname).astype(np.float32)[:, :, ::-1] / 255.
182
+ else:
183
+ image = (image.clone() * 255.) * (torch.tensor(self.cfg.MODEL.IMAGE_STD, device=image.device).reshape(3, 1, 1) * 255.)
184
+ image = image + (torch.tensor(self.cfg.MODEL.IMAGE_MEAN, device=image.device).reshape(3, 1, 1) * 255)
185
+ image = image.permute(1, 2, 0).cpu().numpy() / 255.
186
+
187
+ renderer = pyrender.OffscreenRenderer(viewport_width=image.shape[1],
188
+ viewport_height=image.shape[0],
189
+ point_size=1.0)
190
+ material = pyrender.MetallicRoughnessMaterial(
191
+ metallicFactor=0.0,
192
+ alphaMode='OPAQUE',
193
+ baseColorFactor=(*mesh_base_color, 1.0))
194
+
195
+ camera_translation[0] *= -1.
196
+
197
+ mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
198
+ if side_view:
199
+ rot = trimesh.transformations.rotation_matrix(
200
+ np.radians(rot_angle), [0, 1, 0])
201
+ mesh.apply_transform(rot)
202
+ rot = trimesh.transformations.rotation_matrix(
203
+ np.radians(180), [1, 0, 0])
204
+ mesh.apply_transform(rot)
205
+ mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
206
+
207
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
208
+ ambient_light=(0.3, 0.3, 0.3))
209
+ scene.add(mesh, 'mesh')
210
+
211
+ camera_pose = np.eye(4)
212
+ camera_pose[:3, 3] = camera_translation
213
+ camera_center = [image.shape[1] / 2., image.shape[0] / 2.]
214
+ camera = pyrender.IntrinsicsCamera(fx=self.focal_length, fy=self.focal_length,
215
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
216
+ scene.add(camera, pose=camera_pose)
217
+
218
+ light_nodes = create_raymond_lights()
219
+ for node in light_nodes:
220
+ scene.add_node(node)
221
+
222
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
223
+ color = color.astype(np.float32) / 255.0
224
+ renderer.delete()
225
+
226
+ if return_rgba:
227
+ return color
228
+
229
+ valid_mask = (color[:, :, -1])[:, :, np.newaxis]
230
+ if not side_view:
231
+ output_img = (color[:, :, :3] * valid_mask + (1 - valid_mask) * image)
232
+ else:
233
+ output_img = color[:, :, :3]
234
+
235
+ output_img = output_img.astype(np.float32)
236
+ return output_img
237
+
238
+ def vertices_to_trimesh(self, vertices, camera_translation, mesh_base_color=(1.0, 1.0, 0.9),
239
+ rot_axis=[1, 0, 0], rot_angle=0):
240
+ # material = pyrender.MetallicRoughnessMaterial(
241
+ # metallicFactor=0.0,
242
+ # alphaMode='OPAQUE',
243
+ # baseColorFactor=(*mesh_base_color, 1.0))
244
+ vertex_colors = np.array([(*mesh_base_color, 1.0)] * vertices.shape[0])
245
+ mesh = trimesh.Trimesh(vertices.copy() + camera_translation, self.faces.copy(), vertex_colors=vertex_colors)
246
+ # mesh = trimesh.Trimesh(vertices.copy(), self.faces.copy())
247
+
248
+ rot = trimesh.transformations.rotation_matrix(
249
+ np.radians(rot_angle), rot_axis)
250
+ mesh.apply_transform(rot)
251
+
252
+ rot = trimesh.transformations.rotation_matrix(
253
+ np.radians(180), [1, 0, 0])
254
+ mesh.apply_transform(rot)
255
+ return mesh
256
+
257
+ def render_rgba(
258
+ self,
259
+ vertices: np.array,
260
+ cam_t=None,
261
+ rot=None,
262
+ rot_axis=[1, 0, 0],
263
+ rot_angle=0,
264
+ camera_z=3,
265
+ # camera_translation: np.array,
266
+ mesh_base_color=(1.0, 1.0, 0.9),
267
+ scene_bg_color=(0, 0, 0),
268
+ render_res=[256, 256],
269
+ focal_length=None,
270
+ ):
271
+
272
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
273
+ viewport_height=render_res[1],
274
+ point_size=1.0)
275
+ # material = pyrender.MetallicRoughnessMaterial(
276
+ # metallicFactor=0.0,
277
+ # alphaMode='OPAQUE',
278
+ # baseColorFactor=(*mesh_base_color, 1.0))
279
+
280
+ focal_length = focal_length if focal_length is not None else self.focal_length
281
+
282
+ if cam_t is not None:
283
+ camera_translation = cam_t.copy()
284
+ camera_translation[0] *= -1.
285
+ else:
286
+ camera_translation = np.array([0, 0, camera_z * focal_length / render_res[1]])
287
+
288
+ mesh = self.vertices_to_trimesh(vertices, np.array([0, 0, 0]), mesh_base_color, rot_axis, rot_angle,
289
+ )
290
+ mesh = pyrender.Mesh.from_trimesh(mesh)
291
+ # mesh = pyrender.Mesh.from_trimesh(mesh, material=material)
292
+
293
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
294
+ ambient_light=(0.3, 0.3, 0.3))
295
+ scene.add(mesh, 'mesh')
296
+
297
+ camera_pose = np.eye(4)
298
+ camera_pose[:3, 3] = camera_translation
299
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
300
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
301
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
302
+
303
+ # Create camera node and add it to pyRender scene
304
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
305
+ scene.add_node(camera_node)
306
+ self.add_point_lighting(scene, camera_node)
307
+ self.add_lighting(scene, camera_node)
308
+
309
+ light_nodes = create_raymond_lights()
310
+ for node in light_nodes:
311
+ scene.add_node(node)
312
+
313
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
314
+ color = color.astype(np.float32) / 255.0
315
+ renderer.delete()
316
+
317
+ return color
318
+
319
+ def render_rgba_multiple(
320
+ self,
321
+ vertices: List[np.array],
322
+ cam_t: List[np.array],
323
+ rot_axis=[1, 0, 0],
324
+ rot_angle=0,
325
+ mesh_base_color=(1.0, 1.0, 0.9),
326
+ scene_bg_color=(0, 0, 0),
327
+ render_res=[256, 256],
328
+ focal_length=None,
329
+ ):
330
+
331
+ renderer = pyrender.OffscreenRenderer(viewport_width=render_res[0],
332
+ viewport_height=render_res[1],
333
+ point_size=1.0)
334
+ # material = pyrender.MetallicRoughnessMaterial(
335
+ # metallicFactor=0.0,
336
+ # alphaMode='OPAQUE',
337
+ # baseColorFactor=(*mesh_base_color, 1.0))
338
+
339
+ mesh_list = [pyrender.Mesh.from_trimesh(
340
+ self.vertices_to_trimesh(vvv, ttt.copy(), mesh_base_color, rot_axis, rot_angle)) for
341
+ vvv, ttt in zip(vertices, cam_t)]
342
+
343
+ scene = pyrender.Scene(bg_color=[*scene_bg_color, 0.0],
344
+ ambient_light=(0.3, 0.3, 0.3))
345
+ for i, mesh in enumerate(mesh_list):
346
+ scene.add(mesh, f'mesh_{i}')
347
+
348
+ camera_pose = np.eye(4)
349
+ # camera_pose[:3, 3] = camera_translation
350
+ camera_center = [render_res[0] / 2., render_res[1] / 2.]
351
+ focal_length = focal_length if focal_length is not None else self.focal_length
352
+ camera = pyrender.IntrinsicsCamera(fx=focal_length, fy=focal_length,
353
+ cx=camera_center[0], cy=camera_center[1], zfar=1e12)
354
+
355
+ # Create camera node and add it to pyRender scene
356
+ camera_node = pyrender.Node(camera=camera, matrix=camera_pose)
357
+ scene.add_node(camera_node)
358
+ self.add_point_lighting(scene, camera_node)
359
+ self.add_lighting(scene, camera_node)
360
+
361
+ light_nodes = create_raymond_lights()
362
+ for node in light_nodes:
363
+ scene.add_node(node)
364
+
365
+ color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA)
366
+ color = color.astype(np.float32) / 255.0
367
+ renderer.delete()
368
+
369
+ return color
370
+
371
+ def add_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
372
+ # from phalp.visualize.py_renderer import get_light_poses
373
+ light_poses = get_light_poses()
374
+ light_poses.append(np.eye(4))
375
+ cam_pose = scene.get_pose(cam_node)
376
+ for i, pose in enumerate(light_poses):
377
+ matrix = cam_pose @ pose
378
+ node = pyrender.Node(
379
+ name=f"light-{i:02d}",
380
+ light=pyrender.DirectionalLight(color=color, intensity=intensity),
381
+ matrix=matrix,
382
+ )
383
+ if scene.has_node(node):
384
+ continue
385
+ scene.add_node(node)
386
+
387
+ def add_point_lighting(self, scene, cam_node, color=np.ones(3), intensity=1.0):
388
+ # from phalp.visualize.py_renderer import get_light_poses
389
+ light_poses = get_light_poses(dist=0.5)
390
+ light_poses.append(np.eye(4))
391
+ cam_pose = scene.get_pose(cam_node)
392
+ for i, pose in enumerate(light_poses):
393
+ matrix = cam_pose @ pose
394
+ # node = pyrender.Node(
395
+ # name=f"light-{i:02d}",
396
+ # light=pyrender.DirectionalLight(color=color, intensity=intensity),
397
+ # matrix=matrix,
398
+ # )
399
+ node = pyrender.Node(
400
+ name=f"plight-{i:02d}",
401
+ light=pyrender.PointLight(color=color, intensity=intensity),
402
+ matrix=matrix,
403
+ )
404
+ if scene.has_node(node):
405
+ continue
406
+ scene.add_node(node)
407
+
408
+
409
+
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import tempfile
7
+ from tqdm import tqdm
8
+ import torch.utils
9
+ import trimesh
10
+ import torch.utils.data
11
+ import gradio as gr
12
+ from typing import Union, List, Tuple, Dict
13
+ from amr.models import AMR
14
+ from amr.configs import get_config
15
+ from amr.utils import recursive_to
16
+ from amr.datasets.vitdet_dataset import ViTDetDataset, DEFAULT_MEAN, DEFAULT_STD
17
+ from amr.utils.renderer import Renderer, cam_crop_to_full
18
+ from huggingface_hub import snapshot_download
19
+
20
+ LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353)
21
+
22
+ # Load model config
23
+ path_model_cfg = 'config/config.yaml'
24
+ model_cfg = get_config(path_model_cfg)
25
+
26
+ # Load model
27
+ repo_id = "luoxue-star/AniMer"
28
+ local_dir = snapshot_download(repo_id=repo_id)
29
+ PATH_CHECKPOINT = os.path.join(local_dir, "checkpoint.ckpt")
30
+ model = AMR.load_from_checkpoint(checkpoint_path=PATH_CHECKPOINT, map_location="cpu",
31
+ cfg=model_cfg, strict=False)
32
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
33
+ model = model.to(device)
34
+ model.eval()
35
+
36
+ # Setup the renderer
37
+ renderer = Renderer(model_cfg, faces=model.smal.faces)
38
+
39
+ # Make output directory if it does not exist
40
+ OUTPUT_FOLDER = "demo_out"
41
+ os.makedirs(OUTPUT_FOLDER, exist_ok=True)
42
+
43
+
44
+ def predict(im):
45
+ return im["composite"]
46
+
47
+
48
+ def inference(img: Dict)-> Tuple[Union[np.ndarray|None], List[str]]:
49
+ img = np.array(img["composite"])[:, :, :-1]
50
+ boxes = np.array([[0, 0, img.shape[1], img.shape[0]]]) # x1, y1, x2, y2
51
+
52
+ # Run AniMer on the crop image
53
+ dataset = ViTDetDataset(model_cfg, img, boxes)
54
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
55
+ all_verts = []
56
+ all_cam_t = []
57
+ temp_name = next(tempfile._get_candidate_names())
58
+ for batch in tqdm(dataloader):
59
+ batch = recursive_to(batch, device)
60
+ with torch.no_grad():
61
+ out = model(batch)
62
+
63
+ pred_cam = out['pred_cam']
64
+ box_center = batch["box_center"].float()
65
+ box_size = batch["box_size"].float()
66
+ img_size = batch["img_size"].float()
67
+ scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max()
68
+ pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size,
69
+ scaled_focal_length).detach().cpu().numpy()
70
+
71
+ # Render the result
72
+ batch_size = batch['img'].shape[0]
73
+ for n in range(batch_size):
74
+ person_id = int(batch['personid'][n])
75
+ input_patch = (batch['img'][n].cpu() * 255 * (DEFAULT_STD[:, None, None]) + (
76
+ DEFAULT_MEAN[:, None, None])) / 255.
77
+ input_patch = input_patch.permute(1, 2, 0).numpy()
78
+
79
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
80
+ cam_t = pred_cam_t_full[n]
81
+ all_verts.append(verts)
82
+ all_cam_t.append(cam_t)
83
+
84
+ # Render mesh onto the original image
85
+ if len(all_verts):
86
+ misc_args = dict(
87
+ mesh_base_color=LIGHT_BLUE,
88
+ scene_bg_color=(1, 1, 1),
89
+ focal_length=scaled_focal_length,
90
+ )
91
+
92
+ cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], **misc_args)
93
+ # Overlay image
94
+ input_img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR).astype(np.float32)[:, :, ::-1] / 255.0
95
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:, :, :1])], axis=2) # Add alpha channel
96
+ input_img_overlay = input_img[:, :, :3] * (1 - cam_view[:, :, 3:]) + cam_view[:, :, :3] * cam_view[:, :, 3:]
97
+ output_img = (255 * input_img_overlay[:, :, ::-1]).astype(np.uint8)[:, :, [2, 1, 0]]
98
+
99
+ # Return mesh path
100
+ trimeshes = [renderer.vertices_to_trimesh(vvv, ttt.copy(), LIGHT_BLUE) for vvv,ttt in zip(all_verts, all_cam_t)]
101
+ # Join meshes
102
+ mesh = trimesh.util.concatenate(trimeshes)
103
+ # Save mesh to file
104
+ mesh_name = os.path.join(OUTPUT_FOLDER, next(tempfile._get_candidate_names()) + '.obj')
105
+ trimesh.exchange.export.export_mesh(mesh, mesh_name)
106
+
107
+ return (output_img, mesh_name)
108
+ else:
109
+ return (None, [])
110
+
111
+
112
+ # with gr.Blocks(title="AniMer", css=".gradio-container") as demo:
113
+
114
+ # gr.HTML("""<div style="font-weight:bold; text-align:center; color:royalblue;">AniMer</div>""")
115
+
116
+ # with gr.Row():
117
+ # with gr.Column():
118
+ # input_image = gr.ImageEditor(label="Input image", sources=["upload", "clipboard"],
119
+ # brush=False, eraser=False, layers=False, transforms="crop",
120
+ # interactive=True,
121
+ # )
122
+ # crop_image = gr.Image(label="Crop image", sources=[])
123
+ # input_image.change(predict, outputs=crop_image, inputs=input_image, show_progress="hidden")
124
+ # with gr.Column():
125
+ # output_image = gr.Image(label="Overlap image")
126
+ # output_mesh = gr.Model3D(display_mode="wireframe", label="3D Mesh")
127
+
128
+ # gr.HTML("""<br/>""")
129
+
130
+ # with gr.Row():
131
+ # send_btn = gr.Button("Inference")
132
+ # send_btn.click(fn=inference, inputs=[crop_image], outputs=[output_image, output_mesh])
133
+
134
+ # example_images = gr.Examples([
135
+ # ['example_data/000000015956_horse.png'],
136
+ # ['example_data/n02101388_1188.png'],
137
+ # ['example_data/n02412080_12159.png'],
138
+ # ['example_data/000000101684_zebra.png']
139
+ # ],
140
+ # inputs=[input_image])
141
+
142
+ # demo.launch(debug=True)
143
+
144
+
145
+ demo = gr.Interface(
146
+ fn=inference,
147
+ analytics_enabled=False,
148
+ inputs=gr.ImageEditor(label="Input image", sources=["upload", "clipboard"], type='pil',
149
+ brush=False, eraser=False, layers=False, transforms="crop",
150
+ interactive=True),
151
+ outputs=[
152
+ gr.Image(label="Overlap image"),
153
+ gr.Model3D(display_mode="wireframe", label="3D Mesh"),
154
+ ],
155
+ title="AniMer",
156
+ description="""
157
+ # AniMer: Animal Pose and Shape Estimation Using Family Aware Transformer
158
+ https://luoxue-star.github.io/AniMer_project_page/
159
+ ## Steps for Use
160
+ 1. **Input**: Select an example image or upload your own image.
161
+ 2. **Crop**: Crop the animal in the image.
162
+ 3. **Output**:
163
+ - Overlapping Image
164
+ - 3D Mesh
165
+ """,
166
+ examples=[
167
+ 'example_data/000000015956_horse.png',
168
+ 'example_data/n02101388_1188.png',
169
+ 'example_data/n02412080_12159.png',
170
+ 'example_data/000000101684_zebra.png',
171
+ ],
172
+ )
173
+
174
+ demo.launch()
175
+
176
+
config/config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ task_name: train
2
+ tags:
3
+ - dev
4
+ train: true
5
+ test: false
6
+ ckpt_path: null
7
+ seed: null
8
+ trainer:
9
+ _target_: pytorch_lightning.Trainer
10
+ default_root_dir: ${paths.output_dir}
11
+ accelerator: gpu
12
+ devices: 1
13
+ deterministic: false
14
+ num_sanity_val_steps: 0
15
+ log_every_n_steps: ${GENERAL.LOG_STEPS}
16
+ val_check_interval: ${GENERAL.VAL_STEPS}
17
+ check_val_every_n_epoch: ${GENERAL.VAL_EPOCHS}
18
+ precision: 16-mixed
19
+ max_steps: ${GENERAL.TOTAL_STEPS}
20
+ limit_val_batches: 80
21
+ paths:
22
+ root_dir: ${oc.env:PROJECT_ROOT}
23
+ data_dir: ${paths.root_dir}/data/
24
+ log_dir: logs/
25
+ output_dir: ${hydra:runtime.output_dir}
26
+ work_dir: ${hydra:runtime.cwd}
27
+ extras:
28
+ ignore_warnings: false
29
+ enforce_tags: true
30
+ print_config: true
31
+ exp_name: AniMer
32
+ SMAL:
33
+ MODEL_PATH: ./data/my_smpl_00781_4_all.pkl
34
+ NUM_JOINTS: 34
35
+ EXTRA:
36
+ FOCAL_LENGTH: 1000
37
+ NUM_LOG_IMAGES: 4
38
+ NUM_LOG_SAMPLES_PER_IMAGE: 4
39
+ PELVIS_IND: 0
40
+ MODEL:
41
+ IMAGE_SIZE: 256
42
+ IMAGE_MEAN:
43
+ - 0.485
44
+ - 0.456
45
+ - 0.406
46
+ IMAGE_STD:
47
+ - 0.229
48
+ - 0.224
49
+ - 0.225
50
+ BACKBONE:
51
+ TYPE: vit
52
+ SMAL_HEAD:
53
+ TYPE: transformer_decoder
54
+ IN_CHANNELS: 2048
55
+ IEF_ITERS: 1
56
+ TRANSFORMER_DECODER:
57
+ depth: 6
58
+ heads: 8
59
+ mlp_dim: 1024
60
+ dim_head: 64
61
+ dropout: 0.0
62
+ emb_dropout: 0.0
63
+ norm: layer
64
+ context_dim: 1280
data/my_smpl_00781_4_all.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22831db0e0e564dc95128e098da19995c2dda39b1aa18acc1335a6e62e0e3a59
3
+ size 33686326