ac5113 commited on
Commit
b807ddb
·
1 Parent(s): c2fb5a4

added missing files

Browse files
utils/__init__.py ADDED
File without changes
utils/cluster.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import stat
4
+ import shutil
5
+ import subprocess
6
+
7
+ from loguru import logger
8
+
9
+ GPUS = {
10
+ 'v100-v16': ('\"Tesla V100-PCIE-16GB\"', 'tesla', 16000),
11
+ 'v100-p32': ('\"Tesla V100-PCIE-32GB\"', 'tesla', 32000),
12
+ 'v100-s32': ('\"Tesla V100-SXM2-32GB\"', 'tesla', 32000),
13
+ 'v100-p16': ('\"Tesla P100-PCIE-16GB\"', 'tesla', 16000),
14
+ }
15
+
16
+ def get_gpus(min_mem=10000, arch=('tesla', 'quadro', 'rtx')):
17
+ gpu_names = []
18
+ for k, (gpu_name, gpu_arch, gpu_mem) in GPUS.items():
19
+ if gpu_mem >= min_mem and gpu_arch in arch:
20
+ gpu_names.append(gpu_name)
21
+
22
+ assert len(gpu_names) > 0, 'Suitable GPU model could not be found'
23
+
24
+ return gpu_names
25
+
26
+
27
+ def execute_task_on_cluster(
28
+ script,
29
+ exp_name,
30
+ output_dir,
31
+ condor_dir,
32
+ cfg_file,
33
+ num_exp=1,
34
+ exp_opts=None,
35
+ bid_amount=10,
36
+ num_workers=2,
37
+ memory=64000,
38
+ gpu_min_mem=10000,
39
+ gpu_arch=('tesla', 'quadro', 'rtx'),
40
+ num_gpus=1
41
+ ):
42
+ # copy config to a new experiment directory and source from there.
43
+ # this makes sure the correct config is copied even if you change the config file
44
+ # after starting the experiment and before the first job is submitted
45
+ temp_config_dir = os.path.join(os.path.dirname(condor_dir), 'temp_configs', exp_name)
46
+ os.makedirs(temp_config_dir, exist_ok=True)
47
+ new_cfg_file = os.path.join(temp_config_dir, 'config.yaml')
48
+ shutil.copy(src=cfg_file, dst=new_cfg_file)
49
+
50
+ gpus = get_gpus(min_mem=gpu_min_mem, arch=gpu_arch)
51
+
52
+ gpus = ' || '.join([f'CUDADeviceName=={x}' for x in gpus])
53
+
54
+ condor_log_dir = os.path.join(condor_dir, 'condorlog', exp_name)
55
+ os.makedirs(condor_log_dir, exist_ok=True)
56
+ submission = f'executable = {condor_log_dir}/{exp_name}_run.sh\n' \
57
+ 'arguments = $(Process) $(Cluster)\n' \
58
+ f'error = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).err\n' \
59
+ f'output = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).out\n' \
60
+ f'log = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).log\n' \
61
+ f'request_memory = {memory}\n' \
62
+ f'request_cpus={int(num_workers)}\n' \
63
+ f'request_gpus={num_gpus}\n' \
64
+ f'requirements={gpus}\n' \
65
+ f'+MaxRunningPrice = 500\n' \
66
+ f'queue {num_exp}'
67
+ # f'request_cpus={int(num_workers/2)}\n' \
68
+ # f'+RunningPriceExceededAction = \"kill\"\n' \
69
+ print('<<< Condor Submission >>> ')
70
+ print(submission)
71
+
72
+ with open(f'{condor_log_dir}/{exp_name}_submit.sub', 'w') as f:
73
+ f.write(submission)
74
+
75
+ # output_dir = os.path.join(output_dir, exp_name)
76
+ logger.info(f'The logs for this experiments can be found under: {condor_log_dir}')
77
+ logger.info(f'The outputs for this experiments can be found under: {output_dir}')
78
+ ## This is the trick. Notice there is no --cluster here
79
+ bash = 'export PYTHONBUFFERED=1\n export PATH=$PATH\n ' \
80
+ f'{sys.executable} {script} --cfg {new_cfg_file} --cfg_id $1'
81
+
82
+ if exp_opts is not None:
83
+ bash += ' --opts '
84
+ for opt in exp_opts:
85
+ bash += f'{opt} '
86
+ bash += 'SYSTEM.CLUSTER_NODE $2.$1'
87
+ else:
88
+ bash += ' --opts SYSTEM.CLUSTER_NODE $2.$1'
89
+
90
+ executable_path = f'{condor_log_dir}/{exp_name}_run.sh'
91
+
92
+ with open(executable_path, 'w') as f:
93
+ f.write(bash)
94
+
95
+ os.chmod(executable_path, stat.S_IRWXU)
96
+
97
+ cmd = ['condor_submit_bid', f'{bid_amount}', f'{condor_log_dir}/{exp_name}_submit.sub']
98
+ logger.info('Executing ' + ' '.join(cmd))
99
+ subprocess.call(cmd)
utils/colorwheel.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def make_color_wheel_image(img_width, img_height):
6
+ """
7
+ Creates a color wheel based image of given width and height
8
+ Args:
9
+ img_width (int):
10
+ img_height (int):
11
+
12
+ Returns:
13
+ opencv image (numpy array): color wheel based image
14
+ """
15
+ hue = np.fromfunction(lambda i, j: (np.arctan2(i-img_height/2, img_width/2-j) + np.pi)*(180/np.pi)/2,
16
+ (img_height, img_width), dtype=np.float)
17
+ saturation = np.ones((img_height, img_width)) * 255
18
+ value = np.ones((img_height, img_width)) * 255
19
+ hsl = np.dstack((hue, saturation, value))
20
+ color_map = cv2.cvtColor(np.array(hsl, dtype=np.uint8), cv2.COLOR_HSV2BGR)
21
+ return color_map
22
+
utils/config.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import operator
3
+ import os
4
+ import shutil
5
+ import time
6
+ from functools import reduce
7
+ from typing import List, Union
8
+
9
+ import configargparse
10
+ import yaml
11
+ from flatten_dict import flatten, unflatten
12
+ from loguru import logger
13
+ from yacs.config import CfgNode as CN
14
+
15
+ from utils.cluster import execute_task_on_cluster
16
+ from utils.default_hparams import hparams
17
+
18
+
19
+ def parse_args():
20
+ def add_common_cmdline_args(parser):
21
+ # for cluster runs
22
+ parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
23
+ parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
24
+ parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
25
+ parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
26
+ parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
27
+ parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
28
+ parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
29
+ parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
30
+ nargs='*', help='additional options to update config')
31
+ parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
32
+ return parser
33
+
34
+ # For Blender main parser
35
+ arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
36
+ cfg_parser = configargparse.YAMLConfigFileParser
37
+ description = 'PyTorch implementation of DECO'
38
+
39
+ parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
40
+ config_file_parser_class=cfg_parser,
41
+ description=description,
42
+ prog='deco')
43
+
44
+ parser = add_common_cmdline_args(parser)
45
+
46
+ args = parser.parse_args()
47
+ print(args, end='\n\n')
48
+
49
+ return args
50
+
51
+ def get_hparams_defaults():
52
+ """Get a yacs hparamsNode object with default values for my_project."""
53
+ # Return a clone so that the defaults will not be altered
54
+ # This is for the "local variable" use pattern
55
+ return hparams.clone()
56
+
57
+ def update_hparams(hparams_file):
58
+ hparams = get_hparams_defaults()
59
+ hparams.merge_from_file(hparams_file)
60
+ return hparams.clone()
61
+
62
+ def update_hparams_from_dict(cfg_dict):
63
+ hparams = get_hparams_defaults()
64
+ cfg = hparams.load_cfg(str(cfg_dict))
65
+ hparams.merge_from_other_cfg(cfg)
66
+ return hparams.clone()
67
+
68
+ def get_grid_search_configs(config, excluded_keys=[]):
69
+ """
70
+ :param config: dictionary with the configurations
71
+ :return: The different configurations
72
+ """
73
+
74
+ def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
75
+ """
76
+ boolean to string conversion
77
+ :param x: list or bool to be converted
78
+ :return: string converted thinghat
79
+ """
80
+ if isinstance(x, bool):
81
+ return [str(x)]
82
+ for i, j in enumerate(x):
83
+ x[i] = str(j)
84
+ return x
85
+
86
+ # exclude from grid search
87
+
88
+ flattened_config_dict = flatten(config, reducer='path')
89
+ hyper_params = []
90
+
91
+ for k,v in flattened_config_dict.items():
92
+ if isinstance(v,list):
93
+ if k in excluded_keys:
94
+ flattened_config_dict[k] = ['+'.join(v)]
95
+ elif len(v) > 1:
96
+ hyper_params += [k]
97
+
98
+ if isinstance(v, list) and isinstance(v[0], bool) :
99
+ flattened_config_dict[k] = bool_to_string(v)
100
+
101
+ if not isinstance(v,list):
102
+ if isinstance(v, bool):
103
+ flattened_config_dict[k] = bool_to_string(v)
104
+ else:
105
+ flattened_config_dict[k] = [v]
106
+
107
+ keys, values = zip(*flattened_config_dict.items())
108
+ experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
109
+
110
+ for exp_id, exp in enumerate(experiments):
111
+ for param in excluded_keys:
112
+ exp[param] = exp[param].strip().split('+')
113
+ for param_name, param_value in exp.items():
114
+ # print(param_name,type(param_value))
115
+ if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
116
+ exp[param_name] = [True if x == 'True' else False for x in param_value]
117
+ if param_value in ['True', 'False']:
118
+ if param_value == 'True':
119
+ exp[param_name] = True
120
+ else:
121
+ exp[param_name] = False
122
+
123
+
124
+ experiments[exp_id] = unflatten(exp, splitter='path')
125
+
126
+ return experiments, hyper_params
127
+
128
+ def get_from_dict(dict, keys):
129
+ return reduce(operator.getitem, keys, dict)
130
+
131
+ def save_dict_to_yaml(obj, filename, mode='w'):
132
+ with open(filename, mode) as f:
133
+ yaml.dump(obj, f, default_flow_style=False)
134
+
135
+ def run_grid_search_experiments(
136
+ args,
137
+ script='train.py',
138
+ change_wt_name=True
139
+ ):
140
+ cfg = yaml.safe_load(open(args.cfg))
141
+ # parse config file to split into a list of configs with tuning hyperparameters separated
142
+ # Also return the names of tuned hyperparameters hyperparameters
143
+ different_configs, hyperparams = get_grid_search_configs(
144
+ cfg,
145
+ excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
146
+ )
147
+ logger.info(f'Grid search hparams: \n {hyperparams}')
148
+
149
+ # The config file may be missing some default values, so we need to add them
150
+ different_configs = [update_hparams_from_dict(c) for c in different_configs]
151
+ logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
152
+
153
+ config_to_run = CN(different_configs[args.cfg_id])
154
+
155
+ if args.cluster:
156
+ execute_task_on_cluster(
157
+ script=script,
158
+ exp_name=config_to_run.EXP_NAME,
159
+ output_dir=config_to_run.OUTPUT_DIR,
160
+ condor_dir=config_to_run.CONDOR_DIR,
161
+ cfg_file=args.cfg,
162
+ num_exp=len(different_configs),
163
+ bid_amount=args.bid,
164
+ num_workers=config_to_run.DATASET.NUM_WORKERS,
165
+ memory=args.memory,
166
+ exp_opts=args.opts,
167
+ gpu_min_mem=args.gpu_min_mem,
168
+ gpu_arch=args.gpu_arch,
169
+ )
170
+ exit()
171
+
172
+ # ==== create logdir using hyperparam settings
173
+ logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
174
+ logdir = f'{logtime}_{config_to_run.EXP_NAME}'
175
+ wt_file = config_to_run.EXP_NAME + '_'
176
+ for hp in hyperparams:
177
+ v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
178
+ logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
179
+ wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
180
+ logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
181
+ os.makedirs(logdir, exist_ok=True)
182
+ config_to_run.LOGDIR = logdir
183
+
184
+ wt_file += 'best.pth'
185
+ wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
186
+ if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
187
+
188
+ shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
189
+
190
+ # save config
191
+ save_dict_to_yaml(
192
+ unflatten(flatten(config_to_run)),
193
+ os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
194
+ )
195
+
196
+ return config_to_run
utils/default_hparams.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ # Set default hparams to construct new default config
4
+ # Make sure the defaults are same as in parser
5
+ hparams = CN()
6
+
7
+ # General settings
8
+ hparams.EXP_NAME = 'default'
9
+ hparams.PROJECT_NAME = 'default'
10
+ hparams.OUTPUT_DIR = 'deco_results/'
11
+ hparams.CONDOR_DIR = '/is/cluster/work/achatterjee/condor/rich/'
12
+ hparams.LOGDIR = ''
13
+
14
+ # Dataset hparams
15
+ hparams.DATASET = CN()
16
+ hparams.DATASET.BATCH_SIZE = 64
17
+ hparams.DATASET.NUM_WORKERS = 4
18
+ hparams.DATASET.NORMALIZE_IMAGES = True
19
+
20
+ # Optimizer hparams
21
+ hparams.OPTIMIZER = CN()
22
+ hparams.OPTIMIZER.TYPE = 'adam'
23
+ hparams.OPTIMIZER.LR = 5e-5
24
+ hparams.OPTIMIZER.NUM_UPDATE_LR = 10
25
+
26
+ # Training hparams
27
+ hparams.TRAINING = CN()
28
+ hparams.TRAINING.ENCODER = 'hrnet'
29
+ hparams.TRAINING.CONTEXT = True
30
+ hparams.TRAINING.NUM_EPOCHS = 50
31
+ hparams.TRAINING.SUMMARY_STEPS = 100
32
+ hparams.TRAINING.CHECKPOINT_EPOCHS = 5
33
+ hparams.TRAINING.NUM_EARLY_STOP = 10
34
+ hparams.TRAINING.DATASETS = ['rich']
35
+ hparams.TRAINING.DATASET_MIX_PDF = ['1.']
36
+ hparams.TRAINING.DATASET_ROOT_PATH = '/is/cluster/work/achatterjee/rich/npzs'
37
+ hparams.TRAINING.BEST_MODEL_PATH = '/is/cluster/work/achatterjee/weights/rich/exp/rich_exp.pth'
38
+ hparams.TRAINING.LOSS_WEIGHTS = 1.
39
+ hparams.TRAINING.PAL_LOSS_WEIGHTS = 1.
40
+
41
+ # Training hparams
42
+ hparams.VALIDATION = CN()
43
+ hparams.VALIDATION.SUMMARY_STEPS = 100
44
+ hparams.VALIDATION.DATASETS = ['rich']
45
+ hparams.VALIDATION.MAIN_DATASET = 'rich'
utils/diff_renderer.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn as nn
6
+
7
+ from pytorch3d.renderer import (
8
+ PerspectiveCameras,
9
+ RasterizationSettings,
10
+ DirectionalLights,
11
+ BlendParams,
12
+ HardFlatShader,
13
+ MeshRasterizer,
14
+ TexturesVertex,
15
+ TexturesAtlas
16
+ )
17
+ from pytorch3d.structures import Meshes
18
+
19
+ from .image_utils import get_default_camera
20
+ from .smpl_uv import get_tenet_texture
21
+
22
+
23
+ class MeshRendererWithDepth(nn.Module):
24
+ """
25
+ A class for rendering a batch of heterogeneous meshes. The class should
26
+ be initialized with a rasterizer and shader class which each have a forward
27
+ function.
28
+ """
29
+
30
+ def __init__(self, rasterizer, shader):
31
+ super().__init__()
32
+ self.rasterizer = rasterizer
33
+ self.shader = shader
34
+
35
+ def forward(self, meshes_world, **kwargs) -> torch.Tensor:
36
+ """
37
+ Render a batch of images from a batch of meshes by rasterizing and then
38
+ shading.
39
+
40
+ NOTE: If the blur radius for rasterization is > 0.0, some pixels can
41
+ have one or more barycentric coordinates lying outside the range [0, 1].
42
+ For a pixel with out of bounds barycentric coordinates with respect to a
43
+ face f, clipping is required before interpolating the texture uv
44
+ coordinates and z buffer so that the colors and depths are limited to
45
+ the range for the corresponding face.
46
+ """
47
+ fragments = self.rasterizer(meshes_world, **kwargs)
48
+ images = self.shader(fragments, meshes_world, **kwargs)
49
+
50
+ mask = (fragments.zbuf > -1).float()
51
+
52
+ zbuf = fragments.zbuf.view(images.shape[0], -1)
53
+ # print(images.shape, zbuf.shape)
54
+ depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \
55
+ (zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values)
56
+ depth = depth.reshape(*images.shape[:3] + (1,))
57
+
58
+ images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1)
59
+ return images
60
+
61
+
62
+ class DifferentiableRenderer(nn.Module):
63
+ def __init__(
64
+ self,
65
+ img_h,
66
+ img_w,
67
+ focal_length,
68
+ device='cuda',
69
+ background_color=(0.0, 0.0, 0.0),
70
+ texture_mode='smplpix',
71
+ vertex_colors=None,
72
+ face_textures=None,
73
+ smpl_faces=None,
74
+ is_train=False,
75
+ is_cam_batch=False,
76
+ ):
77
+ super(DifferentiableRenderer, self).__init__()
78
+ self.x = 'a'
79
+ self.img_h = img_h
80
+ self.img_w = img_w
81
+ self.device = device
82
+ self.focal_length = focal_length
83
+ K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch)
84
+ K, R = K.to(device), R.to(device)
85
+
86
+ # T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device)
87
+ if is_cam_batch:
88
+ T = torch.zeros((K.shape[0], 3)).to(device)
89
+ else:
90
+ T = torch.tensor([[0.0, 0.0, 0.0]]).to(device)
91
+ self.background_color = background_color
92
+ self.renderer = None
93
+ smpl_faces = smpl_faces
94
+
95
+ if texture_mode == 'smplpix':
96
+ face_colors = get_tenet_texture(mode=texture_mode).to(device).float()
97
+ vertex_colors = torch.from_numpy(
98
+ np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3]
99
+ ).unsqueeze(0).to(device).float()
100
+ if texture_mode == 'partseg':
101
+ vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device)
102
+ face_colors = face_textures.to(device)
103
+ if texture_mode == 'deco':
104
+ vertex_colors = vertex_colors[..., :3].to(device)
105
+ face_colors = face_textures.to(device)
106
+
107
+ self.register_buffer('K', K)
108
+ self.register_buffer('R', R)
109
+ self.register_buffer('T', T)
110
+ self.register_buffer('face_colors', face_colors)
111
+ self.register_buffer('vertex_colors', vertex_colors)
112
+ self.register_buffer('smpl_faces', smpl_faces)
113
+
114
+ self.set_requires_grad(is_train)
115
+
116
+ def set_requires_grad(self, val=False):
117
+ self.K.requires_grad_(val)
118
+ self.R.requires_grad_(val)
119
+ self.T.requires_grad_(val)
120
+ self.face_colors.requires_grad_(val)
121
+ self.vertex_colors.requires_grad_(val)
122
+ # check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor
123
+ if isinstance(self.smpl_faces, torch.FloatTensor):
124
+ self.smpl_faces.requires_grad_(val)
125
+
126
+ def forward(self, vertices, faces=None, R=None, T=None):
127
+ raise NotImplementedError
128
+
129
+
130
+ class Pytorch3D(DifferentiableRenderer):
131
+ def __init__(
132
+ self,
133
+ img_h,
134
+ img_w,
135
+ focal_length,
136
+ device='cuda',
137
+ background_color=(0.0, 0.0, 0.0),
138
+ texture_mode='smplpix',
139
+ vertex_colors=None,
140
+ face_textures=None,
141
+ smpl_faces=None,
142
+ model_type='smpl',
143
+ is_train=False,
144
+ is_cam_batch=False,
145
+ ):
146
+ super(Pytorch3D, self).__init__(
147
+ img_h,
148
+ img_w,
149
+ focal_length,
150
+ device=device,
151
+ background_color=background_color,
152
+ texture_mode=texture_mode,
153
+ vertex_colors=vertex_colors,
154
+ face_textures=face_textures,
155
+ smpl_faces=smpl_faces,
156
+ is_train=is_train,
157
+ is_cam_batch=is_cam_batch,
158
+ )
159
+
160
+ # this R converts the camera from pyrender NDC to
161
+ # OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y)
162
+ # I manually defined it here for convenience
163
+ self.R = self.R @ torch.tensor(
164
+ [[[ -1.0, 0.0, 0.0],
165
+ [ 0.0, -1.0, 0.0],
166
+ [ 0.0, 0.0, 1.0]]],
167
+ dtype=self.R.dtype, device=self.R.device,
168
+ )
169
+
170
+ if is_cam_batch:
171
+ focal_length = self.focal_length
172
+ else:
173
+ focal_length = self.focal_length[None, :]
174
+
175
+ principal_point = ((self.img_w // 2, self.img_h // 2),)
176
+ image_size = ((self.img_h, self.img_w),)
177
+
178
+ cameras = PerspectiveCameras(
179
+ device=self.device,
180
+ focal_length=focal_length,
181
+ principal_point=principal_point,
182
+ R=self.R,
183
+ T=self.T,
184
+ in_ndc=False,
185
+ image_size=image_size,
186
+ )
187
+
188
+ for param in cameras.parameters():
189
+ param.requires_grad_(False)
190
+
191
+ raster_settings = RasterizationSettings(
192
+ image_size=(self.img_h, self.img_w),
193
+ blur_radius=0.0,
194
+ max_faces_per_bin=20000,
195
+ faces_per_pixel=1,
196
+ )
197
+
198
+ lights = DirectionalLights(
199
+ device=self.device,
200
+ ambient_color=((1.0, 1.0, 1.0),),
201
+ diffuse_color=((0.0, 0.0, 0.0),),
202
+ specular_color=((0.0, 0.0, 0.0),),
203
+ direction=((0, 1, 0),),
204
+ )
205
+
206
+ blend_params = BlendParams(background_color=self.background_color)
207
+
208
+ shader = HardFlatShader(device=self.device,
209
+ cameras=cameras,
210
+ blend_params=blend_params,
211
+ lights=lights)
212
+
213
+ self.textures = TexturesVertex(verts_features=self.vertex_colors)
214
+
215
+ self.renderer = MeshRendererWithDepth(
216
+ rasterizer=MeshRasterizer(
217
+ cameras=cameras,
218
+ raster_settings=raster_settings
219
+ ),
220
+ shader=shader,
221
+ )
222
+
223
+ def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None):
224
+ batch_size = vertices.shape[0]
225
+ if faces is None:
226
+ faces = self.smpl_faces.expand(batch_size, -1, -1)
227
+
228
+ if R is None:
229
+ R = self.R.expand(batch_size, -1, -1)
230
+
231
+ if T is None:
232
+ T = self.T.expand(batch_size, -1)
233
+
234
+ # convert camera translation to pytorch3d coordinate frame
235
+ T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1)
236
+
237
+ vertex_textures = TexturesVertex(
238
+ verts_features=self.vertex_colors.expand(batch_size, -1, -1)
239
+ )
240
+
241
+ # face_textures needed because vertex_texture cause interpolation at boundaries
242
+ if face_atlas:
243
+ face_textures = TexturesAtlas(atlas=face_atlas)
244
+ else:
245
+ face_textures = TexturesAtlas(atlas=self.face_colors)
246
+
247
+ # we may need to rotate the mesh
248
+ meshes = Meshes(verts=vertices, faces=faces, textures=face_textures)
249
+ images = self.renderer(meshes, R=R, T=T)
250
+ images = images.permute(0, 3, 1, 2)
251
+ return images
252
+
253
+
254
+ class NeuralMeshRenderer(DifferentiableRenderer):
255
+ def __init__(self, *args, **kwargs):
256
+ import neural_renderer as nr
257
+
258
+ super(NeuralMeshRenderer, self).__init__(*args, **kwargs)
259
+
260
+ self.neural_renderer = nr.Renderer(
261
+ dist_coeffs=None,
262
+ orig_size=self.img_size,
263
+ image_size=self.img_size,
264
+ light_intensity_ambient=1,
265
+ light_intensity_directional=0,
266
+ anti_aliasing=False,
267
+ )
268
+
269
+ def forward(self, vertices, faces=None, R=None, T=None):
270
+ batch_size = vertices.shape[0]
271
+ if faces is None:
272
+ faces = self.smpl_faces.expand(batch_size, -1, -1)
273
+
274
+ if R is None:
275
+ R = self.R.expand(batch_size, -1, -1)
276
+
277
+ if T is None:
278
+ T = self.T.expand(batch_size, -1)
279
+ rgb, depth, mask = self.neural_renderer(
280
+ vertices,
281
+ faces,
282
+ textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1),
283
+ K=self.K.expand(batch_size, -1, -1),
284
+ R=R,
285
+ t=T.unsqueeze(1),
286
+ )
287
+ return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)
utils/get_cfg.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode
2
+
3
+ _VALID_TYPES = {tuple, list, str, int, float, bool}
4
+
5
+
6
+ def convert_to_dict(cfg_node, key_list=[]):
7
+ """ Convert a config node to dictionary """
8
+ if not isinstance(cfg_node, CfgNode):
9
+ if type(cfg_node) not in _VALID_TYPES:
10
+ print("Key {} with value {} is not a valid type; valid types: {}".format(
11
+ ".".join(key_list), type(cfg_node), _VALID_TYPES), )
12
+ return cfg_node
13
+ else:
14
+ cfg_dict = dict(cfg_node)
15
+ for k, v in cfg_dict.items():
16
+ cfg_dict[k] = convert_to_dict(v, key_list + [k])
17
+ return cfg_dict
utils/hrnet.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from loguru import logger
6
+ import torch.nn.functional as F
7
+ from yacs.config import CfgNode as CN
8
+
9
+ models = [
10
+ 'hrnet_w32',
11
+ 'hrnet_w48',
12
+ ]
13
+
14
+ BN_MOMENTUM = 0.1
15
+
16
+
17
+ def conv3x3(in_planes, out_planes, stride=1):
18
+ """3x3 convolution with padding"""
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
+ padding=1, bias=False)
21
+
22
+
23
+ class BasicBlock(nn.Module):
24
+ expansion = 1
25
+
26
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
27
+ super(BasicBlock, self).__init__()
28
+ self.conv1 = conv3x3(inplanes, planes, stride)
29
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.conv2 = conv3x3(planes, planes)
32
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
33
+ self.downsample = downsample
34
+ self.stride = stride
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if self.downsample is not None:
47
+ residual = self.downsample(x)
48
+
49
+ out += residual
50
+ out = self.relu(out)
51
+
52
+ return out
53
+
54
+
55
+ class Bottleneck(nn.Module):
56
+ expansion = 4
57
+
58
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
59
+ super(Bottleneck, self).__init__()
60
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61
+ self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
62
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63
+ padding=1, bias=False)
64
+ self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
65
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
66
+ bias=False)
67
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion,
68
+ momentum=BN_MOMENTUM)
69
+ self.relu = nn.ReLU(inplace=True)
70
+ self.downsample = downsample
71
+ self.stride = stride
72
+
73
+ def forward(self, x):
74
+ residual = x
75
+
76
+ out = self.conv1(x)
77
+ out = self.bn1(out)
78
+ out = self.relu(out)
79
+
80
+ out = self.conv2(out)
81
+ out = self.bn2(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv3(out)
85
+ out = self.bn3(out)
86
+
87
+ if self.downsample is not None:
88
+ residual = self.downsample(x)
89
+
90
+ out += residual
91
+ out = self.relu(out)
92
+
93
+ return out
94
+
95
+
96
+ class HighResolutionModule(nn.Module):
97
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
98
+ num_channels, fuse_method, multi_scale_output=True):
99
+ super(HighResolutionModule, self).__init__()
100
+ self._check_branches(
101
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
102
+
103
+ self.num_inchannels = num_inchannels
104
+ self.fuse_method = fuse_method
105
+ self.num_branches = num_branches
106
+
107
+ self.multi_scale_output = multi_scale_output
108
+
109
+ self.branches = self._make_branches(
110
+ num_branches, blocks, num_blocks, num_channels)
111
+ self.fuse_layers = self._make_fuse_layers()
112
+ self.relu = nn.ReLU(True)
113
+
114
+ def _check_branches(self, num_branches, blocks, num_blocks,
115
+ num_inchannels, num_channels):
116
+ if num_branches != len(num_blocks):
117
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
118
+ num_branches, len(num_blocks))
119
+ logger.error(error_msg)
120
+ raise ValueError(error_msg)
121
+
122
+ if num_branches != len(num_channels):
123
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
124
+ num_branches, len(num_channels))
125
+ logger.error(error_msg)
126
+ raise ValueError(error_msg)
127
+
128
+ if num_branches != len(num_inchannels):
129
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
130
+ num_branches, len(num_inchannels))
131
+ logger.error(error_msg)
132
+ raise ValueError(error_msg)
133
+
134
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
135
+ stride=1):
136
+ downsample = None
137
+ if stride != 1 or \
138
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
139
+ downsample = nn.Sequential(
140
+ nn.Conv2d(
141
+ self.num_inchannels[branch_index],
142
+ num_channels[branch_index] * block.expansion,
143
+ kernel_size=1, stride=stride, bias=False
144
+ ),
145
+ nn.BatchNorm2d(
146
+ num_channels[branch_index] * block.expansion,
147
+ momentum=BN_MOMENTUM
148
+ ),
149
+ )
150
+
151
+ layers = []
152
+ layers.append(
153
+ block(
154
+ self.num_inchannels[branch_index],
155
+ num_channels[branch_index],
156
+ stride,
157
+ downsample
158
+ )
159
+ )
160
+ self.num_inchannels[branch_index] = \
161
+ num_channels[branch_index] * block.expansion
162
+ for i in range(1, num_blocks[branch_index]):
163
+ layers.append(
164
+ block(
165
+ self.num_inchannels[branch_index],
166
+ num_channels[branch_index]
167
+ )
168
+ )
169
+
170
+ return nn.Sequential(*layers)
171
+
172
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
173
+ branches = []
174
+
175
+ for i in range(num_branches):
176
+ branches.append(
177
+ self._make_one_branch(i, block, num_blocks, num_channels)
178
+ )
179
+
180
+ return nn.ModuleList(branches)
181
+
182
+ def _make_fuse_layers(self):
183
+ if self.num_branches == 1:
184
+ return None
185
+
186
+ num_branches = self.num_branches
187
+ num_inchannels = self.num_inchannels
188
+ fuse_layers = []
189
+ for i in range(num_branches if self.multi_scale_output else 1):
190
+ fuse_layer = []
191
+ for j in range(num_branches):
192
+ if j > i:
193
+ fuse_layer.append(
194
+ nn.Sequential(
195
+ nn.Conv2d(
196
+ num_inchannels[j],
197
+ num_inchannels[i],
198
+ 1, 1, 0, bias=False
199
+ ),
200
+ nn.BatchNorm2d(num_inchannels[i]),
201
+ nn.Upsample(scale_factor=2**(j-i), mode='nearest')
202
+ )
203
+ )
204
+ elif j == i:
205
+ fuse_layer.append(None)
206
+ else:
207
+ conv3x3s = []
208
+ for k in range(i-j):
209
+ if k == i - j - 1:
210
+ num_outchannels_conv3x3 = num_inchannels[i]
211
+ conv3x3s.append(
212
+ nn.Sequential(
213
+ nn.Conv2d(
214
+ num_inchannels[j],
215
+ num_outchannels_conv3x3,
216
+ 3, 2, 1, bias=False
217
+ ),
218
+ nn.BatchNorm2d(num_outchannels_conv3x3)
219
+ )
220
+ )
221
+ else:
222
+ num_outchannels_conv3x3 = num_inchannels[j]
223
+ conv3x3s.append(
224
+ nn.Sequential(
225
+ nn.Conv2d(
226
+ num_inchannels[j],
227
+ num_outchannels_conv3x3,
228
+ 3, 2, 1, bias=False
229
+ ),
230
+ nn.BatchNorm2d(num_outchannels_conv3x3),
231
+ nn.ReLU(True)
232
+ )
233
+ )
234
+ fuse_layer.append(nn.Sequential(*conv3x3s))
235
+ fuse_layers.append(nn.ModuleList(fuse_layer))
236
+
237
+ return nn.ModuleList(fuse_layers)
238
+
239
+ def get_num_inchannels(self):
240
+ return self.num_inchannels
241
+
242
+ def forward(self, x):
243
+ if self.num_branches == 1:
244
+ return [self.branches[0](x[0])]
245
+
246
+ for i in range(self.num_branches):
247
+ x[i] = self.branches[i](x[i])
248
+
249
+ x_fuse = []
250
+
251
+ for i in range(len(self.fuse_layers)):
252
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
253
+ for j in range(1, self.num_branches):
254
+ if i == j:
255
+ y = y + x[j]
256
+ else:
257
+ y = y + self.fuse_layers[i][j](x[j])
258
+ x_fuse.append(self.relu(y))
259
+
260
+ return x_fuse
261
+
262
+
263
+ blocks_dict = {
264
+ 'BASIC': BasicBlock,
265
+ 'BOTTLENECK': Bottleneck
266
+ }
267
+
268
+
269
+ class PoseHighResolutionNet(nn.Module):
270
+
271
+ def __init__(self, cfg):
272
+ self.inplanes = 64
273
+ extra = cfg['MODEL']['EXTRA']
274
+ super(PoseHighResolutionNet, self).__init__()
275
+
276
+ self.cfg = extra
277
+
278
+ # stem net
279
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
280
+ bias=False)
281
+ self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
282
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
283
+ bias=False)
284
+ self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
285
+ self.relu = nn.ReLU(inplace=True)
286
+ self.layer1 = self._make_layer(Bottleneck, 64, 4)
287
+
288
+ self.stage2_cfg = extra['STAGE2']
289
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
290
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
291
+ num_channels = [
292
+ num_channels[i] * block.expansion for i in range(len(num_channels))
293
+ ]
294
+ self.transition1 = self._make_transition_layer([256], num_channels)
295
+ self.stage2, pre_stage_channels = self._make_stage(
296
+ self.stage2_cfg, num_channels)
297
+
298
+ self.stage3_cfg = extra['STAGE3']
299
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
300
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
301
+ num_channels = [
302
+ num_channels[i] * block.expansion for i in range(len(num_channels))
303
+ ]
304
+ self.transition2 = self._make_transition_layer(
305
+ pre_stage_channels, num_channels)
306
+ self.stage3, pre_stage_channels = self._make_stage(
307
+ self.stage3_cfg, num_channels)
308
+
309
+ self.stage4_cfg = extra['STAGE4']
310
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
311
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
312
+ num_channels = [
313
+ num_channels[i] * block.expansion for i in range(len(num_channels))
314
+ ]
315
+ self.transition3 = self._make_transition_layer(
316
+ pre_stage_channels, num_channels)
317
+ self.stage4, pre_stage_channels = self._make_stage(
318
+ self.stage4_cfg, num_channels, multi_scale_output=True)
319
+
320
+ self.final_layer = nn.Conv2d(
321
+ in_channels=pre_stage_channels[0],
322
+ out_channels=cfg['MODEL']['NUM_JOINTS'],
323
+ kernel_size=extra['FINAL_CONV_KERNEL'],
324
+ stride=1,
325
+ padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
326
+ )
327
+
328
+ self.pretrained_layers = extra['PRETRAINED_LAYERS']
329
+
330
+ if extra.DOWNSAMPLE and extra.USE_CONV:
331
+ self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
332
+ self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
333
+ self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
334
+ elif not extra.DOWNSAMPLE and extra.USE_CONV:
335
+ self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
336
+ self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
337
+ self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
338
+
339
+ def _make_transition_layer(
340
+ self, num_channels_pre_layer, num_channels_cur_layer):
341
+ num_branches_cur = len(num_channels_cur_layer)
342
+ num_branches_pre = len(num_channels_pre_layer)
343
+
344
+ transition_layers = []
345
+ for i in range(num_branches_cur):
346
+ if i < num_branches_pre:
347
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
348
+ transition_layers.append(
349
+ nn.Sequential(
350
+ nn.Conv2d(
351
+ num_channels_pre_layer[i],
352
+ num_channels_cur_layer[i],
353
+ 3, 1, 1, bias=False
354
+ ),
355
+ nn.BatchNorm2d(num_channels_cur_layer[i]),
356
+ nn.ReLU(inplace=True)
357
+ )
358
+ )
359
+ else:
360
+ transition_layers.append(None)
361
+ else:
362
+ conv3x3s = []
363
+ for j in range(i+1-num_branches_pre):
364
+ inchannels = num_channels_pre_layer[-1]
365
+ outchannels = num_channels_cur_layer[i] \
366
+ if j == i-num_branches_pre else inchannels
367
+ conv3x3s.append(
368
+ nn.Sequential(
369
+ nn.Conv2d(
370
+ inchannels, outchannels, 3, 2, 1, bias=False
371
+ ),
372
+ nn.BatchNorm2d(outchannels),
373
+ nn.ReLU(inplace=True)
374
+ )
375
+ )
376
+ transition_layers.append(nn.Sequential(*conv3x3s))
377
+
378
+ return nn.ModuleList(transition_layers)
379
+
380
+ def _make_layer(self, block, planes, blocks, stride=1):
381
+ downsample = None
382
+ if stride != 1 or self.inplanes != planes * block.expansion:
383
+ downsample = nn.Sequential(
384
+ nn.Conv2d(
385
+ self.inplanes, planes * block.expansion,
386
+ kernel_size=1, stride=stride, bias=False
387
+ ),
388
+ nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
389
+ )
390
+
391
+ layers = []
392
+ layers.append(block(self.inplanes, planes, stride, downsample))
393
+ self.inplanes = planes * block.expansion
394
+ for i in range(1, blocks):
395
+ layers.append(block(self.inplanes, planes))
396
+
397
+ return nn.Sequential(*layers)
398
+
399
+ def _make_stage(self, layer_config, num_inchannels,
400
+ multi_scale_output=True):
401
+ num_modules = layer_config['NUM_MODULES']
402
+ num_branches = layer_config['NUM_BRANCHES']
403
+ num_blocks = layer_config['NUM_BLOCKS']
404
+ num_channels = layer_config['NUM_CHANNELS']
405
+ block = blocks_dict[layer_config['BLOCK']]
406
+ fuse_method = layer_config['FUSE_METHOD']
407
+
408
+ modules = []
409
+ for i in range(num_modules):
410
+ # multi_scale_output is only used last module
411
+ if not multi_scale_output and i == num_modules - 1:
412
+ reset_multi_scale_output = False
413
+ else:
414
+ reset_multi_scale_output = True
415
+
416
+ modules.append(
417
+ HighResolutionModule(
418
+ num_branches,
419
+ block,
420
+ num_blocks,
421
+ num_inchannels,
422
+ num_channels,
423
+ fuse_method,
424
+ reset_multi_scale_output
425
+ )
426
+ )
427
+ num_inchannels = modules[-1].get_num_inchannels()
428
+
429
+ return nn.Sequential(*modules), num_inchannels
430
+
431
+ def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
432
+ layers = []
433
+ for i in range(num_layers):
434
+ layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
435
+ layers.append(
436
+ nn.Conv2d(
437
+ in_channels=num_channel, out_channels=num_channel,
438
+ kernel_size=kernel_size, stride=1, padding=1, bias=False,
439
+ )
440
+ )
441
+ layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
442
+ layers.append(nn.ReLU(inplace=True))
443
+
444
+ return nn.Sequential(*layers)
445
+
446
+ def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
447
+ layers = []
448
+ for i in range(num_layers):
449
+ layers.append(
450
+ nn.Conv2d(
451
+ in_channels=num_channel, out_channels=num_channel,
452
+ kernel_size=kernel_size, stride=2, padding=1, bias=False,
453
+ )
454
+ )
455
+ layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
456
+ layers.append(nn.ReLU(inplace=True))
457
+
458
+ return nn.Sequential(*layers)
459
+
460
+ def forward(self, x):
461
+ x = self.conv1(x)
462
+ x = self.bn1(x)
463
+ x = self.relu(x)
464
+ x = self.conv2(x)
465
+ x = self.bn2(x)
466
+ x = self.relu(x)
467
+ x = self.layer1(x)
468
+
469
+ x_list = []
470
+ for i in range(self.stage2_cfg['NUM_BRANCHES']):
471
+ if self.transition1[i] is not None:
472
+ x_list.append(self.transition1[i](x))
473
+ else:
474
+ x_list.append(x)
475
+ y_list = self.stage2(x_list)
476
+
477
+ x_list = []
478
+ for i in range(self.stage3_cfg['NUM_BRANCHES']):
479
+ if self.transition2[i] is not None:
480
+ x_list.append(self.transition2[i](y_list[-1]))
481
+ else:
482
+ x_list.append(y_list[i])
483
+ y_list = self.stage3(x_list)
484
+
485
+ x_list = []
486
+ for i in range(self.stage4_cfg['NUM_BRANCHES']):
487
+ if self.transition3[i] is not None:
488
+ x_list.append(self.transition3[i](y_list[-1]))
489
+ else:
490
+ x_list.append(y_list[i])
491
+ x = self.stage4(x_list)
492
+
493
+ if self.cfg.DOWNSAMPLE:
494
+ if self.cfg.USE_CONV:
495
+ # Downsampling with strided convolutions
496
+ x1 = self.downsample_stage_1(x[0])
497
+ x2 = self.downsample_stage_2(x[1])
498
+ x3 = self.downsample_stage_3(x[2])
499
+ x = torch.cat([x1, x2, x3, x[3]], 1)
500
+ else:
501
+ # Downsampling with interpolation
502
+ x0_h, x0_w = x[3].size(2), x[3].size(3)
503
+ x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
504
+ x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
505
+ x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
506
+ x = torch.cat([x1, x2, x3, x[3]], 1)
507
+ else:
508
+ if self.cfg.USE_CONV:
509
+ # Upsampling with interpolations + convolutions
510
+ x1 = self.upsample_stage_2(x[1])
511
+ x2 = self.upsample_stage_3(x[2])
512
+ x3 = self.upsample_stage_4(x[3])
513
+ x = torch.cat([x[0], x1, x2, x3], 1)
514
+ else:
515
+ # Upsampling with interpolation
516
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
517
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
518
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
519
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
520
+ x = torch.cat([x[0], x1, x2, x3], 1)
521
+
522
+ return x
523
+
524
+ def init_weights(self, pretrained=''):
525
+ logger.info('=> init weights from normal distribution')
526
+ for m in self.modules():
527
+ if isinstance(m, nn.Conv2d):
528
+ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
529
+ nn.init.normal_(m.weight, std=0.001)
530
+ for name, _ in m.named_parameters():
531
+ if name in ['bias']:
532
+ nn.init.constant_(m.bias, 0)
533
+ elif isinstance(m, nn.BatchNorm2d):
534
+ nn.init.constant_(m.weight, 1)
535
+ nn.init.constant_(m.bias, 0)
536
+ elif isinstance(m, nn.ConvTranspose2d):
537
+ nn.init.normal_(m.weight, std=0.001)
538
+ for name, _ in m.named_parameters():
539
+ if name in ['bias']:
540
+ nn.init.constant_(m.bias, 0)
541
+
542
+ if os.path.isfile(pretrained):
543
+ pretrained_state_dict = torch.load(pretrained)
544
+ logger.info('=> loading pretrained model {}'.format(pretrained))
545
+
546
+ need_init_state_dict = {}
547
+ for name, m in pretrained_state_dict.items():
548
+ if name.split('.')[0] in self.pretrained_layers \
549
+ or self.pretrained_layers[0] is '*':
550
+ need_init_state_dict[name] = m
551
+ self.load_state_dict(need_init_state_dict, strict=False)
552
+ elif pretrained:
553
+ logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
554
+ # raise ValueError('{} is not exist!'.format(pretrained))
555
+
556
+
557
+ def get_pose_net(cfg, is_train):
558
+ model = PoseHighResolutionNet(cfg)
559
+
560
+ if is_train and cfg['MODEL']['INIT_WEIGHTS']:
561
+ model.init_weights(cfg['MODEL']['PRETRAINED'])
562
+
563
+ return model
564
+
565
+
566
+ def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
567
+ # pose_multi_resoluton_net related params
568
+ HRNET = CN()
569
+ HRNET.PRETRAINED_LAYERS = [
570
+ 'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
571
+ 'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
572
+ ]
573
+ HRNET.STEM_INPLANES = 64
574
+ HRNET.FINAL_CONV_KERNEL = 1
575
+ HRNET.STAGE2 = CN()
576
+ HRNET.STAGE2.NUM_MODULES = 1
577
+ HRNET.STAGE2.NUM_BRANCHES = 2
578
+ HRNET.STAGE2.NUM_BLOCKS = [4, 4]
579
+ HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
580
+ HRNET.STAGE2.BLOCK = 'BASIC'
581
+ HRNET.STAGE2.FUSE_METHOD = 'SUM'
582
+ HRNET.STAGE3 = CN()
583
+ HRNET.STAGE3.NUM_MODULES = 4
584
+ HRNET.STAGE3.NUM_BRANCHES = 3
585
+ HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
586
+ HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
587
+ HRNET.STAGE3.BLOCK = 'BASIC'
588
+ HRNET.STAGE3.FUSE_METHOD = 'SUM'
589
+ HRNET.STAGE4 = CN()
590
+ HRNET.STAGE4.NUM_MODULES = 3
591
+ HRNET.STAGE4.NUM_BRANCHES = 4
592
+ HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
593
+ HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
594
+ HRNET.STAGE4.BLOCK = 'BASIC'
595
+ HRNET.STAGE4.FUSE_METHOD = 'SUM'
596
+ HRNET.DOWNSAMPLE = downsample
597
+ HRNET.USE_CONV = use_conv
598
+
599
+ cfg = CN()
600
+ cfg.MODEL = CN()
601
+ cfg.MODEL.INIT_WEIGHTS = True
602
+ cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
603
+ cfg.MODEL.EXTRA = HRNET
604
+ cfg.MODEL.NUM_JOINTS = 24
605
+ return cfg
606
+
607
+
608
+ def hrnet_w32(
609
+ pretrained=True,
610
+ pretrained_ckpt='data/weights/pose_hrnet_w32_256x192.pth',
611
+ downsample=False,
612
+ use_conv=False,
613
+ ):
614
+ cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
615
+ return get_pose_net(cfg, is_train=True)
616
+
617
+
618
+ def hrnet_w48(
619
+ pretrained=True,
620
+ pretrained_ckpt='data/weights/pose_hrnet_w48_256x192.pth',
621
+ downsample=False,
622
+ use_conv=False,
623
+ ):
624
+ cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
625
+ return get_pose_net(cfg, is_train=True)
utils/image_utils.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains functions that are used to perform data augmentation.
3
+ """
4
+ import cv2
5
+ import torch
6
+ import json
7
+ from skimage.transform import rotate, resize
8
+ import numpy as np
9
+ import jpeg4py as jpeg
10
+ from trimesh.visual import color
11
+
12
+ # from ..core import constants
13
+ # from .vibe_image_utils import gen_trans_from_patch_cv
14
+ from .kp_utils import map_smpl_to_common, get_smpl_joint_names
15
+
16
+ def get_transform(center, scale, res, rot=0):
17
+ """Generate transformation matrix."""
18
+ h = 200 * scale
19
+ t = np.zeros((3, 3))
20
+ t[0, 0] = float(res[1]) / h
21
+ t[1, 1] = float(res[0]) / h
22
+ t[0, 2] = res[1] * (-float(center[0]) / h + .5)
23
+ t[1, 2] = res[0] * (-float(center[1]) / h + .5)
24
+ t[2, 2] = 1
25
+ if not rot == 0:
26
+ rot = -rot # To match direction of rotation from cropping
27
+ rot_mat = np.zeros((3, 3))
28
+ rot_rad = rot * np.pi / 180
29
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
30
+ rot_mat[0, :2] = [cs, -sn]
31
+ rot_mat[1, :2] = [sn, cs]
32
+ rot_mat[2, 2] = 1
33
+ # Need to rotate around center
34
+ t_mat = np.eye(3)
35
+ t_mat[0, 2] = -res[1] / 2
36
+ t_mat[1, 2] = -res[0] / 2
37
+ t_inv = t_mat.copy()
38
+ t_inv[:2, 2] *= -1
39
+ t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
40
+ return t
41
+
42
+
43
+ def transform(pt, center, scale, res, invert=0, rot=0):
44
+ """Transform pixel location to different reference."""
45
+ t = get_transform(center, scale, res, rot=rot)
46
+ if invert:
47
+ t = np.linalg.inv(t)
48
+ new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
49
+ new_pt = np.dot(t, new_pt)
50
+ return new_pt[:2].astype(int) + 1
51
+
52
+
53
+ def crop(img, center, scale, res, rot=0):
54
+ """Crop image according to the supplied bounding box."""
55
+ # Upper left point
56
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
57
+ # Bottom right point
58
+ br = np.array(transform([res[0] + 1,
59
+ res[1] + 1], center, scale, res, invert=1)) - 1
60
+
61
+ # Padding so that when rotated proper amount of context is included
62
+ pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
63
+ if not rot == 0:
64
+ ul -= pad
65
+ br += pad
66
+
67
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
68
+ if len(img.shape) > 2:
69
+ new_shape += [img.shape[2]]
70
+ new_img = np.zeros(new_shape)
71
+
72
+ # Range to fill new array
73
+ new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
74
+ new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
75
+ # Range to sample from original image
76
+ old_x = max(0, ul[0]), min(len(img[0]), br[0])
77
+ old_y = max(0, ul[1]), min(len(img), br[1])
78
+ new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
79
+ old_x[0]:old_x[1]]
80
+
81
+ if not rot == 0:
82
+ # Remove padding
83
+
84
+ new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
85
+ new_img = new_img[pad:-pad, pad:-pad]
86
+
87
+ # resize image
88
+ new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
89
+ return new_img
90
+
91
+
92
+ def crop_cv2(img, center, scale, res, rot=0):
93
+ c_x, c_y = center
94
+ c_x, c_y = int(round(c_x)), int(round(c_y))
95
+ patch_width, patch_height = int(round(res[0])), int(round(res[1]))
96
+ bb_width = bb_height = int(round(scale * 200.))
97
+
98
+ trans = gen_trans_from_patch_cv(
99
+ c_x, c_y, bb_width, bb_height,
100
+ patch_width, patch_height,
101
+ scale=1.0, rot=rot, inv=False,
102
+ )
103
+
104
+ crop_img = cv2.warpAffine(
105
+ img, trans, (int(patch_width), int(patch_height)),
106
+ flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
107
+ )
108
+
109
+ return crop_img
110
+
111
+
112
+ def get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start):
113
+ y1 = int((height - crop_height) * h_start)
114
+ y2 = y1 + crop_height
115
+ x1 = int((width - crop_width) * w_start)
116
+ x2 = x1 + crop_width
117
+ return x1, y1, x2, y2
118
+
119
+
120
+ def random_crop(center, scale, crop_scale_factor, axis='all'):
121
+ '''
122
+ center: bbox center [x,y]
123
+ scale: bbox height / 200
124
+ crop_scale_factor: amount of cropping to be applied
125
+ axis: axis which cropping will be applied
126
+ "x": center the y axis and get random crops in x
127
+ "y": center the x axis and get random crops in y
128
+ "all": randomly crop from all locations
129
+ '''
130
+ orig_size = int(scale * 200.)
131
+ ul = (center - (orig_size / 2.)).astype(int)
132
+
133
+ crop_size = int(orig_size * crop_scale_factor)
134
+
135
+ if axis == 'all':
136
+ h_start = np.random.rand()
137
+ w_start = np.random.rand()
138
+ elif axis == 'x':
139
+ h_start = np.random.rand()
140
+ w_start = 0.5
141
+ elif axis == 'y':
142
+ h_start = 0.5
143
+ w_start = np.random.rand()
144
+ else:
145
+ raise ValueError(f'axis {axis} is undefined!')
146
+
147
+ x1, y1, x2, y2 = get_random_crop_coords(
148
+ height=orig_size,
149
+ width=orig_size,
150
+ crop_height=crop_size,
151
+ crop_width=crop_size,
152
+ h_start=h_start,
153
+ w_start=w_start,
154
+ )
155
+ scale = (y2 - y1) / 200.
156
+ center = ul + np.array([(y1 + y2) / 2, (x1 + x2) / 2])
157
+ return center, scale
158
+
159
+
160
+ def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
161
+ """'Undo' the image cropping/resizing.
162
+ This function is used when evaluating mask/part segmentation.
163
+ """
164
+ res = img.shape[:2]
165
+ # Upper left point
166
+ ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
167
+ # Bottom right point
168
+ br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
169
+ # size of cropped image
170
+ crop_shape = [br[1] - ul[1], br[0] - ul[0]]
171
+
172
+ new_shape = [br[1] - ul[1], br[0] - ul[0]]
173
+ if len(img.shape) > 2:
174
+ new_shape += [img.shape[2]]
175
+ new_img = np.zeros(orig_shape, dtype=np.uint8)
176
+ # Range to fill new array
177
+ new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
178
+ new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
179
+ # Range to sample from original image
180
+ old_x = max(0, ul[0]), min(orig_shape[1], br[0])
181
+ old_y = max(0, ul[1]), min(orig_shape[0], br[1])
182
+ img = resize(img, crop_shape) #, interp='nearest') # scipy.misc.imresize(img, crop_shape, interp='nearest')
183
+ new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
184
+ return new_img
185
+
186
+
187
+ def rot_aa(aa, rot):
188
+ """Rotate axis angle parameters."""
189
+ # pose parameters
190
+ R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
191
+ [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
192
+ [0, 0, 1]])
193
+ # find the rotation of the body in camera frame
194
+ per_rdg, _ = cv2.Rodrigues(aa)
195
+ # apply the global rotation to the global orientation
196
+ resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
197
+ aa = (resrot.T)[0]
198
+ return aa
199
+
200
+
201
+ def flip_img(img):
202
+ """Flip rgb images or masks.
203
+ channels come last, e.g. (256,256,3).
204
+ """
205
+ img = np.fliplr(img)
206
+ return img
207
+
208
+
209
+ def flip_kp(kp):
210
+ """Flip keypoints."""
211
+ if len(kp) == 24:
212
+ flipped_parts = constants.J24_FLIP_PERM
213
+ elif len(kp) == 49:
214
+ flipped_parts = constants.J49_FLIP_PERM
215
+ kp = kp[flipped_parts]
216
+ kp[:, 0] = - kp[:, 0]
217
+ return kp
218
+
219
+
220
+ def flip_pose(pose):
221
+ """Flip pose.
222
+ The flipping is based on SMPL parameters.
223
+ """
224
+ flipped_parts = constants.SMPL_POSE_FLIP_PERM
225
+ pose = pose[flipped_parts]
226
+ # we also negate the second and the third dimension of the axis-angle
227
+ pose[1::3] = -pose[1::3]
228
+ pose[2::3] = -pose[2::3]
229
+ return pose
230
+
231
+
232
+ def denormalize_images(images):
233
+ images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
234
+ images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
235
+ return images
236
+
237
+
238
+ def read_img(img_fn):
239
+ # return pil_img.fromarray(
240
+ # cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB))
241
+ # with open(img_fn, 'rb') as f:
242
+ # img = pil_img.open(f).convert('RGB')
243
+ # return img
244
+ if img_fn.endswith('jpeg') or img_fn.endswith('jpg'):
245
+ try:
246
+ with open(img_fn, 'rb') as f:
247
+ img = np.array(jpeg.JPEG(f).decode())
248
+ except jpeg.JPEGRuntimeError:
249
+ # logger.warning('{} produced a JPEGRuntimeError', img_fn)
250
+ img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
251
+ else:
252
+ # elif img_fn.endswith('png') or img_fn.endswith('JPG') or img_fn.endswith(''):
253
+ img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
254
+ return img.astype(np.float32)
255
+
256
+
257
+ def generate_heatmaps_2d(joints, joints_vis, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
258
+ '''
259
+ :param joints: [num_joints, 3]
260
+ :param joints_vis: [num_joints, 3]
261
+ :return: target, target_weight(1: visible, 0: invisible)
262
+ '''
263
+ target_weight = np.ones((num_joints, 1), dtype=np.float32)
264
+ target_weight[:, 0] = joints_vis[:, 0]
265
+
266
+ target = np.zeros((num_joints, heatmap_size, heatmap_size), dtype=np.float32)
267
+
268
+ tmp_size = sigma * 3
269
+
270
+ # denormalize joint into heatmap coordinates
271
+ joints = (joints + 1.) * (image_size / 2.)
272
+
273
+ for joint_id in range(num_joints):
274
+ feat_stride = image_size / heatmap_size
275
+ mu_x = int(joints[joint_id][0] / feat_stride + 0.5)
276
+ mu_y = int(joints[joint_id][1] / feat_stride + 0.5)
277
+ # Check that any part of the gaussian is in-bounds
278
+ ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
279
+ br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
280
+ if ul[0] >= heatmap_size or ul[1] >= heatmap_size \
281
+ or br[0] < 0 or br[1] < 0:
282
+ # If not, just return the image as is
283
+ target_weight[joint_id] = 0
284
+ continue
285
+
286
+ # # Generate gaussian
287
+ size = 2 * tmp_size + 1
288
+ x = np.arange(0, size, 1, np.float32)
289
+ y = x[:, np.newaxis]
290
+ x0 = y0 = size // 2
291
+ # The gaussian is not normalized, we want the center value to equal 1
292
+ g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
293
+
294
+ # Usable gaussian range
295
+ g_x = max(0, -ul[0]), min(br[0], heatmap_size) - ul[0]
296
+ g_y = max(0, -ul[1]), min(br[1], heatmap_size) - ul[1]
297
+ # Image range
298
+ img_x = max(0, ul[0]), min(br[0], heatmap_size)
299
+ img_y = max(0, ul[1]), min(br[1], heatmap_size)
300
+
301
+ v = target_weight[joint_id]
302
+ if v > 0.5:
303
+ target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
304
+ g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
305
+
306
+ return target, target_weight
307
+
308
+
309
+ def generate_part_labels(vertices, faces, cam_t, neural_renderer, body_part_texture, K, R, part_bins):
310
+ batch_size = vertices.shape[0]
311
+
312
+ body_parts, depth, mask = neural_renderer(
313
+ vertices,
314
+ faces.expand(batch_size, -1, -1),
315
+ textures=body_part_texture.expand(batch_size, -1, -1, -1, -1, -1),
316
+ K=K.expand(batch_size, -1, -1),
317
+ R=R.expand(batch_size, -1, -1),
318
+ t=cam_t.unsqueeze(1),
319
+ )
320
+
321
+ render_rgb = body_parts.clone()
322
+
323
+ body_parts = body_parts.permute(0, 2, 3, 1)
324
+ body_parts *= 255. # multiply it with 255 to make labels distant
325
+ body_parts, _ = body_parts.max(-1) # reduce to single channel
326
+
327
+ body_parts = torch.bucketize(body_parts.detach(), part_bins, right=True) # np.digitize(body_parts, bins, right=True)
328
+
329
+ # add 1 to make background label 0
330
+ body_parts = body_parts.long() + 1
331
+ body_parts = body_parts * mask.detach()
332
+
333
+ return body_parts.long(), render_rgb
334
+
335
+
336
+ def generate_heatmaps_2d_batch(joints, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
337
+ batch_size = joints.shape[0]
338
+
339
+ joints = joints.detach().cpu().numpy()
340
+ joints_vis = np.ones_like(joints)
341
+
342
+ heatmaps = []
343
+ heatmaps_vis = []
344
+ for i in range(batch_size):
345
+ hm, hm_vis = generate_heatmaps_2d(joints[i], joints_vis[i], num_joints, heatmap_size, image_size, sigma)
346
+ heatmaps.append(hm)
347
+ heatmaps_vis.append(hm_vis)
348
+
349
+ return torch.from_numpy(np.stack(heatmaps)).float().to('cuda'), \
350
+ torch.from_numpy(np.stack(heatmaps_vis)).float().to('cuda')
351
+
352
+
353
+ def get_body_part_texture(faces, model_type='smpl', non_parametric=False):
354
+ if model_type == 'smpl':
355
+ n_vertices = 6890
356
+ segmentation_path = 'data/smpl_vert_segmentation.json'
357
+ if model_type == 'smplx':
358
+ n_vertices = 10475
359
+ segmentation_path = 'data/smplx_vert_segmentation.json'
360
+
361
+ with open(segmentation_path, 'rb') as f:
362
+ part_segmentation = json.load(f)
363
+
364
+ # map all vertex ids to the joint ids
365
+ joint_names = get_smpl_joint_names()
366
+ smplx_extra_joint_names = ['leftEye', 'eyeballs', 'rightEye']
367
+ body_vert_idx = np.zeros((n_vertices), dtype=np.int32) - 1 # -1 for missing label
368
+ for i, (k, v) in enumerate(part_segmentation.items()):
369
+ if k in smplx_extra_joint_names and model_type == 'smplx':
370
+ k = 'head' # map all extra smplx face joints to head
371
+ body_joint_idx = joint_names.index(k)
372
+ body_vert_idx[v] = body_joint_idx
373
+
374
+ # pare implementation
375
+ # import joblib
376
+ # part_segmentation = joblib.load('data/smpl_partSegmentation_mapping.pkl')
377
+ # body_vert_idx = part_segmentation['smpl_index']
378
+
379
+ n_parts = 24.
380
+
381
+ if non_parametric:
382
+ # reduce the number of body_parts to 14
383
+ # by mapping some joints to others
384
+ n_parts = 14.
385
+ joint_mapping = map_smpl_to_common()
386
+
387
+ for jm in joint_mapping:
388
+ for j in jm[0]:
389
+ body_vert_idx[body_vert_idx==j] = jm[1]
390
+
391
+ vertex_colors = np.ones((n_vertices, 4))
392
+ vertex_colors[:, :3] = body_vert_idx[..., None]
393
+
394
+ vertex_colors = color.to_rgba(vertex_colors)
395
+ vertex_colors = vertex_colors[:, :3]/255.
396
+
397
+ face_colors = vertex_colors[faces].min(axis=1)
398
+ texture = np.zeros((1, faces.shape[0], 1, 1, 3), dtype=np.float32)
399
+ # texture[0, :, 0, 0, :] = face_colors[:, :3] / n_parts
400
+ texture[0, :, 0, 0, :] = face_colors[:, :3]
401
+
402
+ vertex_colors = torch.from_numpy(vertex_colors).float()
403
+ texture = torch.from_numpy(texture).float()
404
+ return vertex_colors, texture
405
+
406
+
407
+ def get_default_camera(focal_length, img_h, img_w, is_cam_batch=False):
408
+ if not is_cam_batch:
409
+ K = torch.eye(3)
410
+ K[0, 0] = focal_length
411
+ K[1, 1] = focal_length
412
+ K[2, 2] = 1
413
+ K[0, 2] = img_w / 2.
414
+ K[1, 2] = img_h / 2.
415
+ K = K[None, :, :]
416
+ R = torch.eye(3)[None, :, :]
417
+ else:
418
+ bs = focal_length.shape[0]
419
+ K = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
420
+ K[:, 0, 0] = focal_length[:, 0]
421
+ K[:, 1, 1] = focal_length[:, 1]
422
+ K[:, 2, 2] = 1
423
+ K[:, 0, 2] = img_w / 2.
424
+ K[:, 1, 2] = img_h / 2.
425
+ R = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
426
+ return K, R
427
+
428
+
429
+ def read_exif_data(img_fname):
430
+ import PIL.Image
431
+ import PIL.ExifTags
432
+
433
+ img = PIL.Image.open(img_fname)
434
+ exif_data = img._getexif()
435
+
436
+ if exif_data == None:
437
+ return None
438
+
439
+ exif = {
440
+ PIL.ExifTags.TAGS[k]: v
441
+ for k, v in exif_data.items()
442
+ if k in PIL.ExifTags.TAGS
443
+ }
444
+ return exif
utils/kp_utils.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def keypoint_hflip(kp, img_width):
5
+ # Flip a keypoint horizontally around the y-axis
6
+ # kp N,2
7
+ if len(kp.shape) == 2:
8
+ kp[:,0] = (img_width - 1.) - kp[:,0]
9
+ elif len(kp.shape) == 3:
10
+ kp[:, :, 0] = (img_width - 1.) - kp[:, :, 0]
11
+ return kp
12
+
13
+
14
+ def convert_kps(joints2d, src, dst):
15
+ src_names = eval(f'get_{src}_joint_names')()
16
+ dst_names = eval(f'get_{dst}_joint_names')()
17
+
18
+ out_joints2d = np.zeros((joints2d.shape[0], len(dst_names), joints2d.shape[-1]))
19
+
20
+ for idx, jn in enumerate(dst_names):
21
+ if jn in src_names:
22
+ out_joints2d[:, idx] = joints2d[:, src_names.index(jn)]
23
+
24
+ return out_joints2d
25
+
26
+
27
+ def get_perm_idxs(src, dst):
28
+ src_names = eval(f'get_{src}_joint_names')()
29
+ dst_names = eval(f'get_{dst}_joint_names')()
30
+ idxs = [src_names.index(h) for h in dst_names if h in src_names]
31
+ return idxs
32
+
33
+
34
+ def get_mpii3d_test_joint_names():
35
+ return [
36
+ 'headtop', # 'head_top',
37
+ 'neck',
38
+ 'rshoulder',# 'right_shoulder',
39
+ 'relbow',# 'right_elbow',
40
+ 'rwrist',# 'right_wrist',
41
+ 'lshoulder',# 'left_shoulder',
42
+ 'lelbow', # 'left_elbow',
43
+ 'lwrist', # 'left_wrist',
44
+ 'rhip', # 'right_hip',
45
+ 'rknee', # 'right_knee',
46
+ 'rankle',# 'right_ankle',
47
+ 'lhip',# 'left_hip',
48
+ 'lknee',# 'left_knee',
49
+ 'lankle',# 'left_ankle'
50
+ 'hip',# 'pelvis',
51
+ 'Spine (H36M)',# 'spine',
52
+ 'Head (H36M)',# 'head'
53
+ ]
54
+
55
+
56
+ def get_mpii3d_joint_names():
57
+ return [
58
+ 'spine3', # 0,
59
+ 'spine4', # 1,
60
+ 'spine2', # 2,
61
+ 'Spine (H36M)', #'spine', # 3,
62
+ 'hip', # 'pelvis', # 4,
63
+ 'neck', # 5,
64
+ 'Head (H36M)', # 'head', # 6,
65
+ "headtop", # 'head_top', # 7,
66
+ 'left_clavicle', # 8,
67
+ "lshoulder", # 'left_shoulder', # 9,
68
+ "lelbow", # 'left_elbow',# 10,
69
+ "lwrist", # 'left_wrist',# 11,
70
+ 'left_hand',# 12,
71
+ 'right_clavicle',# 13,
72
+ 'rshoulder',# 'right_shoulder',# 14,
73
+ 'relbow',# 'right_elbow',# 15,
74
+ 'rwrist',# 'right_wrist',# 16,
75
+ 'right_hand',# 17,
76
+ 'lhip', # left_hip',# 18,
77
+ 'lknee', # 'left_knee',# 19,
78
+ 'lankle', #left ankle # 20
79
+ 'left_foot', # 21
80
+ 'left_toe', # 22
81
+ "rhip", # 'right_hip',# 23
82
+ "rknee", # 'right_knee',# 24
83
+ "rankle", #'right_ankle', # 25
84
+ 'right_foot',# 26
85
+ 'right_toe' # 27
86
+ ]
87
+
88
+
89
+ # def get_insta_joint_names():
90
+ # return [
91
+ # 'rheel' , # 0
92
+ # 'rknee' , # 1
93
+ # 'rhip' , # 2
94
+ # 'lhip' , # 3
95
+ # 'lknee' , # 4
96
+ # 'lheel' , # 5
97
+ # 'rwrist' , # 6
98
+ # 'relbow' , # 7
99
+ # 'rshoulder' , # 8
100
+ # 'lshoulder' , # 9
101
+ # 'lelbow' , # 10
102
+ # 'lwrist' , # 11
103
+ # 'neck' , # 12
104
+ # 'headtop' , # 13
105
+ # 'nose' , # 14
106
+ # 'leye' , # 15
107
+ # 'reye' , # 16
108
+ # 'lear' , # 17
109
+ # 'rear' , # 18
110
+ # 'lbigtoe' , # 19
111
+ # 'rbigtoe' , # 20
112
+ # 'lsmalltoe' , # 21
113
+ # 'rsmalltoe' , # 22
114
+ # 'lankle' , # 23
115
+ # 'rankle' , # 24
116
+ # ]
117
+
118
+
119
+ def get_insta_joint_names():
120
+ return [
121
+ 'OP RHeel',
122
+ 'OP RKnee',
123
+ 'OP RHip',
124
+ 'OP LHip',
125
+ 'OP LKnee',
126
+ 'OP LHeel',
127
+ 'OP RWrist',
128
+ 'OP RElbow',
129
+ 'OP RShoulder',
130
+ 'OP LShoulder',
131
+ 'OP LElbow',
132
+ 'OP LWrist',
133
+ 'OP Neck',
134
+ 'headtop',
135
+ 'OP Nose',
136
+ 'OP LEye',
137
+ 'OP REye',
138
+ 'OP LEar',
139
+ 'OP REar',
140
+ 'OP LBigToe',
141
+ 'OP RBigToe',
142
+ 'OP LSmallToe',
143
+ 'OP RSmallToe',
144
+ 'OP LAnkle',
145
+ 'OP RAnkle',
146
+ ]
147
+
148
+
149
+ def get_mmpose_joint_names():
150
+ # this naming is for the first 23 joints of MMPose
151
+ # does not include hands and face
152
+ return [
153
+ 'OP Nose', # 1
154
+ 'OP LEye', # 2
155
+ 'OP REye', # 3
156
+ 'OP LEar', # 4
157
+ 'OP REar', # 5
158
+ 'OP LShoulder', # 6
159
+ 'OP RShoulder', # 7
160
+ 'OP LElbow', # 8
161
+ 'OP RElbow', # 9
162
+ 'OP LWrist', # 10
163
+ 'OP RWrist', # 11
164
+ 'OP LHip', # 12
165
+ 'OP RHip', # 13
166
+ 'OP LKnee', # 14
167
+ 'OP RKnee', # 15
168
+ 'OP LAnkle', # 16
169
+ 'OP RAnkle', # 17
170
+ 'OP LBigToe', # 18
171
+ 'OP LSmallToe', # 19
172
+ 'OP LHeel', # 20
173
+ 'OP RBigToe', # 21
174
+ 'OP RSmallToe', # 22
175
+ 'OP RHeel', # 23
176
+ ]
177
+
178
+
179
+ def get_insta_skeleton():
180
+ return np.array(
181
+ [
182
+ [0 , 1],
183
+ [1 , 2],
184
+ [2 , 3],
185
+ [3 , 4],
186
+ [4 , 5],
187
+ [6 , 7],
188
+ [7 , 8],
189
+ [8 , 9],
190
+ [9 ,10],
191
+ [2 , 8],
192
+ [3 , 9],
193
+ [10,11],
194
+ [8 ,12],
195
+ [9 ,12],
196
+ [12,13],
197
+ [12,14],
198
+ [14,15],
199
+ [14,16],
200
+ [15,17],
201
+ [16,18],
202
+ [0 ,20],
203
+ [20,22],
204
+ [5 ,19],
205
+ [19,21],
206
+ [5 ,23],
207
+ [0 ,24],
208
+ ])
209
+
210
+
211
+ def get_staf_skeleton():
212
+ return np.array(
213
+ [
214
+ [0, 1],
215
+ [1, 2],
216
+ [2, 3],
217
+ [3, 4],
218
+ [1, 5],
219
+ [5, 6],
220
+ [6, 7],
221
+ [1, 8],
222
+ [8, 9],
223
+ [9, 10],
224
+ [10, 11],
225
+ [8, 12],
226
+ [12, 13],
227
+ [13, 14],
228
+ [0, 15],
229
+ [0, 16],
230
+ [15, 17],
231
+ [16, 18],
232
+ [2, 9],
233
+ [5, 12],
234
+ [1, 19],
235
+ [20, 19],
236
+ ]
237
+ )
238
+
239
+
240
+ def get_staf_joint_names():
241
+ return [
242
+ 'OP Nose', # 0,
243
+ 'OP Neck', # 1,
244
+ 'OP RShoulder', # 2,
245
+ 'OP RElbow', # 3,
246
+ 'OP RWrist', # 4,
247
+ 'OP LShoulder', # 5,
248
+ 'OP LElbow', # 6,
249
+ 'OP LWrist', # 7,
250
+ 'OP MidHip', # 8,
251
+ 'OP RHip', # 9,
252
+ 'OP RKnee', # 10,
253
+ 'OP RAnkle', # 11,
254
+ 'OP LHip', # 12,
255
+ 'OP LKnee', # 13,
256
+ 'OP LAnkle', # 14,
257
+ 'OP REye', # 15,
258
+ 'OP LEye', # 16,
259
+ 'OP REar', # 17,
260
+ 'OP LEar', # 18,
261
+ 'Neck (LSP)', # 19,
262
+ 'Top of Head (LSP)', # 20,
263
+ ]
264
+
265
+
266
+ def get_spin_op_joint_names():
267
+ return [
268
+ 'OP Nose', # 0
269
+ 'OP Neck', # 1
270
+ 'OP RShoulder', # 2
271
+ 'OP RElbow', # 3
272
+ 'OP RWrist', # 4
273
+ 'OP LShoulder', # 5
274
+ 'OP LElbow', # 6
275
+ 'OP LWrist', # 7
276
+ 'OP MidHip', # 8
277
+ 'OP RHip', # 9
278
+ 'OP RKnee', # 10
279
+ 'OP RAnkle', # 11
280
+ 'OP LHip', # 12
281
+ 'OP LKnee', # 13
282
+ 'OP LAnkle', # 14
283
+ 'OP REye', # 15
284
+ 'OP LEye', # 16
285
+ 'OP REar', # 17
286
+ 'OP LEar', # 18
287
+ 'OP LBigToe', # 19
288
+ 'OP LSmallToe', # 20
289
+ 'OP LHeel', # 21
290
+ 'OP RBigToe', # 22
291
+ 'OP RSmallToe', # 23
292
+ 'OP RHeel', # 24
293
+ ]
294
+
295
+
296
+ def get_openpose_joint_names():
297
+ return [
298
+ 'OP Nose', # 0
299
+ 'OP Neck', # 1
300
+ 'OP RShoulder', # 2
301
+ 'OP RElbow', # 3
302
+ 'OP RWrist', # 4
303
+ 'OP LShoulder', # 5
304
+ 'OP LElbow', # 6
305
+ 'OP LWrist', # 7
306
+ 'OP MidHip', # 8
307
+ 'OP RHip', # 9
308
+ 'OP RKnee', # 10
309
+ 'OP RAnkle', # 11
310
+ 'OP LHip', # 12
311
+ 'OP LKnee', # 13
312
+ 'OP LAnkle', # 14
313
+ 'OP REye', # 15
314
+ 'OP LEye', # 16
315
+ 'OP REar', # 17
316
+ 'OP LEar', # 18
317
+ 'OP LBigToe', # 19
318
+ 'OP LSmallToe', # 20
319
+ 'OP LHeel', # 21
320
+ 'OP RBigToe', # 22
321
+ 'OP RSmallToe', # 23
322
+ 'OP RHeel', # 24
323
+ ]
324
+
325
+
326
+ def get_spin_joint_names():
327
+ return [
328
+ 'OP Nose', # 0
329
+ 'OP Neck', # 1
330
+ 'OP RShoulder', # 2
331
+ 'OP RElbow', # 3
332
+ 'OP RWrist', # 4
333
+ 'OP LShoulder', # 5
334
+ 'OP LElbow', # 6
335
+ 'OP LWrist', # 7
336
+ 'OP MidHip', # 8
337
+ 'OP RHip', # 9
338
+ 'OP RKnee', # 10
339
+ 'OP RAnkle', # 11
340
+ 'OP LHip', # 12
341
+ 'OP LKnee', # 13
342
+ 'OP LAnkle', # 14
343
+ 'OP REye', # 15
344
+ 'OP LEye', # 16
345
+ 'OP REar', # 17
346
+ 'OP LEar', # 18
347
+ 'OP LBigToe', # 19
348
+ 'OP LSmallToe', # 20
349
+ 'OP LHeel', # 21
350
+ 'OP RBigToe', # 22
351
+ 'OP RSmallToe', # 23
352
+ 'OP RHeel', # 24
353
+ 'rankle', # 25
354
+ 'rknee', # 26
355
+ 'rhip', # 27
356
+ 'lhip', # 28
357
+ 'lknee', # 29
358
+ 'lankle', # 30
359
+ 'rwrist', # 31
360
+ 'relbow', # 32
361
+ 'rshoulder', # 33
362
+ 'lshoulder', # 34
363
+ 'lelbow', # 35
364
+ 'lwrist', # 36
365
+ 'neck', # 37
366
+ 'headtop', # 38
367
+ 'hip', # 39 'Pelvis (MPII)', # 39
368
+ 'thorax', # 40 'Thorax (MPII)', # 40
369
+ 'Spine (H36M)', # 41
370
+ 'Jaw (H36M)', # 42
371
+ 'Head (H36M)', # 43
372
+ 'nose', # 44
373
+ 'leye', # 45 'Left Eye', # 45
374
+ 'reye', # 46 'Right Eye', # 46
375
+ 'lear', # 47 'Left Ear', # 47
376
+ 'rear', # 48 'Right Ear', # 48
377
+ ]
378
+
379
+ def get_muco3dhp_joint_names():
380
+ return [
381
+ 'headtop',
382
+ 'thorax',
383
+ 'rshoulder',
384
+ 'relbow',
385
+ 'rwrist',
386
+ 'lshoulder',
387
+ 'lelbow',
388
+ 'lwrist',
389
+ 'rhip',
390
+ 'rknee',
391
+ 'rankle',
392
+ 'lhip',
393
+ 'lknee',
394
+ 'lankle',
395
+ 'hip',
396
+ 'Spine (H36M)',
397
+ 'Head (H36M)',
398
+ 'R_Hand',
399
+ 'L_Hand',
400
+ 'R_Toe',
401
+ 'L_Toe'
402
+ ]
403
+
404
+ def get_h36m_joint_names():
405
+ return [
406
+ 'hip', # 0
407
+ 'lhip', # 1
408
+ 'lknee', # 2
409
+ 'lankle', # 3
410
+ 'rhip', # 4
411
+ 'rknee', # 5
412
+ 'rankle', # 6
413
+ 'Spine (H36M)', # 7
414
+ 'neck', # 8
415
+ 'Head (H36M)', # 9
416
+ 'headtop', # 10
417
+ 'lshoulder', # 11
418
+ 'lelbow', # 12
419
+ 'lwrist', # 13
420
+ 'rshoulder', # 14
421
+ 'relbow', # 15
422
+ 'rwrist', # 16
423
+ ]
424
+
425
+
426
+ def get_spin_skeleton():
427
+ return np.array(
428
+ [
429
+ [0 , 1],
430
+ [1 , 2],
431
+ [2 , 3],
432
+ [3 , 4],
433
+ [1 , 5],
434
+ [5 , 6],
435
+ [6 , 7],
436
+ [1 , 8],
437
+ [8 , 9],
438
+ [9 ,10],
439
+ [10,11],
440
+ [8 ,12],
441
+ [12,13],
442
+ [13,14],
443
+ [0 ,15],
444
+ [0 ,16],
445
+ [15,17],
446
+ [16,18],
447
+ [21,19],
448
+ [19,20],
449
+ [14,21],
450
+ [11,24],
451
+ [24,22],
452
+ [22,23],
453
+ [0 ,38],
454
+ ]
455
+ )
456
+
457
+
458
+ def get_openpose_skeleton():
459
+ return np.array(
460
+ [
461
+ [0 , 1],
462
+ [1 , 2],
463
+ [2 , 3],
464
+ [3 , 4],
465
+ [1 , 5],
466
+ [5 , 6],
467
+ [6 , 7],
468
+ [1 , 8],
469
+ [8 , 9],
470
+ [9 ,10],
471
+ [10,11],
472
+ [8 ,12],
473
+ [12,13],
474
+ [13,14],
475
+ [0 ,15],
476
+ [0 ,16],
477
+ [15,17],
478
+ [16,18],
479
+ [21,19],
480
+ [19,20],
481
+ [14,21],
482
+ [11,24],
483
+ [24,22],
484
+ [22,23],
485
+ ]
486
+ )
487
+
488
+
489
+ def get_posetrack_joint_names():
490
+ return [
491
+ "nose",
492
+ "neck",
493
+ "headtop",
494
+ "lear",
495
+ "rear",
496
+ "lshoulder",
497
+ "rshoulder",
498
+ "lelbow",
499
+ "relbow",
500
+ "lwrist",
501
+ "rwrist",
502
+ "lhip",
503
+ "rhip",
504
+ "lknee",
505
+ "rknee",
506
+ "lankle",
507
+ "rankle"
508
+ ]
509
+
510
+
511
+ def get_posetrack_original_kp_names():
512
+ return [
513
+ 'nose',
514
+ 'head_bottom',
515
+ 'head_top',
516
+ 'left_ear',
517
+ 'right_ear',
518
+ 'left_shoulder',
519
+ 'right_shoulder',
520
+ 'left_elbow',
521
+ 'right_elbow',
522
+ 'left_wrist',
523
+ 'right_wrist',
524
+ 'left_hip',
525
+ 'right_hip',
526
+ 'left_knee',
527
+ 'right_knee',
528
+ 'left_ankle',
529
+ 'right_ankle'
530
+ ]
531
+
532
+
533
+ def get_pennaction_joint_names():
534
+ return [
535
+ "headtop", # 0
536
+ "lshoulder", # 1
537
+ "rshoulder", # 2
538
+ "lelbow", # 3
539
+ "relbow", # 4
540
+ "lwrist", # 5
541
+ "rwrist", # 6
542
+ "lhip" , # 7
543
+ "rhip" , # 8
544
+ "lknee", # 9
545
+ "rknee" , # 10
546
+ "lankle", # 11
547
+ "rankle" # 12
548
+ ]
549
+
550
+
551
+ def get_common_joint_names():
552
+ return [
553
+ "rankle", # 0 "lankle", # 0
554
+ "rknee", # 1 "lknee", # 1
555
+ "rhip", # 2 "lhip", # 2
556
+ "lhip", # 3 "rhip", # 3
557
+ "lknee", # 4 "rknee", # 4
558
+ "lankle", # 5 "rankle", # 5
559
+ "rwrist", # 6 "lwrist", # 6
560
+ "relbow", # 7 "lelbow", # 7
561
+ "rshoulder", # 8 "lshoulder", # 8
562
+ "lshoulder", # 9 "rshoulder", # 9
563
+ "lelbow", # 10 "relbow", # 10
564
+ "lwrist", # 11 "rwrist", # 11
565
+ "neck", # 12 "neck", # 12
566
+ "headtop", # 13 "headtop", # 13
567
+ ]
568
+
569
+
570
+ def get_common_paper_joint_names():
571
+ return [
572
+ "Right Ankle", # 0 "lankle", # 0
573
+ "Right Knee", # 1 "lknee", # 1
574
+ "Right Hip", # 2 "lhip", # 2
575
+ "Left Hip", # 3 "rhip", # 3
576
+ "Left Knee", # 4 "rknee", # 4
577
+ "Left Ankle", # 5 "rankle", # 5
578
+ "Right Wrist", # 6 "lwrist", # 6
579
+ "Right Elbow", # 7 "lelbow", # 7
580
+ "Right Shoulder", # 8 "lshoulder", # 8
581
+ "Left Shoulder", # 9 "rshoulder", # 9
582
+ "Left Elbow", # 10 "relbow", # 10
583
+ "Left Wrist", # 11 "rwrist", # 11
584
+ "Neck", # 12 "neck", # 12
585
+ "Head", # 13 "headtop", # 13
586
+ ]
587
+
588
+
589
+ def get_common_skeleton():
590
+ return np.array(
591
+ [
592
+ [ 0, 1 ],
593
+ [ 1, 2 ],
594
+ [ 3, 4 ],
595
+ [ 4, 5 ],
596
+ [ 6, 7 ],
597
+ [ 7, 8 ],
598
+ [ 8, 2 ],
599
+ [ 8, 9 ],
600
+ [ 9, 3 ],
601
+ [ 2, 3 ],
602
+ [ 8, 12],
603
+ [ 9, 10],
604
+ [12, 9 ],
605
+ [10, 11],
606
+ [12, 13],
607
+ ]
608
+ )
609
+
610
+
611
+ def get_coco_joint_names():
612
+ return [
613
+ "nose", # 0
614
+ "leye", # 1
615
+ "reye", # 2
616
+ "lear", # 3
617
+ "rear", # 4
618
+ "lshoulder", # 5
619
+ "rshoulder", # 6
620
+ "lelbow", # 7
621
+ "relbow", # 8
622
+ "lwrist", # 9
623
+ "rwrist", # 10
624
+ "lhip", # 11
625
+ "rhip", # 12
626
+ "lknee", # 13
627
+ "rknee", # 14
628
+ "lankle", # 15
629
+ "rankle", # 16
630
+ ]
631
+
632
+
633
+ def get_ochuman_joint_names():
634
+ return [
635
+ 'rshoulder',
636
+ 'relbow',
637
+ 'rwrist',
638
+ 'lshoulder',
639
+ 'lelbow',
640
+ 'lwrist',
641
+ 'rhip',
642
+ 'rknee',
643
+ 'rankle',
644
+ 'lhip',
645
+ 'lknee',
646
+ 'lankle',
647
+ 'headtop',
648
+ 'neck',
649
+ 'rear',
650
+ 'lear',
651
+ 'nose',
652
+ 'reye',
653
+ 'leye'
654
+ ]
655
+
656
+
657
+ def get_crowdpose_joint_names():
658
+ return [
659
+ 'lshoulder',
660
+ 'rshoulder',
661
+ 'lelbow',
662
+ 'relbow',
663
+ 'lwrist',
664
+ 'rwrist',
665
+ 'lhip',
666
+ 'rhip',
667
+ 'lknee',
668
+ 'rknee',
669
+ 'lankle',
670
+ 'rankle',
671
+ 'headtop',
672
+ 'neck'
673
+ ]
674
+
675
+ def get_coco_skeleton():
676
+ # 0 - nose,
677
+ # 1 - leye,
678
+ # 2 - reye,
679
+ # 3 - lear,
680
+ # 4 - rear,
681
+ # 5 - lshoulder,
682
+ # 6 - rshoulder,
683
+ # 7 - lelbow,
684
+ # 8 - relbow,
685
+ # 9 - lwrist,
686
+ # 10 - rwrist,
687
+ # 11 - lhip,
688
+ # 12 - rhip,
689
+ # 13 - lknee,
690
+ # 14 - rknee,
691
+ # 15 - lankle,
692
+ # 16 - rankle,
693
+ return np.array(
694
+ [
695
+ [15, 13],
696
+ [13, 11],
697
+ [16, 14],
698
+ [14, 12],
699
+ [11, 12],
700
+ [ 5, 11],
701
+ [ 6, 12],
702
+ [ 5, 6 ],
703
+ [ 5, 7 ],
704
+ [ 6, 8 ],
705
+ [ 7, 9 ],
706
+ [ 8, 10],
707
+ [ 1, 2 ],
708
+ [ 0, 1 ],
709
+ [ 0, 2 ],
710
+ [ 1, 3 ],
711
+ [ 2, 4 ],
712
+ [ 3, 5 ],
713
+ [ 4, 6 ]
714
+ ]
715
+ )
716
+
717
+
718
+ def get_mpii_joint_names():
719
+ return [
720
+ "rankle", # 0
721
+ "rknee", # 1
722
+ "rhip", # 2
723
+ "lhip", # 3
724
+ "lknee", # 4
725
+ "lankle", # 5
726
+ "hip", # 6
727
+ "thorax", # 7
728
+ "neck", # 8
729
+ "headtop", # 9
730
+ "rwrist", # 10
731
+ "relbow", # 11
732
+ "rshoulder", # 12
733
+ "lshoulder", # 13
734
+ "lelbow", # 14
735
+ "lwrist", # 15
736
+ ]
737
+
738
+
739
+ def get_mpii_skeleton():
740
+ # 0 - rankle,
741
+ # 1 - rknee,
742
+ # 2 - rhip,
743
+ # 3 - lhip,
744
+ # 4 - lknee,
745
+ # 5 - lankle,
746
+ # 6 - hip,
747
+ # 7 - thorax,
748
+ # 8 - neck,
749
+ # 9 - headtop,
750
+ # 10 - rwrist,
751
+ # 11 - relbow,
752
+ # 12 - rshoulder,
753
+ # 13 - lshoulder,
754
+ # 14 - lelbow,
755
+ # 15 - lwrist,
756
+ return np.array(
757
+ [
758
+ [ 0, 1 ],
759
+ [ 1, 2 ],
760
+ [ 2, 6 ],
761
+ [ 6, 3 ],
762
+ [ 3, 4 ],
763
+ [ 4, 5 ],
764
+ [ 6, 7 ],
765
+ [ 7, 8 ],
766
+ [ 8, 9 ],
767
+ [ 7, 12],
768
+ [12, 11],
769
+ [11, 10],
770
+ [ 7, 13],
771
+ [13, 14],
772
+ [14, 15]
773
+ ]
774
+ )
775
+
776
+
777
+ def get_aich_joint_names():
778
+ return [
779
+ "rshoulder", # 0
780
+ "relbow", # 1
781
+ "rwrist", # 2
782
+ "lshoulder", # 3
783
+ "lelbow", # 4
784
+ "lwrist", # 5
785
+ "rhip", # 6
786
+ "rknee", # 7
787
+ "rankle", # 8
788
+ "lhip", # 9
789
+ "lknee", # 10
790
+ "lankle", # 11
791
+ "headtop", # 12
792
+ "neck", # 13
793
+ ]
794
+
795
+
796
+ def get_aich_skeleton():
797
+ # 0 - rshoulder,
798
+ # 1 - relbow,
799
+ # 2 - rwrist,
800
+ # 3 - lshoulder,
801
+ # 4 - lelbow,
802
+ # 5 - lwrist,
803
+ # 6 - rhip,
804
+ # 7 - rknee,
805
+ # 8 - rankle,
806
+ # 9 - lhip,
807
+ # 10 - lknee,
808
+ # 11 - lankle,
809
+ # 12 - headtop,
810
+ # 13 - neck,
811
+ return np.array(
812
+ [
813
+ [ 0, 1 ],
814
+ [ 1, 2 ],
815
+ [ 3, 4 ],
816
+ [ 4, 5 ],
817
+ [ 6, 7 ],
818
+ [ 7, 8 ],
819
+ [ 9, 10],
820
+ [10, 11],
821
+ [12, 13],
822
+ [13, 0 ],
823
+ [13, 3 ],
824
+ [ 0, 6 ],
825
+ [ 3, 9 ]
826
+ ]
827
+ )
828
+
829
+
830
+ def get_3dpw_joint_names():
831
+ return [
832
+ "nose", # 0
833
+ "thorax", # 1
834
+ "rshoulder", # 2
835
+ "relbow", # 3
836
+ "rwrist", # 4
837
+ "lshoulder", # 5
838
+ "lelbow", # 6
839
+ "lwrist", # 7
840
+ "rhip", # 8
841
+ "rknee", # 9
842
+ "rankle", # 10
843
+ "lhip", # 11
844
+ "lknee", # 12
845
+ "lankle", # 13
846
+ ]
847
+
848
+
849
+ def get_3dpw_skeleton():
850
+ return np.array(
851
+ [
852
+ [ 0, 1 ],
853
+ [ 1, 2 ],
854
+ [ 2, 3 ],
855
+ [ 3, 4 ],
856
+ [ 1, 5 ],
857
+ [ 5, 6 ],
858
+ [ 6, 7 ],
859
+ [ 2, 8 ],
860
+ [ 5, 11],
861
+ [ 8, 11],
862
+ [ 8, 9 ],
863
+ [ 9, 10],
864
+ [11, 12],
865
+ [12, 13]
866
+ ]
867
+ )
868
+
869
+
870
+ def get_smplcoco_joint_names():
871
+ return [
872
+ "rankle", # 0
873
+ "rknee", # 1
874
+ "rhip", # 2
875
+ "lhip", # 3
876
+ "lknee", # 4
877
+ "lankle", # 5
878
+ "rwrist", # 6
879
+ "relbow", # 7
880
+ "rshoulder", # 8
881
+ "lshoulder", # 9
882
+ "lelbow", # 10
883
+ "lwrist", # 11
884
+ "neck", # 12
885
+ "headtop", # 13
886
+ "nose", # 14
887
+ "leye", # 15
888
+ "reye", # 16
889
+ "lear", # 17
890
+ "rear", # 18
891
+ ]
892
+
893
+
894
+ def get_smplcoco_skeleton():
895
+ return np.array(
896
+ [
897
+ [ 0, 1 ],
898
+ [ 1, 2 ],
899
+ [ 3, 4 ],
900
+ [ 4, 5 ],
901
+ [ 6, 7 ],
902
+ [ 7, 8 ],
903
+ [ 8, 12],
904
+ [12, 9 ],
905
+ [ 9, 10],
906
+ [10, 11],
907
+ [12, 13],
908
+ [14, 15],
909
+ [15, 17],
910
+ [16, 18],
911
+ [14, 16],
912
+ [ 8, 2 ],
913
+ [ 9, 3 ],
914
+ [ 2, 3 ],
915
+ ]
916
+ )
917
+
918
+
919
+ def get_smpl_joint_names():
920
+ return [
921
+ 'hips', # 0
922
+ 'leftUpLeg', # 1
923
+ 'rightUpLeg', # 2
924
+ 'spine', # 3
925
+ 'leftLeg', # 4
926
+ 'rightLeg', # 5
927
+ 'spine1', # 6
928
+ 'leftFoot', # 7
929
+ 'rightFoot', # 8
930
+ 'spine2', # 9
931
+ 'leftToeBase', # 10
932
+ 'rightToeBase', # 11
933
+ 'neck', # 12
934
+ 'leftShoulder', # 13
935
+ 'rightShoulder', # 14
936
+ 'head', # 15
937
+ 'leftArm', # 16
938
+ 'rightArm', # 17
939
+ 'leftForeArm', # 18
940
+ 'rightForeArm', # 19
941
+ 'leftHand', # 20
942
+ 'rightHand', # 21
943
+ 'leftHandIndex1', # 22
944
+ 'rightHandIndex1', # 23
945
+ ]
946
+
947
+
948
+ def get_smpl_paper_joint_names():
949
+ return [
950
+ 'Hips', # 0
951
+ 'Left Hip', # 1
952
+ 'Right Hip', # 2
953
+ 'Spine', # 3
954
+ 'Left Knee', # 4
955
+ 'Right Knee', # 5
956
+ 'Spine_1', # 6
957
+ 'Left Ankle', # 7
958
+ 'Right Ankle', # 8
959
+ 'Spine_2', # 9
960
+ 'Left Toe', # 10
961
+ 'Right Toe', # 11
962
+ 'Neck', # 12
963
+ 'Left Shoulder', # 13
964
+ 'Right Shoulder', # 14
965
+ 'Head', # 15
966
+ 'Left Arm', # 16
967
+ 'Right Arm', # 17
968
+ 'Left Elbow', # 18
969
+ 'Right Elbow', # 19
970
+ 'Left Hand', # 20
971
+ 'Right Hand', # 21
972
+ 'Left Thumb', # 22
973
+ 'Right Thumb', # 23
974
+ ]
975
+
976
+
977
+ def get_smpl_neighbor_triplets():
978
+ return [
979
+ [ 0, 1, 2 ], # 0
980
+ [ 1, 4, 0 ], # 1
981
+ [ 2, 0, 5 ], # 2
982
+ [ 3, 0, 6 ], # 3
983
+ [ 4, 7, 1 ], # 4
984
+ [ 5, 2, 8 ], # 5
985
+ [ 6, 3, 9 ], # 6
986
+ [ 7, 10, 4 ], # 7
987
+ [ 8, 5, 11], # 8
988
+ [ 9, 13, 14], # 9
989
+ [10, 7, 4 ], # 10
990
+ [11, 8, 5 ], # 11
991
+ [12, 9, 15], # 12
992
+ [13, 16, 9 ], # 13
993
+ [14, 9, 17], # 14
994
+ [15, 9, 12], # 15
995
+ [16, 18, 13], # 16
996
+ [17, 14, 19], # 17
997
+ [18, 20, 16], # 18
998
+ [19, 17, 21], # 19
999
+ [20, 22, 18], # 20
1000
+ [21, 19, 23], # 21
1001
+ [22, 20, 18], # 22
1002
+ [23, 19, 21], # 23
1003
+ ]
1004
+
1005
+
1006
+ def get_smpl_skeleton():
1007
+ return np.array(
1008
+ [
1009
+ [ 0, 1 ],
1010
+ [ 0, 2 ],
1011
+ [ 0, 3 ],
1012
+ [ 1, 4 ],
1013
+ [ 2, 5 ],
1014
+ [ 3, 6 ],
1015
+ [ 4, 7 ],
1016
+ [ 5, 8 ],
1017
+ [ 6, 9 ],
1018
+ [ 7, 10],
1019
+ [ 8, 11],
1020
+ [ 9, 12],
1021
+ [ 9, 13],
1022
+ [ 9, 14],
1023
+ [12, 15],
1024
+ [13, 16],
1025
+ [14, 17],
1026
+ [16, 18],
1027
+ [17, 19],
1028
+ [18, 20],
1029
+ [19, 21],
1030
+ [20, 22],
1031
+ [21, 23],
1032
+ ]
1033
+ )
1034
+
1035
+
1036
+ def map_spin_joints_to_smpl():
1037
+ # this function primarily will be used to copy 2D keypoint
1038
+ # confidences to pose parameters
1039
+ return [
1040
+ [(39, 27, 28), 0], # hip,lhip,rhip->hips
1041
+ [(28,), 1], # lhip->leftUpLeg
1042
+ [(27,), 2], # rhip->rightUpLeg
1043
+ [(41, 27, 28, 39), 3], # Spine->spine
1044
+ [(29,), 4], # lknee->leftLeg
1045
+ [(26,), 5], # rknee->rightLeg
1046
+ [(41, 40, 33, 34,), 6], # spine, thorax ->spine1
1047
+ [(30,), 7], # lankle->leftFoot
1048
+ [(25,), 8], # rankle->rightFoot
1049
+ [(40, 33, 34), 9], # thorax,shoulders->spine2
1050
+ [(30,), 10], # lankle -> leftToe
1051
+ [(25,), 11], # rankle -> rightToe
1052
+ [(37, 42, 33, 34), 12], # neck, shoulders -> neck
1053
+ [(34,), 13], # lshoulder->leftShoulder
1054
+ [(33,), 14], # rshoulder->rightShoulder
1055
+ [(33, 34, 38, 43, 44, 45, 46, 47, 48,), 15], # nose, eyes, ears, headtop, shoulders->head
1056
+ [(34,), 16], # lshoulder->leftArm
1057
+ [(33,), 17], # rshoulder->rightArm
1058
+ [(35,), 18], # lelbow->leftForeArm
1059
+ [(32,), 19], # relbow->rightForeArm
1060
+ [(36,), 20], # lwrist->leftHand
1061
+ [(31,), 21], # rwrist->rightHand
1062
+ [(36,), 22], # lhand -> leftHandIndex
1063
+ [(31,), 23], # rhand -> rightHandIndex
1064
+ ]
1065
+
1066
+
1067
+ def map_smpl_to_common():
1068
+ return [
1069
+ [(11, 8), 0], # rightToe, rightFoot -> rankle
1070
+ [(5,), 1], # rightleg -> rknee,
1071
+ [(2,), 2], # rhip
1072
+ [(1,), 3], # lhip
1073
+ [(4,), 4], # leftLeg -> lknee
1074
+ [(10, 7), 5], # lefttoe, leftfoot -> lankle
1075
+ [(21, 23), 6], # rwrist
1076
+ [(18,), 7], # relbow
1077
+ [(17, 14), 8], # rshoulder
1078
+ [(16, 13), 9], # lshoulder
1079
+ [(19,), 10], # lelbow
1080
+ [(20, 22), 11], # lwrist
1081
+ [(0, 3, 6, 9, 12), 12], # neck
1082
+ [(15,), 13], # headtop
1083
+ ]
1084
+
1085
+
1086
+ def relation_among_spin_joints():
1087
+ # this function primarily will be used to copy 2D keypoint
1088
+ # confidences to 3D joints
1089
+ return [
1090
+ [(), 25],
1091
+ [(), 26],
1092
+ [(39,), 27],
1093
+ [(39,), 28],
1094
+ [(), 29],
1095
+ [(), 30],
1096
+ [(), 31],
1097
+ [(), 32],
1098
+ [(), 33],
1099
+ [(), 34],
1100
+ [(), 35],
1101
+ [(), 36],
1102
+ [(40,42,44,43,38,33,34,), 37],
1103
+ [(43,44,45,46,47,48,33,34,), 38],
1104
+ [(27,28,), 39],
1105
+ [(27,28,37,41,42,), 40],
1106
+ [(27,28,39,40,), 41],
1107
+ [(37,38,44,45,46,47,48,), 42],
1108
+ [(44,45,46,47,48,38,42,37,33,34,), 43],
1109
+ [(44,45,46,47,48,38,42,37,33,34), 44],
1110
+ [(44,45,46,47,48,38,42,37,33,34), 45],
1111
+ [(44,45,46,47,48,38,42,37,33,34), 46],
1112
+ [(44,45,46,47,48,38,42,37,33,34), 47],
1113
+ [(44,45,46,47,48,38,42,37,33,34), 48],
1114
+ ]
utils/loss.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from common import constants
4
+ from models.smpl import SMPL
5
+ from smplx import SMPLX
6
+ import pickle as pkl
7
+ import numpy as np
8
+ from utils.mesh_utils import save_results_mesh
9
+ from utils.diff_renderer import Pytorch3D
10
+ import os
11
+ import cv2
12
+
13
+
14
+ class sem_loss_function(nn.Module):
15
+ def __init__(self):
16
+ super(sem_loss_function, self).__init__()
17
+ self.ce = nn.BCELoss()
18
+
19
+ def forward(self, y_true, y_pred):
20
+ loss = self.ce(y_pred, y_true)
21
+ return loss
22
+
23
+
24
+ class class_loss_function(nn.Module):
25
+ def __init__(self):
26
+ super(class_loss_function, self).__init__()
27
+ self.ce_loss = nn.BCELoss()
28
+ # self.ce_loss = nn.MultiLabelSoftMarginLoss()
29
+ # self.ce_loss = nn.MultiLabelMarginLoss()
30
+
31
+ def forward(self, y_true, y_pred, valid_mask):
32
+ # y_true = torch.squeeze(y_true, 1).long()
33
+ # y_true = torch.squeeze(y_true, 1)
34
+ # y_pred = torch.squeeze(y_pred, 1)
35
+ bs = y_true.shape[0]
36
+ if bs != 1:
37
+ y_pred = y_pred[valid_mask == 1]
38
+ y_true = y_true[valid_mask == 1]
39
+ if len(y_pred) > 0:
40
+ return self.ce_loss(y_pred, y_true)
41
+ else:
42
+ return torch.tensor(0.0).to(y_pred.device)
43
+
44
+
45
+ class pixel_anchoring_function(nn.Module):
46
+ def __init__(self, model_type, device='cuda'):
47
+ super(pixel_anchoring_function, self).__init__()
48
+
49
+ self.device = device
50
+
51
+ self.model_type = model_type
52
+
53
+ if self.model_type == 'smplx':
54
+ # load mapping from smpl vertices to smplx vertices
55
+ mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
56
+ with open(mapping_pkl, 'rb') as f:
57
+ smpl_to_smplx_mapping = pkl.load(f)
58
+ smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
59
+ self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
60
+
61
+
62
+ # Setup the SMPL model
63
+ if self.model_type == 'smpl':
64
+ self.n_vertices = 6890
65
+ self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
66
+ if self.model_type == 'smplx':
67
+ self.n_vertices = 10475
68
+ self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
69
+ num_betas=10,
70
+ use_pca=False).to(self.device)
71
+ self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
72
+
73
+ self.ce_loss = nn.BCELoss()
74
+
75
+ def get_posed_mesh(self, body_params, debug=False):
76
+ betas = body_params['betas']
77
+ pose = body_params['pose']
78
+ transl = body_params['transl']
79
+
80
+ # extra smplx params
81
+ extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
82
+ 'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
83
+ 'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
84
+ 'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
85
+ 'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
86
+ 'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
87
+
88
+ smpl_output = self.body_model(betas=betas,
89
+ body_pose=pose[:, 3:],
90
+ global_orient=pose[:, :3],
91
+ pose2rot=True,
92
+ transl=transl,
93
+ **extra_args)
94
+ smpl_verts = smpl_output.vertices
95
+ smpl_joints = smpl_output.joints
96
+
97
+ if debug:
98
+ for mesh_i in range(smpl_verts.shape[0]):
99
+ out_dir = 'temp_meshes'
100
+ os.makedirs(out_dir, exist_ok=True)
101
+ out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
102
+ save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
103
+ return smpl_verts, smpl_joints
104
+
105
+
106
+ def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
107
+
108
+ bs = smpl_verts.shape[0]
109
+
110
+ # Incorporate resizing factor into the camera
111
+ img_w = 256 # TODO: Remove hardcoding
112
+ img_h = 256 # TODO: Remove hardcoding
113
+ focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
114
+ focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
115
+ # convert to float for pytorch3d
116
+ focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
117
+
118
+ # concatenate focal length
119
+ focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
120
+
121
+ # Setup renderer
122
+ renderer = Pytorch3D(img_h=img_h,
123
+ img_w=img_w,
124
+ focal_length=focal_length,
125
+ smpl_faces=self.body_faces,
126
+ texture_mode='deco',
127
+ vertex_colors=vertex_colors,
128
+ face_textures=face_textures,
129
+ is_train=True,
130
+ is_cam_batch=True)
131
+ front_view = renderer(smpl_verts)
132
+ if debug:
133
+ # visualize the front view as images in a temp_image folder
134
+ for i in range(bs):
135
+ front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
136
+ front_view_mask = front_view[i, 3, :, :].detach().cpu()
137
+ out_dir = 'temp_images'
138
+ os.makedirs(out_dir, exist_ok=True)
139
+ out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
140
+ out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
141
+ cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
142
+ cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
143
+
144
+ return front_view
145
+
146
+ def paint_contact(self, pred_contact):
147
+ """
148
+ Paints the contact vertices on the SMPL mesh
149
+
150
+ Args:
151
+ pred_contact: prbabilities of contact vertices
152
+
153
+ Returns:
154
+ pred_rgb: RGB colors for the contact vertices
155
+ """
156
+ bs = pred_contact.shape[0]
157
+
158
+ # initialize black and while colors
159
+ colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
160
+ colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
161
+
162
+ # add another dimension to the contact probabilities for inverse probabilities
163
+ pred_contact = torch.unsqueeze(pred_contact, 2)
164
+ pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
165
+
166
+ # get pred_rgb colors
167
+ pred_vert_rgb = torch.bmm(pred_contact, colors)
168
+ pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
169
+ pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
170
+ pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
171
+ return pred_vert_rgb, pred_face_texture
172
+
173
+ def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
174
+ """
175
+ Takes predicted contact labels (probabilities), transfers them to the posed mesh and
176
+ renders to the image. Loss is computed between the rendered contact and the ground truth
177
+ polygons from HOT.
178
+
179
+ Args:
180
+ pred_contact: predicted contact labels (probabilities)
181
+ body_params: SMPL parameters in camera coords
182
+ cam_k: camera intrinsics
183
+ gt_contact_polygon: ground truth polygons from HOT
184
+ """
185
+ # convert pred_contact to smplx
186
+ bs = pred_contact.shape[0]
187
+ if self.model_type == 'smplx':
188
+ smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
189
+ pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
190
+ pred_contact = pred_contact.squeeze()
191
+
192
+ # get the posed mesh
193
+ smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
194
+
195
+ # paint the contact vertices on the mesh
196
+ vertex_colors, face_textures = self.paint_contact(pred_contact)
197
+
198
+ # render the mesh
199
+ front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
200
+ front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
201
+ front_view_mask = front_view[:, 3, :, :]
202
+
203
+ # compute segmentation loss between rendered contact mask and ground truth contact mask
204
+ front_view_rgb = front_view_rgb[valid_mask == 1]
205
+ gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
206
+ loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
207
+ return loss, front_view_rgb, front_view_mask
utils/mesh_utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import trimesh
2
+
3
+ def save_results_mesh(vertices, faces, filename):
4
+ mesh = trimesh.Trimesh(vertices, faces, process=False)
5
+ mesh.export(filename)
6
+ print(f'save results to {filename}')
utils/metrics.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import monai.metrics as metrics
4
+ from common.constants import DIST_MATRIX_PATH
5
+
6
+ DIST_MATRIX = np.load(DIST_MATRIX_PATH)
7
+
8
+ def metric(mask, pred, back=True):
9
+ iou = metrics.compute_meaniou(pred, mask, back, False)
10
+ iou = iou.mean()
11
+
12
+ return iou
13
+
14
+ def precision_recall_f1score(gt, pred):
15
+ """
16
+ Compute precision, recall, and f1
17
+ """
18
+
19
+ # gt = gt.numpy()
20
+ # pred = pred.numpy()
21
+
22
+ precision = torch.zeros(gt.shape[0])
23
+ recall = torch.zeros(gt.shape[0])
24
+ f1 = torch.zeros(gt.shape[0])
25
+
26
+ for b in range(gt.shape[0]):
27
+ tp_num = gt[b, pred[b, :] >= 0.5].sum()
28
+ precision_denominator = (pred[b, :] >= 0.5).sum()
29
+ recall_denominator = (gt[b, :]).sum()
30
+
31
+ precision_ = tp_num / precision_denominator
32
+ recall_ = tp_num / recall_denominator
33
+ if precision_denominator == 0: # if no pred
34
+ precision_ = 1.
35
+ recall_ = 0.
36
+ f1_ = 0.
37
+ elif recall_denominator == 0: # if no GT
38
+ precision_ = 0.
39
+ recall_ = 1.
40
+ f1_ = 0.
41
+ elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
42
+ precision_= 0.
43
+ recall_= 0.
44
+ f1_ = 0.
45
+ else:
46
+ f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
47
+
48
+ precision[b] = precision_
49
+ recall[b] = recall_
50
+ f1[b] = f1_
51
+
52
+ # return precision, recall, f1
53
+ return precision, recall, f1
54
+
55
+ def acc_precision_recall_f1score(gt, pred):
56
+ """
57
+ Compute acc, precision, recall, and f1
58
+ """
59
+
60
+ # gt = gt.numpy()
61
+ # pred = pred.numpy()
62
+
63
+ acc = torch.zeros(gt.shape[0])
64
+ precision = torch.zeros(gt.shape[0])
65
+ recall = torch.zeros(gt.shape[0])
66
+ f1 = torch.zeros(gt.shape[0])
67
+
68
+ for b in range(gt.shape[0]):
69
+ tp_num = gt[b, pred[b, :] >= 0.5].sum()
70
+ precision_denominator = (pred[b, :] >= 0.5).sum()
71
+ recall_denominator = (gt[b, :]).sum()
72
+ tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
73
+
74
+ acc_ = (tp_num + tn_num) / gt.shape[-1]
75
+ precision_ = tp_num / (precision_denominator + 1e-10)
76
+ recall_ = tp_num / (recall_denominator + 1e-10)
77
+ f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
78
+
79
+ acc[b] = acc_
80
+ precision[b] = precision_
81
+ recall[b] = recall_
82
+
83
+ # return precision, recall, f1
84
+ return acc, precision, recall, f1
85
+
86
+ def det_error_metric(pred, gt):
87
+
88
+ gt = gt.detach().cpu()
89
+ pred = pred.detach().cpu()
90
+
91
+ dist_matrix = torch.tensor(DIST_MATRIX)
92
+
93
+ false_positive_dist = torch.zeros(gt.shape[0])
94
+ false_negative_dist = torch.zeros(gt.shape[0])
95
+
96
+ for b in range(gt.shape[0]):
97
+ gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
98
+ error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
99
+
100
+ false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
101
+ false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
102
+
103
+ false_positive_dist[b] = false_positive_dist_
104
+ false_negative_dist[b] = false_negative_dist_
105
+
106
+ return false_positive_dist, false_negative_dist
utils/smpl_uv.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import trimesh
3
+ import numpy as np
4
+ import skimage.io as io
5
+ from PIL import Image
6
+ from smplx import SMPL
7
+ from matplotlib import cm as mpl_cm, colors as mpl_colors
8
+ from trimesh.visual.color import face_to_vertex_color, vertex_to_face_color, to_rgba
9
+
10
+ from common import constants
11
+ from .colorwheel import make_color_wheel_image
12
+
13
+
14
+ def get_smpl_uv():
15
+ uv_obj = 'data/body_models/smpl_uv_20200910/smpl_uv.obj'
16
+
17
+ uv_map = []
18
+ with open(uv_obj) as f:
19
+ for line in f.readlines():
20
+ if line.startswith('vt'):
21
+ coords = [float(x) for x in line.split(' ')[1:]]
22
+ uv_map.append(coords)
23
+
24
+ uv_map = np.array(uv_map)
25
+
26
+ return uv_map
27
+
28
+
29
+ def show_uv_texture():
30
+ # image = io.imread('data/body_models/smpl_uv_20200910/smpl_uv_20200910.png')
31
+ image = make_color_wheel_image(1024, 1024)
32
+ image = Image.fromarray(image)
33
+
34
+ uv = np.load('data/body_models/smpl_uv_20200910/uv_table.npy') # get_smpl_uv()
35
+ material = trimesh.visual.texture.SimpleMaterial(image=image)
36
+ tex_visuals = trimesh.visual.TextureVisuals(uv=uv, image=image, material=material)
37
+
38
+ smpl = SMPL(constants.SMPL_MODEL_DIR)
39
+
40
+ faces = smpl.faces
41
+ verts = smpl().vertices[0].detach().numpy()
42
+
43
+ # assert(len(uv) == len(verts))
44
+ print(uv.shape)
45
+ vc = tex_visuals.to_color().vertex_colors
46
+ fc = trimesh.visual.color.vertex_to_face_color(vc, faces)
47
+ face_colors = fc.copy()
48
+ fc = fc.astype(float)
49
+ vc = vc.astype(float)
50
+ fc[:,:3] = fc[:,:3] / 255.
51
+ vc[:,:3] = vc[:,:3] / 255.
52
+ print(fc[:,:3].max(), fc[:,:3].min(), fc[:,:3].mean())
53
+ print(vc[:, :3].max(), vc[:, :3].min(), vc[:, :3].mean())
54
+ np.save('data/body_models/smpl/color_wheel_face_colors.npy', fc)
55
+ np.save('data/body_models/smpl/color_wheel_vertex_colors.npy', vc)
56
+ print(fc.shape)
57
+ mesh = trimesh.Trimesh(verts, faces, validate=True, process=False, face_colors=face_colors)
58
+ # mesh = trimesh.load('data/body_models/smpl_uv_20200910/smpl_uv.obj', process=False)
59
+ # mesh.visual = tex_visuals
60
+
61
+ # import ipdb; ipdb.set_trace()
62
+ # print(vc.shape)
63
+ mesh.show()
64
+
65
+
66
+ def show_colored_mesh():
67
+ cm = mpl_cm.get_cmap('jet')
68
+ norm_gt = mpl_colors.Normalize()
69
+
70
+ smpl = SMPL(constants.SMPL_MODEL_DIR)
71
+
72
+ faces = smpl.faces
73
+ verts = smpl().vertices[0].detach().numpy()
74
+
75
+ m = trimesh.Trimesh(verts, faces, process=False)
76
+
77
+ mode = 1
78
+ if mode == 0:
79
+ # mano_segm_labels = m.triangles_center
80
+ face_labels = m.triangles_center
81
+ face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
82
+
83
+ elif mode == 1:
84
+ # print(face_labels.shape)
85
+ face_labels = m.triangles_center
86
+ face_labels = np.argsort(np.linalg.norm(face_labels, axis=-1))
87
+ face_colors = np.ones((13776, 4))
88
+ face_colors[:, 3] = 1.0
89
+ face_colors[:, :3] = cm(norm_gt(face_labels))[:, :3]
90
+ elif mode == 2:
91
+ # breakpoint()
92
+ fc = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')[0, :, 0, 0, 0, :]
93
+ face_colors = np.ones((13776, 4))
94
+ face_colors[:, :3] = fc
95
+ mesh = trimesh.Trimesh(verts, faces, process=False, face_colors=face_colors)
96
+ mesh.show()
97
+
98
+
99
+ def get_tenet_texture(mode='smplpix'):
100
+ # mode = 'smplpix', 'decomr'
101
+
102
+ smpl = SMPL(constants.SMPL_MODEL_DIR)
103
+
104
+ faces = smpl.faces
105
+ verts = smpl().vertices[0].detach().numpy()
106
+
107
+ m = trimesh.Trimesh(verts, faces, process=False)
108
+ if mode == 'smplpix':
109
+ # mano_segm_labels = m.triangles_center
110
+ face_labels = m.triangles_center
111
+ face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
112
+ texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
113
+ texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
114
+ texture = torch.from_numpy(texture).float()
115
+ elif mode == 'decomr':
116
+ texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
117
+ texture = torch.from_numpy(texture).float()
118
+ elif mode == 'colorwheel':
119
+ face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
120
+ texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
121
+ texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
122
+ texture = torch.from_numpy(texture).float()
123
+ else:
124
+ raise ValueError(f'{mode} is not defined!')
125
+
126
+ return texture
127
+
128
+
129
+ def save_tenet_textures(mode='smplpix'):
130
+ # mode = 'smplpix', 'decomr'
131
+
132
+ smpl = SMPL(constants.SMPL_MODEL_DIR)
133
+
134
+ faces = smpl.faces
135
+ verts = smpl().vertices[0].detach().numpy()
136
+
137
+ m = trimesh.Trimesh(verts, faces, process=False)
138
+
139
+ if mode == 'smplpix':
140
+ # mano_segm_labels = m.triangles_center
141
+ face_labels = m.triangles_center
142
+ face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
143
+ texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
144
+ texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
145
+ texture = torch.from_numpy(texture).float()
146
+
147
+ vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
148
+
149
+ elif mode == 'decomr':
150
+ texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
151
+ texture = torch.from_numpy(texture).float()
152
+ face_colors = texture[0, :, 0, 0, 0, :]
153
+ vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
154
+
155
+ elif mode == 'colorwheel':
156
+ face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
157
+ texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
158
+ texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
159
+ texture = torch.from_numpy(texture).float()
160
+ face_colors[:, :3] *= 255
161
+ vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
162
+ else:
163
+ raise ValueError(f'{mode} is not defined!')
164
+
165
+ print(vert_colors.shape, vert_colors.max())
166
+ np.save(f'data/body_models/smpl/{mode}_vertex_colors.npy', vert_colors)
167
+ return texture