Spaces:
Sleeping
Sleeping
added missing files
Browse files- utils/__init__.py +0 -0
- utils/cluster.py +99 -0
- utils/colorwheel.py +22 -0
- utils/config.py +196 -0
- utils/default_hparams.py +45 -0
- utils/diff_renderer.py +287 -0
- utils/get_cfg.py +17 -0
- utils/hrnet.py +625 -0
- utils/image_utils.py +444 -0
- utils/kp_utils.py +1114 -0
- utils/loss.py +207 -0
- utils/mesh_utils.py +6 -0
- utils/metrics.py +106 -0
- utils/smpl_uv.py +167 -0
utils/__init__.py
ADDED
File without changes
|
utils/cluster.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import stat
|
4 |
+
import shutil
|
5 |
+
import subprocess
|
6 |
+
|
7 |
+
from loguru import logger
|
8 |
+
|
9 |
+
GPUS = {
|
10 |
+
'v100-v16': ('\"Tesla V100-PCIE-16GB\"', 'tesla', 16000),
|
11 |
+
'v100-p32': ('\"Tesla V100-PCIE-32GB\"', 'tesla', 32000),
|
12 |
+
'v100-s32': ('\"Tesla V100-SXM2-32GB\"', 'tesla', 32000),
|
13 |
+
'v100-p16': ('\"Tesla P100-PCIE-16GB\"', 'tesla', 16000),
|
14 |
+
}
|
15 |
+
|
16 |
+
def get_gpus(min_mem=10000, arch=('tesla', 'quadro', 'rtx')):
|
17 |
+
gpu_names = []
|
18 |
+
for k, (gpu_name, gpu_arch, gpu_mem) in GPUS.items():
|
19 |
+
if gpu_mem >= min_mem and gpu_arch in arch:
|
20 |
+
gpu_names.append(gpu_name)
|
21 |
+
|
22 |
+
assert len(gpu_names) > 0, 'Suitable GPU model could not be found'
|
23 |
+
|
24 |
+
return gpu_names
|
25 |
+
|
26 |
+
|
27 |
+
def execute_task_on_cluster(
|
28 |
+
script,
|
29 |
+
exp_name,
|
30 |
+
output_dir,
|
31 |
+
condor_dir,
|
32 |
+
cfg_file,
|
33 |
+
num_exp=1,
|
34 |
+
exp_opts=None,
|
35 |
+
bid_amount=10,
|
36 |
+
num_workers=2,
|
37 |
+
memory=64000,
|
38 |
+
gpu_min_mem=10000,
|
39 |
+
gpu_arch=('tesla', 'quadro', 'rtx'),
|
40 |
+
num_gpus=1
|
41 |
+
):
|
42 |
+
# copy config to a new experiment directory and source from there.
|
43 |
+
# this makes sure the correct config is copied even if you change the config file
|
44 |
+
# after starting the experiment and before the first job is submitted
|
45 |
+
temp_config_dir = os.path.join(os.path.dirname(condor_dir), 'temp_configs', exp_name)
|
46 |
+
os.makedirs(temp_config_dir, exist_ok=True)
|
47 |
+
new_cfg_file = os.path.join(temp_config_dir, 'config.yaml')
|
48 |
+
shutil.copy(src=cfg_file, dst=new_cfg_file)
|
49 |
+
|
50 |
+
gpus = get_gpus(min_mem=gpu_min_mem, arch=gpu_arch)
|
51 |
+
|
52 |
+
gpus = ' || '.join([f'CUDADeviceName=={x}' for x in gpus])
|
53 |
+
|
54 |
+
condor_log_dir = os.path.join(condor_dir, 'condorlog', exp_name)
|
55 |
+
os.makedirs(condor_log_dir, exist_ok=True)
|
56 |
+
submission = f'executable = {condor_log_dir}/{exp_name}_run.sh\n' \
|
57 |
+
'arguments = $(Process) $(Cluster)\n' \
|
58 |
+
f'error = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).err\n' \
|
59 |
+
f'output = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).out\n' \
|
60 |
+
f'log = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).log\n' \
|
61 |
+
f'request_memory = {memory}\n' \
|
62 |
+
f'request_cpus={int(num_workers)}\n' \
|
63 |
+
f'request_gpus={num_gpus}\n' \
|
64 |
+
f'requirements={gpus}\n' \
|
65 |
+
f'+MaxRunningPrice = 500\n' \
|
66 |
+
f'queue {num_exp}'
|
67 |
+
# f'request_cpus={int(num_workers/2)}\n' \
|
68 |
+
# f'+RunningPriceExceededAction = \"kill\"\n' \
|
69 |
+
print('<<< Condor Submission >>> ')
|
70 |
+
print(submission)
|
71 |
+
|
72 |
+
with open(f'{condor_log_dir}/{exp_name}_submit.sub', 'w') as f:
|
73 |
+
f.write(submission)
|
74 |
+
|
75 |
+
# output_dir = os.path.join(output_dir, exp_name)
|
76 |
+
logger.info(f'The logs for this experiments can be found under: {condor_log_dir}')
|
77 |
+
logger.info(f'The outputs for this experiments can be found under: {output_dir}')
|
78 |
+
## This is the trick. Notice there is no --cluster here
|
79 |
+
bash = 'export PYTHONBUFFERED=1\n export PATH=$PATH\n ' \
|
80 |
+
f'{sys.executable} {script} --cfg {new_cfg_file} --cfg_id $1'
|
81 |
+
|
82 |
+
if exp_opts is not None:
|
83 |
+
bash += ' --opts '
|
84 |
+
for opt in exp_opts:
|
85 |
+
bash += f'{opt} '
|
86 |
+
bash += 'SYSTEM.CLUSTER_NODE $2.$1'
|
87 |
+
else:
|
88 |
+
bash += ' --opts SYSTEM.CLUSTER_NODE $2.$1'
|
89 |
+
|
90 |
+
executable_path = f'{condor_log_dir}/{exp_name}_run.sh'
|
91 |
+
|
92 |
+
with open(executable_path, 'w') as f:
|
93 |
+
f.write(bash)
|
94 |
+
|
95 |
+
os.chmod(executable_path, stat.S_IRWXU)
|
96 |
+
|
97 |
+
cmd = ['condor_submit_bid', f'{bid_amount}', f'{condor_log_dir}/{exp_name}_submit.sub']
|
98 |
+
logger.info('Executing ' + ' '.join(cmd))
|
99 |
+
subprocess.call(cmd)
|
utils/colorwheel.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def make_color_wheel_image(img_width, img_height):
|
6 |
+
"""
|
7 |
+
Creates a color wheel based image of given width and height
|
8 |
+
Args:
|
9 |
+
img_width (int):
|
10 |
+
img_height (int):
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
opencv image (numpy array): color wheel based image
|
14 |
+
"""
|
15 |
+
hue = np.fromfunction(lambda i, j: (np.arctan2(i-img_height/2, img_width/2-j) + np.pi)*(180/np.pi)/2,
|
16 |
+
(img_height, img_width), dtype=np.float)
|
17 |
+
saturation = np.ones((img_height, img_width)) * 255
|
18 |
+
value = np.ones((img_height, img_width)) * 255
|
19 |
+
hsl = np.dstack((hue, saturation, value))
|
20 |
+
color_map = cv2.cvtColor(np.array(hsl, dtype=np.uint8), cv2.COLOR_HSV2BGR)
|
21 |
+
return color_map
|
22 |
+
|
utils/config.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import operator
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import time
|
6 |
+
from functools import reduce
|
7 |
+
from typing import List, Union
|
8 |
+
|
9 |
+
import configargparse
|
10 |
+
import yaml
|
11 |
+
from flatten_dict import flatten, unflatten
|
12 |
+
from loguru import logger
|
13 |
+
from yacs.config import CfgNode as CN
|
14 |
+
|
15 |
+
from utils.cluster import execute_task_on_cluster
|
16 |
+
from utils.default_hparams import hparams
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
def add_common_cmdline_args(parser):
|
21 |
+
# for cluster runs
|
22 |
+
parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
|
23 |
+
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
|
24 |
+
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
|
25 |
+
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
|
26 |
+
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
|
27 |
+
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
|
28 |
+
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
|
29 |
+
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
|
30 |
+
nargs='*', help='additional options to update config')
|
31 |
+
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
|
32 |
+
return parser
|
33 |
+
|
34 |
+
# For Blender main parser
|
35 |
+
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
|
36 |
+
cfg_parser = configargparse.YAMLConfigFileParser
|
37 |
+
description = 'PyTorch implementation of DECO'
|
38 |
+
|
39 |
+
parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
|
40 |
+
config_file_parser_class=cfg_parser,
|
41 |
+
description=description,
|
42 |
+
prog='deco')
|
43 |
+
|
44 |
+
parser = add_common_cmdline_args(parser)
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
print(args, end='\n\n')
|
48 |
+
|
49 |
+
return args
|
50 |
+
|
51 |
+
def get_hparams_defaults():
|
52 |
+
"""Get a yacs hparamsNode object with default values for my_project."""
|
53 |
+
# Return a clone so that the defaults will not be altered
|
54 |
+
# This is for the "local variable" use pattern
|
55 |
+
return hparams.clone()
|
56 |
+
|
57 |
+
def update_hparams(hparams_file):
|
58 |
+
hparams = get_hparams_defaults()
|
59 |
+
hparams.merge_from_file(hparams_file)
|
60 |
+
return hparams.clone()
|
61 |
+
|
62 |
+
def update_hparams_from_dict(cfg_dict):
|
63 |
+
hparams = get_hparams_defaults()
|
64 |
+
cfg = hparams.load_cfg(str(cfg_dict))
|
65 |
+
hparams.merge_from_other_cfg(cfg)
|
66 |
+
return hparams.clone()
|
67 |
+
|
68 |
+
def get_grid_search_configs(config, excluded_keys=[]):
|
69 |
+
"""
|
70 |
+
:param config: dictionary with the configurations
|
71 |
+
:return: The different configurations
|
72 |
+
"""
|
73 |
+
|
74 |
+
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
|
75 |
+
"""
|
76 |
+
boolean to string conversion
|
77 |
+
:param x: list or bool to be converted
|
78 |
+
:return: string converted thinghat
|
79 |
+
"""
|
80 |
+
if isinstance(x, bool):
|
81 |
+
return [str(x)]
|
82 |
+
for i, j in enumerate(x):
|
83 |
+
x[i] = str(j)
|
84 |
+
return x
|
85 |
+
|
86 |
+
# exclude from grid search
|
87 |
+
|
88 |
+
flattened_config_dict = flatten(config, reducer='path')
|
89 |
+
hyper_params = []
|
90 |
+
|
91 |
+
for k,v in flattened_config_dict.items():
|
92 |
+
if isinstance(v,list):
|
93 |
+
if k in excluded_keys:
|
94 |
+
flattened_config_dict[k] = ['+'.join(v)]
|
95 |
+
elif len(v) > 1:
|
96 |
+
hyper_params += [k]
|
97 |
+
|
98 |
+
if isinstance(v, list) and isinstance(v[0], bool) :
|
99 |
+
flattened_config_dict[k] = bool_to_string(v)
|
100 |
+
|
101 |
+
if not isinstance(v,list):
|
102 |
+
if isinstance(v, bool):
|
103 |
+
flattened_config_dict[k] = bool_to_string(v)
|
104 |
+
else:
|
105 |
+
flattened_config_dict[k] = [v]
|
106 |
+
|
107 |
+
keys, values = zip(*flattened_config_dict.items())
|
108 |
+
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
109 |
+
|
110 |
+
for exp_id, exp in enumerate(experiments):
|
111 |
+
for param in excluded_keys:
|
112 |
+
exp[param] = exp[param].strip().split('+')
|
113 |
+
for param_name, param_value in exp.items():
|
114 |
+
# print(param_name,type(param_value))
|
115 |
+
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
|
116 |
+
exp[param_name] = [True if x == 'True' else False for x in param_value]
|
117 |
+
if param_value in ['True', 'False']:
|
118 |
+
if param_value == 'True':
|
119 |
+
exp[param_name] = True
|
120 |
+
else:
|
121 |
+
exp[param_name] = False
|
122 |
+
|
123 |
+
|
124 |
+
experiments[exp_id] = unflatten(exp, splitter='path')
|
125 |
+
|
126 |
+
return experiments, hyper_params
|
127 |
+
|
128 |
+
def get_from_dict(dict, keys):
|
129 |
+
return reduce(operator.getitem, keys, dict)
|
130 |
+
|
131 |
+
def save_dict_to_yaml(obj, filename, mode='w'):
|
132 |
+
with open(filename, mode) as f:
|
133 |
+
yaml.dump(obj, f, default_flow_style=False)
|
134 |
+
|
135 |
+
def run_grid_search_experiments(
|
136 |
+
args,
|
137 |
+
script='train.py',
|
138 |
+
change_wt_name=True
|
139 |
+
):
|
140 |
+
cfg = yaml.safe_load(open(args.cfg))
|
141 |
+
# parse config file to split into a list of configs with tuning hyperparameters separated
|
142 |
+
# Also return the names of tuned hyperparameters hyperparameters
|
143 |
+
different_configs, hyperparams = get_grid_search_configs(
|
144 |
+
cfg,
|
145 |
+
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
|
146 |
+
)
|
147 |
+
logger.info(f'Grid search hparams: \n {hyperparams}')
|
148 |
+
|
149 |
+
# The config file may be missing some default values, so we need to add them
|
150 |
+
different_configs = [update_hparams_from_dict(c) for c in different_configs]
|
151 |
+
logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
|
152 |
+
|
153 |
+
config_to_run = CN(different_configs[args.cfg_id])
|
154 |
+
|
155 |
+
if args.cluster:
|
156 |
+
execute_task_on_cluster(
|
157 |
+
script=script,
|
158 |
+
exp_name=config_to_run.EXP_NAME,
|
159 |
+
output_dir=config_to_run.OUTPUT_DIR,
|
160 |
+
condor_dir=config_to_run.CONDOR_DIR,
|
161 |
+
cfg_file=args.cfg,
|
162 |
+
num_exp=len(different_configs),
|
163 |
+
bid_amount=args.bid,
|
164 |
+
num_workers=config_to_run.DATASET.NUM_WORKERS,
|
165 |
+
memory=args.memory,
|
166 |
+
exp_opts=args.opts,
|
167 |
+
gpu_min_mem=args.gpu_min_mem,
|
168 |
+
gpu_arch=args.gpu_arch,
|
169 |
+
)
|
170 |
+
exit()
|
171 |
+
|
172 |
+
# ==== create logdir using hyperparam settings
|
173 |
+
logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
|
174 |
+
logdir = f'{logtime}_{config_to_run.EXP_NAME}'
|
175 |
+
wt_file = config_to_run.EXP_NAME + '_'
|
176 |
+
for hp in hyperparams:
|
177 |
+
v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
|
178 |
+
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
|
179 |
+
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
|
180 |
+
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
|
181 |
+
os.makedirs(logdir, exist_ok=True)
|
182 |
+
config_to_run.LOGDIR = logdir
|
183 |
+
|
184 |
+
wt_file += 'best.pth'
|
185 |
+
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
|
186 |
+
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
|
187 |
+
|
188 |
+
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
|
189 |
+
|
190 |
+
# save config
|
191 |
+
save_dict_to_yaml(
|
192 |
+
unflatten(flatten(config_to_run)),
|
193 |
+
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
|
194 |
+
)
|
195 |
+
|
196 |
+
return config_to_run
|
utils/default_hparams.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
# Set default hparams to construct new default config
|
4 |
+
# Make sure the defaults are same as in parser
|
5 |
+
hparams = CN()
|
6 |
+
|
7 |
+
# General settings
|
8 |
+
hparams.EXP_NAME = 'default'
|
9 |
+
hparams.PROJECT_NAME = 'default'
|
10 |
+
hparams.OUTPUT_DIR = 'deco_results/'
|
11 |
+
hparams.CONDOR_DIR = '/is/cluster/work/achatterjee/condor/rich/'
|
12 |
+
hparams.LOGDIR = ''
|
13 |
+
|
14 |
+
# Dataset hparams
|
15 |
+
hparams.DATASET = CN()
|
16 |
+
hparams.DATASET.BATCH_SIZE = 64
|
17 |
+
hparams.DATASET.NUM_WORKERS = 4
|
18 |
+
hparams.DATASET.NORMALIZE_IMAGES = True
|
19 |
+
|
20 |
+
# Optimizer hparams
|
21 |
+
hparams.OPTIMIZER = CN()
|
22 |
+
hparams.OPTIMIZER.TYPE = 'adam'
|
23 |
+
hparams.OPTIMIZER.LR = 5e-5
|
24 |
+
hparams.OPTIMIZER.NUM_UPDATE_LR = 10
|
25 |
+
|
26 |
+
# Training hparams
|
27 |
+
hparams.TRAINING = CN()
|
28 |
+
hparams.TRAINING.ENCODER = 'hrnet'
|
29 |
+
hparams.TRAINING.CONTEXT = True
|
30 |
+
hparams.TRAINING.NUM_EPOCHS = 50
|
31 |
+
hparams.TRAINING.SUMMARY_STEPS = 100
|
32 |
+
hparams.TRAINING.CHECKPOINT_EPOCHS = 5
|
33 |
+
hparams.TRAINING.NUM_EARLY_STOP = 10
|
34 |
+
hparams.TRAINING.DATASETS = ['rich']
|
35 |
+
hparams.TRAINING.DATASET_MIX_PDF = ['1.']
|
36 |
+
hparams.TRAINING.DATASET_ROOT_PATH = '/is/cluster/work/achatterjee/rich/npzs'
|
37 |
+
hparams.TRAINING.BEST_MODEL_PATH = '/is/cluster/work/achatterjee/weights/rich/exp/rich_exp.pth'
|
38 |
+
hparams.TRAINING.LOSS_WEIGHTS = 1.
|
39 |
+
hparams.TRAINING.PAL_LOSS_WEIGHTS = 1.
|
40 |
+
|
41 |
+
# Training hparams
|
42 |
+
hparams.VALIDATION = CN()
|
43 |
+
hparams.VALIDATION.SUMMARY_STEPS = 100
|
44 |
+
hparams.VALIDATION.DATASETS = ['rich']
|
45 |
+
hparams.VALIDATION.MAIN_DATASET = 'rich'
|
utils/diff_renderer.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from pytorch3d.renderer import (
|
8 |
+
PerspectiveCameras,
|
9 |
+
RasterizationSettings,
|
10 |
+
DirectionalLights,
|
11 |
+
BlendParams,
|
12 |
+
HardFlatShader,
|
13 |
+
MeshRasterizer,
|
14 |
+
TexturesVertex,
|
15 |
+
TexturesAtlas
|
16 |
+
)
|
17 |
+
from pytorch3d.structures import Meshes
|
18 |
+
|
19 |
+
from .image_utils import get_default_camera
|
20 |
+
from .smpl_uv import get_tenet_texture
|
21 |
+
|
22 |
+
|
23 |
+
class MeshRendererWithDepth(nn.Module):
|
24 |
+
"""
|
25 |
+
A class for rendering a batch of heterogeneous meshes. The class should
|
26 |
+
be initialized with a rasterizer and shader class which each have a forward
|
27 |
+
function.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self, rasterizer, shader):
|
31 |
+
super().__init__()
|
32 |
+
self.rasterizer = rasterizer
|
33 |
+
self.shader = shader
|
34 |
+
|
35 |
+
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
36 |
+
"""
|
37 |
+
Render a batch of images from a batch of meshes by rasterizing and then
|
38 |
+
shading.
|
39 |
+
|
40 |
+
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
|
41 |
+
have one or more barycentric coordinates lying outside the range [0, 1].
|
42 |
+
For a pixel with out of bounds barycentric coordinates with respect to a
|
43 |
+
face f, clipping is required before interpolating the texture uv
|
44 |
+
coordinates and z buffer so that the colors and depths are limited to
|
45 |
+
the range for the corresponding face.
|
46 |
+
"""
|
47 |
+
fragments = self.rasterizer(meshes_world, **kwargs)
|
48 |
+
images = self.shader(fragments, meshes_world, **kwargs)
|
49 |
+
|
50 |
+
mask = (fragments.zbuf > -1).float()
|
51 |
+
|
52 |
+
zbuf = fragments.zbuf.view(images.shape[0], -1)
|
53 |
+
# print(images.shape, zbuf.shape)
|
54 |
+
depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \
|
55 |
+
(zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values)
|
56 |
+
depth = depth.reshape(*images.shape[:3] + (1,))
|
57 |
+
|
58 |
+
images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1)
|
59 |
+
return images
|
60 |
+
|
61 |
+
|
62 |
+
class DifferentiableRenderer(nn.Module):
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
img_h,
|
66 |
+
img_w,
|
67 |
+
focal_length,
|
68 |
+
device='cuda',
|
69 |
+
background_color=(0.0, 0.0, 0.0),
|
70 |
+
texture_mode='smplpix',
|
71 |
+
vertex_colors=None,
|
72 |
+
face_textures=None,
|
73 |
+
smpl_faces=None,
|
74 |
+
is_train=False,
|
75 |
+
is_cam_batch=False,
|
76 |
+
):
|
77 |
+
super(DifferentiableRenderer, self).__init__()
|
78 |
+
self.x = 'a'
|
79 |
+
self.img_h = img_h
|
80 |
+
self.img_w = img_w
|
81 |
+
self.device = device
|
82 |
+
self.focal_length = focal_length
|
83 |
+
K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch)
|
84 |
+
K, R = K.to(device), R.to(device)
|
85 |
+
|
86 |
+
# T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device)
|
87 |
+
if is_cam_batch:
|
88 |
+
T = torch.zeros((K.shape[0], 3)).to(device)
|
89 |
+
else:
|
90 |
+
T = torch.tensor([[0.0, 0.0, 0.0]]).to(device)
|
91 |
+
self.background_color = background_color
|
92 |
+
self.renderer = None
|
93 |
+
smpl_faces = smpl_faces
|
94 |
+
|
95 |
+
if texture_mode == 'smplpix':
|
96 |
+
face_colors = get_tenet_texture(mode=texture_mode).to(device).float()
|
97 |
+
vertex_colors = torch.from_numpy(
|
98 |
+
np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3]
|
99 |
+
).unsqueeze(0).to(device).float()
|
100 |
+
if texture_mode == 'partseg':
|
101 |
+
vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device)
|
102 |
+
face_colors = face_textures.to(device)
|
103 |
+
if texture_mode == 'deco':
|
104 |
+
vertex_colors = vertex_colors[..., :3].to(device)
|
105 |
+
face_colors = face_textures.to(device)
|
106 |
+
|
107 |
+
self.register_buffer('K', K)
|
108 |
+
self.register_buffer('R', R)
|
109 |
+
self.register_buffer('T', T)
|
110 |
+
self.register_buffer('face_colors', face_colors)
|
111 |
+
self.register_buffer('vertex_colors', vertex_colors)
|
112 |
+
self.register_buffer('smpl_faces', smpl_faces)
|
113 |
+
|
114 |
+
self.set_requires_grad(is_train)
|
115 |
+
|
116 |
+
def set_requires_grad(self, val=False):
|
117 |
+
self.K.requires_grad_(val)
|
118 |
+
self.R.requires_grad_(val)
|
119 |
+
self.T.requires_grad_(val)
|
120 |
+
self.face_colors.requires_grad_(val)
|
121 |
+
self.vertex_colors.requires_grad_(val)
|
122 |
+
# check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor
|
123 |
+
if isinstance(self.smpl_faces, torch.FloatTensor):
|
124 |
+
self.smpl_faces.requires_grad_(val)
|
125 |
+
|
126 |
+
def forward(self, vertices, faces=None, R=None, T=None):
|
127 |
+
raise NotImplementedError
|
128 |
+
|
129 |
+
|
130 |
+
class Pytorch3D(DifferentiableRenderer):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
img_h,
|
134 |
+
img_w,
|
135 |
+
focal_length,
|
136 |
+
device='cuda',
|
137 |
+
background_color=(0.0, 0.0, 0.0),
|
138 |
+
texture_mode='smplpix',
|
139 |
+
vertex_colors=None,
|
140 |
+
face_textures=None,
|
141 |
+
smpl_faces=None,
|
142 |
+
model_type='smpl',
|
143 |
+
is_train=False,
|
144 |
+
is_cam_batch=False,
|
145 |
+
):
|
146 |
+
super(Pytorch3D, self).__init__(
|
147 |
+
img_h,
|
148 |
+
img_w,
|
149 |
+
focal_length,
|
150 |
+
device=device,
|
151 |
+
background_color=background_color,
|
152 |
+
texture_mode=texture_mode,
|
153 |
+
vertex_colors=vertex_colors,
|
154 |
+
face_textures=face_textures,
|
155 |
+
smpl_faces=smpl_faces,
|
156 |
+
is_train=is_train,
|
157 |
+
is_cam_batch=is_cam_batch,
|
158 |
+
)
|
159 |
+
|
160 |
+
# this R converts the camera from pyrender NDC to
|
161 |
+
# OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y)
|
162 |
+
# I manually defined it here for convenience
|
163 |
+
self.R = self.R @ torch.tensor(
|
164 |
+
[[[ -1.0, 0.0, 0.0],
|
165 |
+
[ 0.0, -1.0, 0.0],
|
166 |
+
[ 0.0, 0.0, 1.0]]],
|
167 |
+
dtype=self.R.dtype, device=self.R.device,
|
168 |
+
)
|
169 |
+
|
170 |
+
if is_cam_batch:
|
171 |
+
focal_length = self.focal_length
|
172 |
+
else:
|
173 |
+
focal_length = self.focal_length[None, :]
|
174 |
+
|
175 |
+
principal_point = ((self.img_w // 2, self.img_h // 2),)
|
176 |
+
image_size = ((self.img_h, self.img_w),)
|
177 |
+
|
178 |
+
cameras = PerspectiveCameras(
|
179 |
+
device=self.device,
|
180 |
+
focal_length=focal_length,
|
181 |
+
principal_point=principal_point,
|
182 |
+
R=self.R,
|
183 |
+
T=self.T,
|
184 |
+
in_ndc=False,
|
185 |
+
image_size=image_size,
|
186 |
+
)
|
187 |
+
|
188 |
+
for param in cameras.parameters():
|
189 |
+
param.requires_grad_(False)
|
190 |
+
|
191 |
+
raster_settings = RasterizationSettings(
|
192 |
+
image_size=(self.img_h, self.img_w),
|
193 |
+
blur_radius=0.0,
|
194 |
+
max_faces_per_bin=20000,
|
195 |
+
faces_per_pixel=1,
|
196 |
+
)
|
197 |
+
|
198 |
+
lights = DirectionalLights(
|
199 |
+
device=self.device,
|
200 |
+
ambient_color=((1.0, 1.0, 1.0),),
|
201 |
+
diffuse_color=((0.0, 0.0, 0.0),),
|
202 |
+
specular_color=((0.0, 0.0, 0.0),),
|
203 |
+
direction=((0, 1, 0),),
|
204 |
+
)
|
205 |
+
|
206 |
+
blend_params = BlendParams(background_color=self.background_color)
|
207 |
+
|
208 |
+
shader = HardFlatShader(device=self.device,
|
209 |
+
cameras=cameras,
|
210 |
+
blend_params=blend_params,
|
211 |
+
lights=lights)
|
212 |
+
|
213 |
+
self.textures = TexturesVertex(verts_features=self.vertex_colors)
|
214 |
+
|
215 |
+
self.renderer = MeshRendererWithDepth(
|
216 |
+
rasterizer=MeshRasterizer(
|
217 |
+
cameras=cameras,
|
218 |
+
raster_settings=raster_settings
|
219 |
+
),
|
220 |
+
shader=shader,
|
221 |
+
)
|
222 |
+
|
223 |
+
def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None):
|
224 |
+
batch_size = vertices.shape[0]
|
225 |
+
if faces is None:
|
226 |
+
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
227 |
+
|
228 |
+
if R is None:
|
229 |
+
R = self.R.expand(batch_size, -1, -1)
|
230 |
+
|
231 |
+
if T is None:
|
232 |
+
T = self.T.expand(batch_size, -1)
|
233 |
+
|
234 |
+
# convert camera translation to pytorch3d coordinate frame
|
235 |
+
T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1)
|
236 |
+
|
237 |
+
vertex_textures = TexturesVertex(
|
238 |
+
verts_features=self.vertex_colors.expand(batch_size, -1, -1)
|
239 |
+
)
|
240 |
+
|
241 |
+
# face_textures needed because vertex_texture cause interpolation at boundaries
|
242 |
+
if face_atlas:
|
243 |
+
face_textures = TexturesAtlas(atlas=face_atlas)
|
244 |
+
else:
|
245 |
+
face_textures = TexturesAtlas(atlas=self.face_colors)
|
246 |
+
|
247 |
+
# we may need to rotate the mesh
|
248 |
+
meshes = Meshes(verts=vertices, faces=faces, textures=face_textures)
|
249 |
+
images = self.renderer(meshes, R=R, T=T)
|
250 |
+
images = images.permute(0, 3, 1, 2)
|
251 |
+
return images
|
252 |
+
|
253 |
+
|
254 |
+
class NeuralMeshRenderer(DifferentiableRenderer):
|
255 |
+
def __init__(self, *args, **kwargs):
|
256 |
+
import neural_renderer as nr
|
257 |
+
|
258 |
+
super(NeuralMeshRenderer, self).__init__(*args, **kwargs)
|
259 |
+
|
260 |
+
self.neural_renderer = nr.Renderer(
|
261 |
+
dist_coeffs=None,
|
262 |
+
orig_size=self.img_size,
|
263 |
+
image_size=self.img_size,
|
264 |
+
light_intensity_ambient=1,
|
265 |
+
light_intensity_directional=0,
|
266 |
+
anti_aliasing=False,
|
267 |
+
)
|
268 |
+
|
269 |
+
def forward(self, vertices, faces=None, R=None, T=None):
|
270 |
+
batch_size = vertices.shape[0]
|
271 |
+
if faces is None:
|
272 |
+
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
273 |
+
|
274 |
+
if R is None:
|
275 |
+
R = self.R.expand(batch_size, -1, -1)
|
276 |
+
|
277 |
+
if T is None:
|
278 |
+
T = self.T.expand(batch_size, -1)
|
279 |
+
rgb, depth, mask = self.neural_renderer(
|
280 |
+
vertices,
|
281 |
+
faces,
|
282 |
+
textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1),
|
283 |
+
K=self.K.expand(batch_size, -1, -1),
|
284 |
+
R=R,
|
285 |
+
t=T.unsqueeze(1),
|
286 |
+
)
|
287 |
+
return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)
|
utils/get_cfg.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode
|
2 |
+
|
3 |
+
_VALID_TYPES = {tuple, list, str, int, float, bool}
|
4 |
+
|
5 |
+
|
6 |
+
def convert_to_dict(cfg_node, key_list=[]):
|
7 |
+
""" Convert a config node to dictionary """
|
8 |
+
if not isinstance(cfg_node, CfgNode):
|
9 |
+
if type(cfg_node) not in _VALID_TYPES:
|
10 |
+
print("Key {} with value {} is not a valid type; valid types: {}".format(
|
11 |
+
".".join(key_list), type(cfg_node), _VALID_TYPES), )
|
12 |
+
return cfg_node
|
13 |
+
else:
|
14 |
+
cfg_dict = dict(cfg_node)
|
15 |
+
for k, v in cfg_dict.items():
|
16 |
+
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
17 |
+
return cfg_dict
|
utils/hrnet.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from loguru import logger
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from yacs.config import CfgNode as CN
|
8 |
+
|
9 |
+
models = [
|
10 |
+
'hrnet_w32',
|
11 |
+
'hrnet_w48',
|
12 |
+
]
|
13 |
+
|
14 |
+
BN_MOMENTUM = 0.1
|
15 |
+
|
16 |
+
|
17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
18 |
+
"""3x3 convolution with padding"""
|
19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
20 |
+
padding=1, bias=False)
|
21 |
+
|
22 |
+
|
23 |
+
class BasicBlock(nn.Module):
|
24 |
+
expansion = 1
|
25 |
+
|
26 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
27 |
+
super(BasicBlock, self).__init__()
|
28 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
29 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
30 |
+
self.relu = nn.ReLU(inplace=True)
|
31 |
+
self.conv2 = conv3x3(planes, planes)
|
32 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
33 |
+
self.downsample = downsample
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = x
|
38 |
+
|
39 |
+
out = self.conv1(x)
|
40 |
+
out = self.bn1(out)
|
41 |
+
out = self.relu(out)
|
42 |
+
|
43 |
+
out = self.conv2(out)
|
44 |
+
out = self.bn2(out)
|
45 |
+
|
46 |
+
if self.downsample is not None:
|
47 |
+
residual = self.downsample(x)
|
48 |
+
|
49 |
+
out += residual
|
50 |
+
out = self.relu(out)
|
51 |
+
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
class Bottleneck(nn.Module):
|
56 |
+
expansion = 4
|
57 |
+
|
58 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
59 |
+
super(Bottleneck, self).__init__()
|
60 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
61 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
62 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
63 |
+
padding=1, bias=False)
|
64 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
65 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
66 |
+
bias=False)
|
67 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
68 |
+
momentum=BN_MOMENTUM)
|
69 |
+
self.relu = nn.ReLU(inplace=True)
|
70 |
+
self.downsample = downsample
|
71 |
+
self.stride = stride
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
residual = x
|
75 |
+
|
76 |
+
out = self.conv1(x)
|
77 |
+
out = self.bn1(out)
|
78 |
+
out = self.relu(out)
|
79 |
+
|
80 |
+
out = self.conv2(out)
|
81 |
+
out = self.bn2(out)
|
82 |
+
out = self.relu(out)
|
83 |
+
|
84 |
+
out = self.conv3(out)
|
85 |
+
out = self.bn3(out)
|
86 |
+
|
87 |
+
if self.downsample is not None:
|
88 |
+
residual = self.downsample(x)
|
89 |
+
|
90 |
+
out += residual
|
91 |
+
out = self.relu(out)
|
92 |
+
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class HighResolutionModule(nn.Module):
|
97 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
98 |
+
num_channels, fuse_method, multi_scale_output=True):
|
99 |
+
super(HighResolutionModule, self).__init__()
|
100 |
+
self._check_branches(
|
101 |
+
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
102 |
+
|
103 |
+
self.num_inchannels = num_inchannels
|
104 |
+
self.fuse_method = fuse_method
|
105 |
+
self.num_branches = num_branches
|
106 |
+
|
107 |
+
self.multi_scale_output = multi_scale_output
|
108 |
+
|
109 |
+
self.branches = self._make_branches(
|
110 |
+
num_branches, blocks, num_blocks, num_channels)
|
111 |
+
self.fuse_layers = self._make_fuse_layers()
|
112 |
+
self.relu = nn.ReLU(True)
|
113 |
+
|
114 |
+
def _check_branches(self, num_branches, blocks, num_blocks,
|
115 |
+
num_inchannels, num_channels):
|
116 |
+
if num_branches != len(num_blocks):
|
117 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
118 |
+
num_branches, len(num_blocks))
|
119 |
+
logger.error(error_msg)
|
120 |
+
raise ValueError(error_msg)
|
121 |
+
|
122 |
+
if num_branches != len(num_channels):
|
123 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
124 |
+
num_branches, len(num_channels))
|
125 |
+
logger.error(error_msg)
|
126 |
+
raise ValueError(error_msg)
|
127 |
+
|
128 |
+
if num_branches != len(num_inchannels):
|
129 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
130 |
+
num_branches, len(num_inchannels))
|
131 |
+
logger.error(error_msg)
|
132 |
+
raise ValueError(error_msg)
|
133 |
+
|
134 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
135 |
+
stride=1):
|
136 |
+
downsample = None
|
137 |
+
if stride != 1 or \
|
138 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
139 |
+
downsample = nn.Sequential(
|
140 |
+
nn.Conv2d(
|
141 |
+
self.num_inchannels[branch_index],
|
142 |
+
num_channels[branch_index] * block.expansion,
|
143 |
+
kernel_size=1, stride=stride, bias=False
|
144 |
+
),
|
145 |
+
nn.BatchNorm2d(
|
146 |
+
num_channels[branch_index] * block.expansion,
|
147 |
+
momentum=BN_MOMENTUM
|
148 |
+
),
|
149 |
+
)
|
150 |
+
|
151 |
+
layers = []
|
152 |
+
layers.append(
|
153 |
+
block(
|
154 |
+
self.num_inchannels[branch_index],
|
155 |
+
num_channels[branch_index],
|
156 |
+
stride,
|
157 |
+
downsample
|
158 |
+
)
|
159 |
+
)
|
160 |
+
self.num_inchannels[branch_index] = \
|
161 |
+
num_channels[branch_index] * block.expansion
|
162 |
+
for i in range(1, num_blocks[branch_index]):
|
163 |
+
layers.append(
|
164 |
+
block(
|
165 |
+
self.num_inchannels[branch_index],
|
166 |
+
num_channels[branch_index]
|
167 |
+
)
|
168 |
+
)
|
169 |
+
|
170 |
+
return nn.Sequential(*layers)
|
171 |
+
|
172 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
173 |
+
branches = []
|
174 |
+
|
175 |
+
for i in range(num_branches):
|
176 |
+
branches.append(
|
177 |
+
self._make_one_branch(i, block, num_blocks, num_channels)
|
178 |
+
)
|
179 |
+
|
180 |
+
return nn.ModuleList(branches)
|
181 |
+
|
182 |
+
def _make_fuse_layers(self):
|
183 |
+
if self.num_branches == 1:
|
184 |
+
return None
|
185 |
+
|
186 |
+
num_branches = self.num_branches
|
187 |
+
num_inchannels = self.num_inchannels
|
188 |
+
fuse_layers = []
|
189 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
190 |
+
fuse_layer = []
|
191 |
+
for j in range(num_branches):
|
192 |
+
if j > i:
|
193 |
+
fuse_layer.append(
|
194 |
+
nn.Sequential(
|
195 |
+
nn.Conv2d(
|
196 |
+
num_inchannels[j],
|
197 |
+
num_inchannels[i],
|
198 |
+
1, 1, 0, bias=False
|
199 |
+
),
|
200 |
+
nn.BatchNorm2d(num_inchannels[i]),
|
201 |
+
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
|
202 |
+
)
|
203 |
+
)
|
204 |
+
elif j == i:
|
205 |
+
fuse_layer.append(None)
|
206 |
+
else:
|
207 |
+
conv3x3s = []
|
208 |
+
for k in range(i-j):
|
209 |
+
if k == i - j - 1:
|
210 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
211 |
+
conv3x3s.append(
|
212 |
+
nn.Sequential(
|
213 |
+
nn.Conv2d(
|
214 |
+
num_inchannels[j],
|
215 |
+
num_outchannels_conv3x3,
|
216 |
+
3, 2, 1, bias=False
|
217 |
+
),
|
218 |
+
nn.BatchNorm2d(num_outchannels_conv3x3)
|
219 |
+
)
|
220 |
+
)
|
221 |
+
else:
|
222 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
223 |
+
conv3x3s.append(
|
224 |
+
nn.Sequential(
|
225 |
+
nn.Conv2d(
|
226 |
+
num_inchannels[j],
|
227 |
+
num_outchannels_conv3x3,
|
228 |
+
3, 2, 1, bias=False
|
229 |
+
),
|
230 |
+
nn.BatchNorm2d(num_outchannels_conv3x3),
|
231 |
+
nn.ReLU(True)
|
232 |
+
)
|
233 |
+
)
|
234 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
235 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
236 |
+
|
237 |
+
return nn.ModuleList(fuse_layers)
|
238 |
+
|
239 |
+
def get_num_inchannels(self):
|
240 |
+
return self.num_inchannels
|
241 |
+
|
242 |
+
def forward(self, x):
|
243 |
+
if self.num_branches == 1:
|
244 |
+
return [self.branches[0](x[0])]
|
245 |
+
|
246 |
+
for i in range(self.num_branches):
|
247 |
+
x[i] = self.branches[i](x[i])
|
248 |
+
|
249 |
+
x_fuse = []
|
250 |
+
|
251 |
+
for i in range(len(self.fuse_layers)):
|
252 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
253 |
+
for j in range(1, self.num_branches):
|
254 |
+
if i == j:
|
255 |
+
y = y + x[j]
|
256 |
+
else:
|
257 |
+
y = y + self.fuse_layers[i][j](x[j])
|
258 |
+
x_fuse.append(self.relu(y))
|
259 |
+
|
260 |
+
return x_fuse
|
261 |
+
|
262 |
+
|
263 |
+
blocks_dict = {
|
264 |
+
'BASIC': BasicBlock,
|
265 |
+
'BOTTLENECK': Bottleneck
|
266 |
+
}
|
267 |
+
|
268 |
+
|
269 |
+
class PoseHighResolutionNet(nn.Module):
|
270 |
+
|
271 |
+
def __init__(self, cfg):
|
272 |
+
self.inplanes = 64
|
273 |
+
extra = cfg['MODEL']['EXTRA']
|
274 |
+
super(PoseHighResolutionNet, self).__init__()
|
275 |
+
|
276 |
+
self.cfg = extra
|
277 |
+
|
278 |
+
# stem net
|
279 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
280 |
+
bias=False)
|
281 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
282 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
283 |
+
bias=False)
|
284 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
285 |
+
self.relu = nn.ReLU(inplace=True)
|
286 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 4)
|
287 |
+
|
288 |
+
self.stage2_cfg = extra['STAGE2']
|
289 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
290 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
291 |
+
num_channels = [
|
292 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
293 |
+
]
|
294 |
+
self.transition1 = self._make_transition_layer([256], num_channels)
|
295 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
296 |
+
self.stage2_cfg, num_channels)
|
297 |
+
|
298 |
+
self.stage3_cfg = extra['STAGE3']
|
299 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
300 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
301 |
+
num_channels = [
|
302 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
303 |
+
]
|
304 |
+
self.transition2 = self._make_transition_layer(
|
305 |
+
pre_stage_channels, num_channels)
|
306 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
307 |
+
self.stage3_cfg, num_channels)
|
308 |
+
|
309 |
+
self.stage4_cfg = extra['STAGE4']
|
310 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
311 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
312 |
+
num_channels = [
|
313 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
314 |
+
]
|
315 |
+
self.transition3 = self._make_transition_layer(
|
316 |
+
pre_stage_channels, num_channels)
|
317 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
318 |
+
self.stage4_cfg, num_channels, multi_scale_output=True)
|
319 |
+
|
320 |
+
self.final_layer = nn.Conv2d(
|
321 |
+
in_channels=pre_stage_channels[0],
|
322 |
+
out_channels=cfg['MODEL']['NUM_JOINTS'],
|
323 |
+
kernel_size=extra['FINAL_CONV_KERNEL'],
|
324 |
+
stride=1,
|
325 |
+
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
|
326 |
+
)
|
327 |
+
|
328 |
+
self.pretrained_layers = extra['PRETRAINED_LAYERS']
|
329 |
+
|
330 |
+
if extra.DOWNSAMPLE and extra.USE_CONV:
|
331 |
+
self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
|
332 |
+
self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
333 |
+
self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
334 |
+
elif not extra.DOWNSAMPLE and extra.USE_CONV:
|
335 |
+
self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
336 |
+
self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
337 |
+
self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
|
338 |
+
|
339 |
+
def _make_transition_layer(
|
340 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
341 |
+
num_branches_cur = len(num_channels_cur_layer)
|
342 |
+
num_branches_pre = len(num_channels_pre_layer)
|
343 |
+
|
344 |
+
transition_layers = []
|
345 |
+
for i in range(num_branches_cur):
|
346 |
+
if i < num_branches_pre:
|
347 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
348 |
+
transition_layers.append(
|
349 |
+
nn.Sequential(
|
350 |
+
nn.Conv2d(
|
351 |
+
num_channels_pre_layer[i],
|
352 |
+
num_channels_cur_layer[i],
|
353 |
+
3, 1, 1, bias=False
|
354 |
+
),
|
355 |
+
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
356 |
+
nn.ReLU(inplace=True)
|
357 |
+
)
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
transition_layers.append(None)
|
361 |
+
else:
|
362 |
+
conv3x3s = []
|
363 |
+
for j in range(i+1-num_branches_pre):
|
364 |
+
inchannels = num_channels_pre_layer[-1]
|
365 |
+
outchannels = num_channels_cur_layer[i] \
|
366 |
+
if j == i-num_branches_pre else inchannels
|
367 |
+
conv3x3s.append(
|
368 |
+
nn.Sequential(
|
369 |
+
nn.Conv2d(
|
370 |
+
inchannels, outchannels, 3, 2, 1, bias=False
|
371 |
+
),
|
372 |
+
nn.BatchNorm2d(outchannels),
|
373 |
+
nn.ReLU(inplace=True)
|
374 |
+
)
|
375 |
+
)
|
376 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
377 |
+
|
378 |
+
return nn.ModuleList(transition_layers)
|
379 |
+
|
380 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
381 |
+
downsample = None
|
382 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
383 |
+
downsample = nn.Sequential(
|
384 |
+
nn.Conv2d(
|
385 |
+
self.inplanes, planes * block.expansion,
|
386 |
+
kernel_size=1, stride=stride, bias=False
|
387 |
+
),
|
388 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
389 |
+
)
|
390 |
+
|
391 |
+
layers = []
|
392 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
393 |
+
self.inplanes = planes * block.expansion
|
394 |
+
for i in range(1, blocks):
|
395 |
+
layers.append(block(self.inplanes, planes))
|
396 |
+
|
397 |
+
return nn.Sequential(*layers)
|
398 |
+
|
399 |
+
def _make_stage(self, layer_config, num_inchannels,
|
400 |
+
multi_scale_output=True):
|
401 |
+
num_modules = layer_config['NUM_MODULES']
|
402 |
+
num_branches = layer_config['NUM_BRANCHES']
|
403 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
404 |
+
num_channels = layer_config['NUM_CHANNELS']
|
405 |
+
block = blocks_dict[layer_config['BLOCK']]
|
406 |
+
fuse_method = layer_config['FUSE_METHOD']
|
407 |
+
|
408 |
+
modules = []
|
409 |
+
for i in range(num_modules):
|
410 |
+
# multi_scale_output is only used last module
|
411 |
+
if not multi_scale_output and i == num_modules - 1:
|
412 |
+
reset_multi_scale_output = False
|
413 |
+
else:
|
414 |
+
reset_multi_scale_output = True
|
415 |
+
|
416 |
+
modules.append(
|
417 |
+
HighResolutionModule(
|
418 |
+
num_branches,
|
419 |
+
block,
|
420 |
+
num_blocks,
|
421 |
+
num_inchannels,
|
422 |
+
num_channels,
|
423 |
+
fuse_method,
|
424 |
+
reset_multi_scale_output
|
425 |
+
)
|
426 |
+
)
|
427 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
428 |
+
|
429 |
+
return nn.Sequential(*modules), num_inchannels
|
430 |
+
|
431 |
+
def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
|
432 |
+
layers = []
|
433 |
+
for i in range(num_layers):
|
434 |
+
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
435 |
+
layers.append(
|
436 |
+
nn.Conv2d(
|
437 |
+
in_channels=num_channel, out_channels=num_channel,
|
438 |
+
kernel_size=kernel_size, stride=1, padding=1, bias=False,
|
439 |
+
)
|
440 |
+
)
|
441 |
+
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
442 |
+
layers.append(nn.ReLU(inplace=True))
|
443 |
+
|
444 |
+
return nn.Sequential(*layers)
|
445 |
+
|
446 |
+
def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
|
447 |
+
layers = []
|
448 |
+
for i in range(num_layers):
|
449 |
+
layers.append(
|
450 |
+
nn.Conv2d(
|
451 |
+
in_channels=num_channel, out_channels=num_channel,
|
452 |
+
kernel_size=kernel_size, stride=2, padding=1, bias=False,
|
453 |
+
)
|
454 |
+
)
|
455 |
+
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
456 |
+
layers.append(nn.ReLU(inplace=True))
|
457 |
+
|
458 |
+
return nn.Sequential(*layers)
|
459 |
+
|
460 |
+
def forward(self, x):
|
461 |
+
x = self.conv1(x)
|
462 |
+
x = self.bn1(x)
|
463 |
+
x = self.relu(x)
|
464 |
+
x = self.conv2(x)
|
465 |
+
x = self.bn2(x)
|
466 |
+
x = self.relu(x)
|
467 |
+
x = self.layer1(x)
|
468 |
+
|
469 |
+
x_list = []
|
470 |
+
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
471 |
+
if self.transition1[i] is not None:
|
472 |
+
x_list.append(self.transition1[i](x))
|
473 |
+
else:
|
474 |
+
x_list.append(x)
|
475 |
+
y_list = self.stage2(x_list)
|
476 |
+
|
477 |
+
x_list = []
|
478 |
+
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
479 |
+
if self.transition2[i] is not None:
|
480 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
481 |
+
else:
|
482 |
+
x_list.append(y_list[i])
|
483 |
+
y_list = self.stage3(x_list)
|
484 |
+
|
485 |
+
x_list = []
|
486 |
+
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
487 |
+
if self.transition3[i] is not None:
|
488 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
489 |
+
else:
|
490 |
+
x_list.append(y_list[i])
|
491 |
+
x = self.stage4(x_list)
|
492 |
+
|
493 |
+
if self.cfg.DOWNSAMPLE:
|
494 |
+
if self.cfg.USE_CONV:
|
495 |
+
# Downsampling with strided convolutions
|
496 |
+
x1 = self.downsample_stage_1(x[0])
|
497 |
+
x2 = self.downsample_stage_2(x[1])
|
498 |
+
x3 = self.downsample_stage_3(x[2])
|
499 |
+
x = torch.cat([x1, x2, x3, x[3]], 1)
|
500 |
+
else:
|
501 |
+
# Downsampling with interpolation
|
502 |
+
x0_h, x0_w = x[3].size(2), x[3].size(3)
|
503 |
+
x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
504 |
+
x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
505 |
+
x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
506 |
+
x = torch.cat([x1, x2, x3, x[3]], 1)
|
507 |
+
else:
|
508 |
+
if self.cfg.USE_CONV:
|
509 |
+
# Upsampling with interpolations + convolutions
|
510 |
+
x1 = self.upsample_stage_2(x[1])
|
511 |
+
x2 = self.upsample_stage_3(x[2])
|
512 |
+
x3 = self.upsample_stage_4(x[3])
|
513 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
514 |
+
else:
|
515 |
+
# Upsampling with interpolation
|
516 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
517 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
518 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
519 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
520 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
521 |
+
|
522 |
+
return x
|
523 |
+
|
524 |
+
def init_weights(self, pretrained=''):
|
525 |
+
logger.info('=> init weights from normal distribution')
|
526 |
+
for m in self.modules():
|
527 |
+
if isinstance(m, nn.Conv2d):
|
528 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
529 |
+
nn.init.normal_(m.weight, std=0.001)
|
530 |
+
for name, _ in m.named_parameters():
|
531 |
+
if name in ['bias']:
|
532 |
+
nn.init.constant_(m.bias, 0)
|
533 |
+
elif isinstance(m, nn.BatchNorm2d):
|
534 |
+
nn.init.constant_(m.weight, 1)
|
535 |
+
nn.init.constant_(m.bias, 0)
|
536 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
537 |
+
nn.init.normal_(m.weight, std=0.001)
|
538 |
+
for name, _ in m.named_parameters():
|
539 |
+
if name in ['bias']:
|
540 |
+
nn.init.constant_(m.bias, 0)
|
541 |
+
|
542 |
+
if os.path.isfile(pretrained):
|
543 |
+
pretrained_state_dict = torch.load(pretrained)
|
544 |
+
logger.info('=> loading pretrained model {}'.format(pretrained))
|
545 |
+
|
546 |
+
need_init_state_dict = {}
|
547 |
+
for name, m in pretrained_state_dict.items():
|
548 |
+
if name.split('.')[0] in self.pretrained_layers \
|
549 |
+
or self.pretrained_layers[0] is '*':
|
550 |
+
need_init_state_dict[name] = m
|
551 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
552 |
+
elif pretrained:
|
553 |
+
logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
|
554 |
+
# raise ValueError('{} is not exist!'.format(pretrained))
|
555 |
+
|
556 |
+
|
557 |
+
def get_pose_net(cfg, is_train):
|
558 |
+
model = PoseHighResolutionNet(cfg)
|
559 |
+
|
560 |
+
if is_train and cfg['MODEL']['INIT_WEIGHTS']:
|
561 |
+
model.init_weights(cfg['MODEL']['PRETRAINED'])
|
562 |
+
|
563 |
+
return model
|
564 |
+
|
565 |
+
|
566 |
+
def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
|
567 |
+
# pose_multi_resoluton_net related params
|
568 |
+
HRNET = CN()
|
569 |
+
HRNET.PRETRAINED_LAYERS = [
|
570 |
+
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
|
571 |
+
'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
|
572 |
+
]
|
573 |
+
HRNET.STEM_INPLANES = 64
|
574 |
+
HRNET.FINAL_CONV_KERNEL = 1
|
575 |
+
HRNET.STAGE2 = CN()
|
576 |
+
HRNET.STAGE2.NUM_MODULES = 1
|
577 |
+
HRNET.STAGE2.NUM_BRANCHES = 2
|
578 |
+
HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
579 |
+
HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
|
580 |
+
HRNET.STAGE2.BLOCK = 'BASIC'
|
581 |
+
HRNET.STAGE2.FUSE_METHOD = 'SUM'
|
582 |
+
HRNET.STAGE3 = CN()
|
583 |
+
HRNET.STAGE3.NUM_MODULES = 4
|
584 |
+
HRNET.STAGE3.NUM_BRANCHES = 3
|
585 |
+
HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
586 |
+
HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
|
587 |
+
HRNET.STAGE3.BLOCK = 'BASIC'
|
588 |
+
HRNET.STAGE3.FUSE_METHOD = 'SUM'
|
589 |
+
HRNET.STAGE4 = CN()
|
590 |
+
HRNET.STAGE4.NUM_MODULES = 3
|
591 |
+
HRNET.STAGE4.NUM_BRANCHES = 4
|
592 |
+
HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
593 |
+
HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
|
594 |
+
HRNET.STAGE4.BLOCK = 'BASIC'
|
595 |
+
HRNET.STAGE4.FUSE_METHOD = 'SUM'
|
596 |
+
HRNET.DOWNSAMPLE = downsample
|
597 |
+
HRNET.USE_CONV = use_conv
|
598 |
+
|
599 |
+
cfg = CN()
|
600 |
+
cfg.MODEL = CN()
|
601 |
+
cfg.MODEL.INIT_WEIGHTS = True
|
602 |
+
cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
|
603 |
+
cfg.MODEL.EXTRA = HRNET
|
604 |
+
cfg.MODEL.NUM_JOINTS = 24
|
605 |
+
return cfg
|
606 |
+
|
607 |
+
|
608 |
+
def hrnet_w32(
|
609 |
+
pretrained=True,
|
610 |
+
pretrained_ckpt='data/weights/pose_hrnet_w32_256x192.pth',
|
611 |
+
downsample=False,
|
612 |
+
use_conv=False,
|
613 |
+
):
|
614 |
+
cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
|
615 |
+
return get_pose_net(cfg, is_train=True)
|
616 |
+
|
617 |
+
|
618 |
+
def hrnet_w48(
|
619 |
+
pretrained=True,
|
620 |
+
pretrained_ckpt='data/weights/pose_hrnet_w48_256x192.pth',
|
621 |
+
downsample=False,
|
622 |
+
use_conv=False,
|
623 |
+
):
|
624 |
+
cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
|
625 |
+
return get_pose_net(cfg, is_train=True)
|
utils/image_utils.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file contains functions that are used to perform data augmentation.
|
3 |
+
"""
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
import json
|
7 |
+
from skimage.transform import rotate, resize
|
8 |
+
import numpy as np
|
9 |
+
import jpeg4py as jpeg
|
10 |
+
from trimesh.visual import color
|
11 |
+
|
12 |
+
# from ..core import constants
|
13 |
+
# from .vibe_image_utils import gen_trans_from_patch_cv
|
14 |
+
from .kp_utils import map_smpl_to_common, get_smpl_joint_names
|
15 |
+
|
16 |
+
def get_transform(center, scale, res, rot=0):
|
17 |
+
"""Generate transformation matrix."""
|
18 |
+
h = 200 * scale
|
19 |
+
t = np.zeros((3, 3))
|
20 |
+
t[0, 0] = float(res[1]) / h
|
21 |
+
t[1, 1] = float(res[0]) / h
|
22 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
23 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
24 |
+
t[2, 2] = 1
|
25 |
+
if not rot == 0:
|
26 |
+
rot = -rot # To match direction of rotation from cropping
|
27 |
+
rot_mat = np.zeros((3, 3))
|
28 |
+
rot_rad = rot * np.pi / 180
|
29 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
30 |
+
rot_mat[0, :2] = [cs, -sn]
|
31 |
+
rot_mat[1, :2] = [sn, cs]
|
32 |
+
rot_mat[2, 2] = 1
|
33 |
+
# Need to rotate around center
|
34 |
+
t_mat = np.eye(3)
|
35 |
+
t_mat[0, 2] = -res[1] / 2
|
36 |
+
t_mat[1, 2] = -res[0] / 2
|
37 |
+
t_inv = t_mat.copy()
|
38 |
+
t_inv[:2, 2] *= -1
|
39 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
40 |
+
return t
|
41 |
+
|
42 |
+
|
43 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
44 |
+
"""Transform pixel location to different reference."""
|
45 |
+
t = get_transform(center, scale, res, rot=rot)
|
46 |
+
if invert:
|
47 |
+
t = np.linalg.inv(t)
|
48 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
49 |
+
new_pt = np.dot(t, new_pt)
|
50 |
+
return new_pt[:2].astype(int) + 1
|
51 |
+
|
52 |
+
|
53 |
+
def crop(img, center, scale, res, rot=0):
|
54 |
+
"""Crop image according to the supplied bounding box."""
|
55 |
+
# Upper left point
|
56 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
57 |
+
# Bottom right point
|
58 |
+
br = np.array(transform([res[0] + 1,
|
59 |
+
res[1] + 1], center, scale, res, invert=1)) - 1
|
60 |
+
|
61 |
+
# Padding so that when rotated proper amount of context is included
|
62 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
63 |
+
if not rot == 0:
|
64 |
+
ul -= pad
|
65 |
+
br += pad
|
66 |
+
|
67 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
68 |
+
if len(img.shape) > 2:
|
69 |
+
new_shape += [img.shape[2]]
|
70 |
+
new_img = np.zeros(new_shape)
|
71 |
+
|
72 |
+
# Range to fill new array
|
73 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
74 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
75 |
+
# Range to sample from original image
|
76 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
77 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
78 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
79 |
+
old_x[0]:old_x[1]]
|
80 |
+
|
81 |
+
if not rot == 0:
|
82 |
+
# Remove padding
|
83 |
+
|
84 |
+
new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
|
85 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
86 |
+
|
87 |
+
# resize image
|
88 |
+
new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
|
89 |
+
return new_img
|
90 |
+
|
91 |
+
|
92 |
+
def crop_cv2(img, center, scale, res, rot=0):
|
93 |
+
c_x, c_y = center
|
94 |
+
c_x, c_y = int(round(c_x)), int(round(c_y))
|
95 |
+
patch_width, patch_height = int(round(res[0])), int(round(res[1]))
|
96 |
+
bb_width = bb_height = int(round(scale * 200.))
|
97 |
+
|
98 |
+
trans = gen_trans_from_patch_cv(
|
99 |
+
c_x, c_y, bb_width, bb_height,
|
100 |
+
patch_width, patch_height,
|
101 |
+
scale=1.0, rot=rot, inv=False,
|
102 |
+
)
|
103 |
+
|
104 |
+
crop_img = cv2.warpAffine(
|
105 |
+
img, trans, (int(patch_width), int(patch_height)),
|
106 |
+
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
|
107 |
+
)
|
108 |
+
|
109 |
+
return crop_img
|
110 |
+
|
111 |
+
|
112 |
+
def get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start):
|
113 |
+
y1 = int((height - crop_height) * h_start)
|
114 |
+
y2 = y1 + crop_height
|
115 |
+
x1 = int((width - crop_width) * w_start)
|
116 |
+
x2 = x1 + crop_width
|
117 |
+
return x1, y1, x2, y2
|
118 |
+
|
119 |
+
|
120 |
+
def random_crop(center, scale, crop_scale_factor, axis='all'):
|
121 |
+
'''
|
122 |
+
center: bbox center [x,y]
|
123 |
+
scale: bbox height / 200
|
124 |
+
crop_scale_factor: amount of cropping to be applied
|
125 |
+
axis: axis which cropping will be applied
|
126 |
+
"x": center the y axis and get random crops in x
|
127 |
+
"y": center the x axis and get random crops in y
|
128 |
+
"all": randomly crop from all locations
|
129 |
+
'''
|
130 |
+
orig_size = int(scale * 200.)
|
131 |
+
ul = (center - (orig_size / 2.)).astype(int)
|
132 |
+
|
133 |
+
crop_size = int(orig_size * crop_scale_factor)
|
134 |
+
|
135 |
+
if axis == 'all':
|
136 |
+
h_start = np.random.rand()
|
137 |
+
w_start = np.random.rand()
|
138 |
+
elif axis == 'x':
|
139 |
+
h_start = np.random.rand()
|
140 |
+
w_start = 0.5
|
141 |
+
elif axis == 'y':
|
142 |
+
h_start = 0.5
|
143 |
+
w_start = np.random.rand()
|
144 |
+
else:
|
145 |
+
raise ValueError(f'axis {axis} is undefined!')
|
146 |
+
|
147 |
+
x1, y1, x2, y2 = get_random_crop_coords(
|
148 |
+
height=orig_size,
|
149 |
+
width=orig_size,
|
150 |
+
crop_height=crop_size,
|
151 |
+
crop_width=crop_size,
|
152 |
+
h_start=h_start,
|
153 |
+
w_start=w_start,
|
154 |
+
)
|
155 |
+
scale = (y2 - y1) / 200.
|
156 |
+
center = ul + np.array([(y1 + y2) / 2, (x1 + x2) / 2])
|
157 |
+
return center, scale
|
158 |
+
|
159 |
+
|
160 |
+
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
161 |
+
"""'Undo' the image cropping/resizing.
|
162 |
+
This function is used when evaluating mask/part segmentation.
|
163 |
+
"""
|
164 |
+
res = img.shape[:2]
|
165 |
+
# Upper left point
|
166 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
167 |
+
# Bottom right point
|
168 |
+
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
|
169 |
+
# size of cropped image
|
170 |
+
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
171 |
+
|
172 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
173 |
+
if len(img.shape) > 2:
|
174 |
+
new_shape += [img.shape[2]]
|
175 |
+
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
176 |
+
# Range to fill new array
|
177 |
+
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
178 |
+
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
179 |
+
# Range to sample from original image
|
180 |
+
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
181 |
+
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
182 |
+
img = resize(img, crop_shape) #, interp='nearest') # scipy.misc.imresize(img, crop_shape, interp='nearest')
|
183 |
+
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
184 |
+
return new_img
|
185 |
+
|
186 |
+
|
187 |
+
def rot_aa(aa, rot):
|
188 |
+
"""Rotate axis angle parameters."""
|
189 |
+
# pose parameters
|
190 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
191 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
192 |
+
[0, 0, 1]])
|
193 |
+
# find the rotation of the body in camera frame
|
194 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
195 |
+
# apply the global rotation to the global orientation
|
196 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
197 |
+
aa = (resrot.T)[0]
|
198 |
+
return aa
|
199 |
+
|
200 |
+
|
201 |
+
def flip_img(img):
|
202 |
+
"""Flip rgb images or masks.
|
203 |
+
channels come last, e.g. (256,256,3).
|
204 |
+
"""
|
205 |
+
img = np.fliplr(img)
|
206 |
+
return img
|
207 |
+
|
208 |
+
|
209 |
+
def flip_kp(kp):
|
210 |
+
"""Flip keypoints."""
|
211 |
+
if len(kp) == 24:
|
212 |
+
flipped_parts = constants.J24_FLIP_PERM
|
213 |
+
elif len(kp) == 49:
|
214 |
+
flipped_parts = constants.J49_FLIP_PERM
|
215 |
+
kp = kp[flipped_parts]
|
216 |
+
kp[:, 0] = - kp[:, 0]
|
217 |
+
return kp
|
218 |
+
|
219 |
+
|
220 |
+
def flip_pose(pose):
|
221 |
+
"""Flip pose.
|
222 |
+
The flipping is based on SMPL parameters.
|
223 |
+
"""
|
224 |
+
flipped_parts = constants.SMPL_POSE_FLIP_PERM
|
225 |
+
pose = pose[flipped_parts]
|
226 |
+
# we also negate the second and the third dimension of the axis-angle
|
227 |
+
pose[1::3] = -pose[1::3]
|
228 |
+
pose[2::3] = -pose[2::3]
|
229 |
+
return pose
|
230 |
+
|
231 |
+
|
232 |
+
def denormalize_images(images):
|
233 |
+
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
|
234 |
+
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
|
235 |
+
return images
|
236 |
+
|
237 |
+
|
238 |
+
def read_img(img_fn):
|
239 |
+
# return pil_img.fromarray(
|
240 |
+
# cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB))
|
241 |
+
# with open(img_fn, 'rb') as f:
|
242 |
+
# img = pil_img.open(f).convert('RGB')
|
243 |
+
# return img
|
244 |
+
if img_fn.endswith('jpeg') or img_fn.endswith('jpg'):
|
245 |
+
try:
|
246 |
+
with open(img_fn, 'rb') as f:
|
247 |
+
img = np.array(jpeg.JPEG(f).decode())
|
248 |
+
except jpeg.JPEGRuntimeError:
|
249 |
+
# logger.warning('{} produced a JPEGRuntimeError', img_fn)
|
250 |
+
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
251 |
+
else:
|
252 |
+
# elif img_fn.endswith('png') or img_fn.endswith('JPG') or img_fn.endswith(''):
|
253 |
+
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
254 |
+
return img.astype(np.float32)
|
255 |
+
|
256 |
+
|
257 |
+
def generate_heatmaps_2d(joints, joints_vis, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
258 |
+
'''
|
259 |
+
:param joints: [num_joints, 3]
|
260 |
+
:param joints_vis: [num_joints, 3]
|
261 |
+
:return: target, target_weight(1: visible, 0: invisible)
|
262 |
+
'''
|
263 |
+
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
264 |
+
target_weight[:, 0] = joints_vis[:, 0]
|
265 |
+
|
266 |
+
target = np.zeros((num_joints, heatmap_size, heatmap_size), dtype=np.float32)
|
267 |
+
|
268 |
+
tmp_size = sigma * 3
|
269 |
+
|
270 |
+
# denormalize joint into heatmap coordinates
|
271 |
+
joints = (joints + 1.) * (image_size / 2.)
|
272 |
+
|
273 |
+
for joint_id in range(num_joints):
|
274 |
+
feat_stride = image_size / heatmap_size
|
275 |
+
mu_x = int(joints[joint_id][0] / feat_stride + 0.5)
|
276 |
+
mu_y = int(joints[joint_id][1] / feat_stride + 0.5)
|
277 |
+
# Check that any part of the gaussian is in-bounds
|
278 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
279 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
280 |
+
if ul[0] >= heatmap_size or ul[1] >= heatmap_size \
|
281 |
+
or br[0] < 0 or br[1] < 0:
|
282 |
+
# If not, just return the image as is
|
283 |
+
target_weight[joint_id] = 0
|
284 |
+
continue
|
285 |
+
|
286 |
+
# # Generate gaussian
|
287 |
+
size = 2 * tmp_size + 1
|
288 |
+
x = np.arange(0, size, 1, np.float32)
|
289 |
+
y = x[:, np.newaxis]
|
290 |
+
x0 = y0 = size // 2
|
291 |
+
# The gaussian is not normalized, we want the center value to equal 1
|
292 |
+
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
293 |
+
|
294 |
+
# Usable gaussian range
|
295 |
+
g_x = max(0, -ul[0]), min(br[0], heatmap_size) - ul[0]
|
296 |
+
g_y = max(0, -ul[1]), min(br[1], heatmap_size) - ul[1]
|
297 |
+
# Image range
|
298 |
+
img_x = max(0, ul[0]), min(br[0], heatmap_size)
|
299 |
+
img_y = max(0, ul[1]), min(br[1], heatmap_size)
|
300 |
+
|
301 |
+
v = target_weight[joint_id]
|
302 |
+
if v > 0.5:
|
303 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
304 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
305 |
+
|
306 |
+
return target, target_weight
|
307 |
+
|
308 |
+
|
309 |
+
def generate_part_labels(vertices, faces, cam_t, neural_renderer, body_part_texture, K, R, part_bins):
|
310 |
+
batch_size = vertices.shape[0]
|
311 |
+
|
312 |
+
body_parts, depth, mask = neural_renderer(
|
313 |
+
vertices,
|
314 |
+
faces.expand(batch_size, -1, -1),
|
315 |
+
textures=body_part_texture.expand(batch_size, -1, -1, -1, -1, -1),
|
316 |
+
K=K.expand(batch_size, -1, -1),
|
317 |
+
R=R.expand(batch_size, -1, -1),
|
318 |
+
t=cam_t.unsqueeze(1),
|
319 |
+
)
|
320 |
+
|
321 |
+
render_rgb = body_parts.clone()
|
322 |
+
|
323 |
+
body_parts = body_parts.permute(0, 2, 3, 1)
|
324 |
+
body_parts *= 255. # multiply it with 255 to make labels distant
|
325 |
+
body_parts, _ = body_parts.max(-1) # reduce to single channel
|
326 |
+
|
327 |
+
body_parts = torch.bucketize(body_parts.detach(), part_bins, right=True) # np.digitize(body_parts, bins, right=True)
|
328 |
+
|
329 |
+
# add 1 to make background label 0
|
330 |
+
body_parts = body_parts.long() + 1
|
331 |
+
body_parts = body_parts * mask.detach()
|
332 |
+
|
333 |
+
return body_parts.long(), render_rgb
|
334 |
+
|
335 |
+
|
336 |
+
def generate_heatmaps_2d_batch(joints, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
337 |
+
batch_size = joints.shape[0]
|
338 |
+
|
339 |
+
joints = joints.detach().cpu().numpy()
|
340 |
+
joints_vis = np.ones_like(joints)
|
341 |
+
|
342 |
+
heatmaps = []
|
343 |
+
heatmaps_vis = []
|
344 |
+
for i in range(batch_size):
|
345 |
+
hm, hm_vis = generate_heatmaps_2d(joints[i], joints_vis[i], num_joints, heatmap_size, image_size, sigma)
|
346 |
+
heatmaps.append(hm)
|
347 |
+
heatmaps_vis.append(hm_vis)
|
348 |
+
|
349 |
+
return torch.from_numpy(np.stack(heatmaps)).float().to('cuda'), \
|
350 |
+
torch.from_numpy(np.stack(heatmaps_vis)).float().to('cuda')
|
351 |
+
|
352 |
+
|
353 |
+
def get_body_part_texture(faces, model_type='smpl', non_parametric=False):
|
354 |
+
if model_type == 'smpl':
|
355 |
+
n_vertices = 6890
|
356 |
+
segmentation_path = 'data/smpl_vert_segmentation.json'
|
357 |
+
if model_type == 'smplx':
|
358 |
+
n_vertices = 10475
|
359 |
+
segmentation_path = 'data/smplx_vert_segmentation.json'
|
360 |
+
|
361 |
+
with open(segmentation_path, 'rb') as f:
|
362 |
+
part_segmentation = json.load(f)
|
363 |
+
|
364 |
+
# map all vertex ids to the joint ids
|
365 |
+
joint_names = get_smpl_joint_names()
|
366 |
+
smplx_extra_joint_names = ['leftEye', 'eyeballs', 'rightEye']
|
367 |
+
body_vert_idx = np.zeros((n_vertices), dtype=np.int32) - 1 # -1 for missing label
|
368 |
+
for i, (k, v) in enumerate(part_segmentation.items()):
|
369 |
+
if k in smplx_extra_joint_names and model_type == 'smplx':
|
370 |
+
k = 'head' # map all extra smplx face joints to head
|
371 |
+
body_joint_idx = joint_names.index(k)
|
372 |
+
body_vert_idx[v] = body_joint_idx
|
373 |
+
|
374 |
+
# pare implementation
|
375 |
+
# import joblib
|
376 |
+
# part_segmentation = joblib.load('data/smpl_partSegmentation_mapping.pkl')
|
377 |
+
# body_vert_idx = part_segmentation['smpl_index']
|
378 |
+
|
379 |
+
n_parts = 24.
|
380 |
+
|
381 |
+
if non_parametric:
|
382 |
+
# reduce the number of body_parts to 14
|
383 |
+
# by mapping some joints to others
|
384 |
+
n_parts = 14.
|
385 |
+
joint_mapping = map_smpl_to_common()
|
386 |
+
|
387 |
+
for jm in joint_mapping:
|
388 |
+
for j in jm[0]:
|
389 |
+
body_vert_idx[body_vert_idx==j] = jm[1]
|
390 |
+
|
391 |
+
vertex_colors = np.ones((n_vertices, 4))
|
392 |
+
vertex_colors[:, :3] = body_vert_idx[..., None]
|
393 |
+
|
394 |
+
vertex_colors = color.to_rgba(vertex_colors)
|
395 |
+
vertex_colors = vertex_colors[:, :3]/255.
|
396 |
+
|
397 |
+
face_colors = vertex_colors[faces].min(axis=1)
|
398 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 3), dtype=np.float32)
|
399 |
+
# texture[0, :, 0, 0, :] = face_colors[:, :3] / n_parts
|
400 |
+
texture[0, :, 0, 0, :] = face_colors[:, :3]
|
401 |
+
|
402 |
+
vertex_colors = torch.from_numpy(vertex_colors).float()
|
403 |
+
texture = torch.from_numpy(texture).float()
|
404 |
+
return vertex_colors, texture
|
405 |
+
|
406 |
+
|
407 |
+
def get_default_camera(focal_length, img_h, img_w, is_cam_batch=False):
|
408 |
+
if not is_cam_batch:
|
409 |
+
K = torch.eye(3)
|
410 |
+
K[0, 0] = focal_length
|
411 |
+
K[1, 1] = focal_length
|
412 |
+
K[2, 2] = 1
|
413 |
+
K[0, 2] = img_w / 2.
|
414 |
+
K[1, 2] = img_h / 2.
|
415 |
+
K = K[None, :, :]
|
416 |
+
R = torch.eye(3)[None, :, :]
|
417 |
+
else:
|
418 |
+
bs = focal_length.shape[0]
|
419 |
+
K = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
420 |
+
K[:, 0, 0] = focal_length[:, 0]
|
421 |
+
K[:, 1, 1] = focal_length[:, 1]
|
422 |
+
K[:, 2, 2] = 1
|
423 |
+
K[:, 0, 2] = img_w / 2.
|
424 |
+
K[:, 1, 2] = img_h / 2.
|
425 |
+
R = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
426 |
+
return K, R
|
427 |
+
|
428 |
+
|
429 |
+
def read_exif_data(img_fname):
|
430 |
+
import PIL.Image
|
431 |
+
import PIL.ExifTags
|
432 |
+
|
433 |
+
img = PIL.Image.open(img_fname)
|
434 |
+
exif_data = img._getexif()
|
435 |
+
|
436 |
+
if exif_data == None:
|
437 |
+
return None
|
438 |
+
|
439 |
+
exif = {
|
440 |
+
PIL.ExifTags.TAGS[k]: v
|
441 |
+
for k, v in exif_data.items()
|
442 |
+
if k in PIL.ExifTags.TAGS
|
443 |
+
}
|
444 |
+
return exif
|
utils/kp_utils.py
ADDED
@@ -0,0 +1,1114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def keypoint_hflip(kp, img_width):
|
5 |
+
# Flip a keypoint horizontally around the y-axis
|
6 |
+
# kp N,2
|
7 |
+
if len(kp.shape) == 2:
|
8 |
+
kp[:,0] = (img_width - 1.) - kp[:,0]
|
9 |
+
elif len(kp.shape) == 3:
|
10 |
+
kp[:, :, 0] = (img_width - 1.) - kp[:, :, 0]
|
11 |
+
return kp
|
12 |
+
|
13 |
+
|
14 |
+
def convert_kps(joints2d, src, dst):
|
15 |
+
src_names = eval(f'get_{src}_joint_names')()
|
16 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
17 |
+
|
18 |
+
out_joints2d = np.zeros((joints2d.shape[0], len(dst_names), joints2d.shape[-1]))
|
19 |
+
|
20 |
+
for idx, jn in enumerate(dst_names):
|
21 |
+
if jn in src_names:
|
22 |
+
out_joints2d[:, idx] = joints2d[:, src_names.index(jn)]
|
23 |
+
|
24 |
+
return out_joints2d
|
25 |
+
|
26 |
+
|
27 |
+
def get_perm_idxs(src, dst):
|
28 |
+
src_names = eval(f'get_{src}_joint_names')()
|
29 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
30 |
+
idxs = [src_names.index(h) for h in dst_names if h in src_names]
|
31 |
+
return idxs
|
32 |
+
|
33 |
+
|
34 |
+
def get_mpii3d_test_joint_names():
|
35 |
+
return [
|
36 |
+
'headtop', # 'head_top',
|
37 |
+
'neck',
|
38 |
+
'rshoulder',# 'right_shoulder',
|
39 |
+
'relbow',# 'right_elbow',
|
40 |
+
'rwrist',# 'right_wrist',
|
41 |
+
'lshoulder',# 'left_shoulder',
|
42 |
+
'lelbow', # 'left_elbow',
|
43 |
+
'lwrist', # 'left_wrist',
|
44 |
+
'rhip', # 'right_hip',
|
45 |
+
'rknee', # 'right_knee',
|
46 |
+
'rankle',# 'right_ankle',
|
47 |
+
'lhip',# 'left_hip',
|
48 |
+
'lknee',# 'left_knee',
|
49 |
+
'lankle',# 'left_ankle'
|
50 |
+
'hip',# 'pelvis',
|
51 |
+
'Spine (H36M)',# 'spine',
|
52 |
+
'Head (H36M)',# 'head'
|
53 |
+
]
|
54 |
+
|
55 |
+
|
56 |
+
def get_mpii3d_joint_names():
|
57 |
+
return [
|
58 |
+
'spine3', # 0,
|
59 |
+
'spine4', # 1,
|
60 |
+
'spine2', # 2,
|
61 |
+
'Spine (H36M)', #'spine', # 3,
|
62 |
+
'hip', # 'pelvis', # 4,
|
63 |
+
'neck', # 5,
|
64 |
+
'Head (H36M)', # 'head', # 6,
|
65 |
+
"headtop", # 'head_top', # 7,
|
66 |
+
'left_clavicle', # 8,
|
67 |
+
"lshoulder", # 'left_shoulder', # 9,
|
68 |
+
"lelbow", # 'left_elbow',# 10,
|
69 |
+
"lwrist", # 'left_wrist',# 11,
|
70 |
+
'left_hand',# 12,
|
71 |
+
'right_clavicle',# 13,
|
72 |
+
'rshoulder',# 'right_shoulder',# 14,
|
73 |
+
'relbow',# 'right_elbow',# 15,
|
74 |
+
'rwrist',# 'right_wrist',# 16,
|
75 |
+
'right_hand',# 17,
|
76 |
+
'lhip', # left_hip',# 18,
|
77 |
+
'lknee', # 'left_knee',# 19,
|
78 |
+
'lankle', #left ankle # 20
|
79 |
+
'left_foot', # 21
|
80 |
+
'left_toe', # 22
|
81 |
+
"rhip", # 'right_hip',# 23
|
82 |
+
"rknee", # 'right_knee',# 24
|
83 |
+
"rankle", #'right_ankle', # 25
|
84 |
+
'right_foot',# 26
|
85 |
+
'right_toe' # 27
|
86 |
+
]
|
87 |
+
|
88 |
+
|
89 |
+
# def get_insta_joint_names():
|
90 |
+
# return [
|
91 |
+
# 'rheel' , # 0
|
92 |
+
# 'rknee' , # 1
|
93 |
+
# 'rhip' , # 2
|
94 |
+
# 'lhip' , # 3
|
95 |
+
# 'lknee' , # 4
|
96 |
+
# 'lheel' , # 5
|
97 |
+
# 'rwrist' , # 6
|
98 |
+
# 'relbow' , # 7
|
99 |
+
# 'rshoulder' , # 8
|
100 |
+
# 'lshoulder' , # 9
|
101 |
+
# 'lelbow' , # 10
|
102 |
+
# 'lwrist' , # 11
|
103 |
+
# 'neck' , # 12
|
104 |
+
# 'headtop' , # 13
|
105 |
+
# 'nose' , # 14
|
106 |
+
# 'leye' , # 15
|
107 |
+
# 'reye' , # 16
|
108 |
+
# 'lear' , # 17
|
109 |
+
# 'rear' , # 18
|
110 |
+
# 'lbigtoe' , # 19
|
111 |
+
# 'rbigtoe' , # 20
|
112 |
+
# 'lsmalltoe' , # 21
|
113 |
+
# 'rsmalltoe' , # 22
|
114 |
+
# 'lankle' , # 23
|
115 |
+
# 'rankle' , # 24
|
116 |
+
# ]
|
117 |
+
|
118 |
+
|
119 |
+
def get_insta_joint_names():
|
120 |
+
return [
|
121 |
+
'OP RHeel',
|
122 |
+
'OP RKnee',
|
123 |
+
'OP RHip',
|
124 |
+
'OP LHip',
|
125 |
+
'OP LKnee',
|
126 |
+
'OP LHeel',
|
127 |
+
'OP RWrist',
|
128 |
+
'OP RElbow',
|
129 |
+
'OP RShoulder',
|
130 |
+
'OP LShoulder',
|
131 |
+
'OP LElbow',
|
132 |
+
'OP LWrist',
|
133 |
+
'OP Neck',
|
134 |
+
'headtop',
|
135 |
+
'OP Nose',
|
136 |
+
'OP LEye',
|
137 |
+
'OP REye',
|
138 |
+
'OP LEar',
|
139 |
+
'OP REar',
|
140 |
+
'OP LBigToe',
|
141 |
+
'OP RBigToe',
|
142 |
+
'OP LSmallToe',
|
143 |
+
'OP RSmallToe',
|
144 |
+
'OP LAnkle',
|
145 |
+
'OP RAnkle',
|
146 |
+
]
|
147 |
+
|
148 |
+
|
149 |
+
def get_mmpose_joint_names():
|
150 |
+
# this naming is for the first 23 joints of MMPose
|
151 |
+
# does not include hands and face
|
152 |
+
return [
|
153 |
+
'OP Nose', # 1
|
154 |
+
'OP LEye', # 2
|
155 |
+
'OP REye', # 3
|
156 |
+
'OP LEar', # 4
|
157 |
+
'OP REar', # 5
|
158 |
+
'OP LShoulder', # 6
|
159 |
+
'OP RShoulder', # 7
|
160 |
+
'OP LElbow', # 8
|
161 |
+
'OP RElbow', # 9
|
162 |
+
'OP LWrist', # 10
|
163 |
+
'OP RWrist', # 11
|
164 |
+
'OP LHip', # 12
|
165 |
+
'OP RHip', # 13
|
166 |
+
'OP LKnee', # 14
|
167 |
+
'OP RKnee', # 15
|
168 |
+
'OP LAnkle', # 16
|
169 |
+
'OP RAnkle', # 17
|
170 |
+
'OP LBigToe', # 18
|
171 |
+
'OP LSmallToe', # 19
|
172 |
+
'OP LHeel', # 20
|
173 |
+
'OP RBigToe', # 21
|
174 |
+
'OP RSmallToe', # 22
|
175 |
+
'OP RHeel', # 23
|
176 |
+
]
|
177 |
+
|
178 |
+
|
179 |
+
def get_insta_skeleton():
|
180 |
+
return np.array(
|
181 |
+
[
|
182 |
+
[0 , 1],
|
183 |
+
[1 , 2],
|
184 |
+
[2 , 3],
|
185 |
+
[3 , 4],
|
186 |
+
[4 , 5],
|
187 |
+
[6 , 7],
|
188 |
+
[7 , 8],
|
189 |
+
[8 , 9],
|
190 |
+
[9 ,10],
|
191 |
+
[2 , 8],
|
192 |
+
[3 , 9],
|
193 |
+
[10,11],
|
194 |
+
[8 ,12],
|
195 |
+
[9 ,12],
|
196 |
+
[12,13],
|
197 |
+
[12,14],
|
198 |
+
[14,15],
|
199 |
+
[14,16],
|
200 |
+
[15,17],
|
201 |
+
[16,18],
|
202 |
+
[0 ,20],
|
203 |
+
[20,22],
|
204 |
+
[5 ,19],
|
205 |
+
[19,21],
|
206 |
+
[5 ,23],
|
207 |
+
[0 ,24],
|
208 |
+
])
|
209 |
+
|
210 |
+
|
211 |
+
def get_staf_skeleton():
|
212 |
+
return np.array(
|
213 |
+
[
|
214 |
+
[0, 1],
|
215 |
+
[1, 2],
|
216 |
+
[2, 3],
|
217 |
+
[3, 4],
|
218 |
+
[1, 5],
|
219 |
+
[5, 6],
|
220 |
+
[6, 7],
|
221 |
+
[1, 8],
|
222 |
+
[8, 9],
|
223 |
+
[9, 10],
|
224 |
+
[10, 11],
|
225 |
+
[8, 12],
|
226 |
+
[12, 13],
|
227 |
+
[13, 14],
|
228 |
+
[0, 15],
|
229 |
+
[0, 16],
|
230 |
+
[15, 17],
|
231 |
+
[16, 18],
|
232 |
+
[2, 9],
|
233 |
+
[5, 12],
|
234 |
+
[1, 19],
|
235 |
+
[20, 19],
|
236 |
+
]
|
237 |
+
)
|
238 |
+
|
239 |
+
|
240 |
+
def get_staf_joint_names():
|
241 |
+
return [
|
242 |
+
'OP Nose', # 0,
|
243 |
+
'OP Neck', # 1,
|
244 |
+
'OP RShoulder', # 2,
|
245 |
+
'OP RElbow', # 3,
|
246 |
+
'OP RWrist', # 4,
|
247 |
+
'OP LShoulder', # 5,
|
248 |
+
'OP LElbow', # 6,
|
249 |
+
'OP LWrist', # 7,
|
250 |
+
'OP MidHip', # 8,
|
251 |
+
'OP RHip', # 9,
|
252 |
+
'OP RKnee', # 10,
|
253 |
+
'OP RAnkle', # 11,
|
254 |
+
'OP LHip', # 12,
|
255 |
+
'OP LKnee', # 13,
|
256 |
+
'OP LAnkle', # 14,
|
257 |
+
'OP REye', # 15,
|
258 |
+
'OP LEye', # 16,
|
259 |
+
'OP REar', # 17,
|
260 |
+
'OP LEar', # 18,
|
261 |
+
'Neck (LSP)', # 19,
|
262 |
+
'Top of Head (LSP)', # 20,
|
263 |
+
]
|
264 |
+
|
265 |
+
|
266 |
+
def get_spin_op_joint_names():
|
267 |
+
return [
|
268 |
+
'OP Nose', # 0
|
269 |
+
'OP Neck', # 1
|
270 |
+
'OP RShoulder', # 2
|
271 |
+
'OP RElbow', # 3
|
272 |
+
'OP RWrist', # 4
|
273 |
+
'OP LShoulder', # 5
|
274 |
+
'OP LElbow', # 6
|
275 |
+
'OP LWrist', # 7
|
276 |
+
'OP MidHip', # 8
|
277 |
+
'OP RHip', # 9
|
278 |
+
'OP RKnee', # 10
|
279 |
+
'OP RAnkle', # 11
|
280 |
+
'OP LHip', # 12
|
281 |
+
'OP LKnee', # 13
|
282 |
+
'OP LAnkle', # 14
|
283 |
+
'OP REye', # 15
|
284 |
+
'OP LEye', # 16
|
285 |
+
'OP REar', # 17
|
286 |
+
'OP LEar', # 18
|
287 |
+
'OP LBigToe', # 19
|
288 |
+
'OP LSmallToe', # 20
|
289 |
+
'OP LHeel', # 21
|
290 |
+
'OP RBigToe', # 22
|
291 |
+
'OP RSmallToe', # 23
|
292 |
+
'OP RHeel', # 24
|
293 |
+
]
|
294 |
+
|
295 |
+
|
296 |
+
def get_openpose_joint_names():
|
297 |
+
return [
|
298 |
+
'OP Nose', # 0
|
299 |
+
'OP Neck', # 1
|
300 |
+
'OP RShoulder', # 2
|
301 |
+
'OP RElbow', # 3
|
302 |
+
'OP RWrist', # 4
|
303 |
+
'OP LShoulder', # 5
|
304 |
+
'OP LElbow', # 6
|
305 |
+
'OP LWrist', # 7
|
306 |
+
'OP MidHip', # 8
|
307 |
+
'OP RHip', # 9
|
308 |
+
'OP RKnee', # 10
|
309 |
+
'OP RAnkle', # 11
|
310 |
+
'OP LHip', # 12
|
311 |
+
'OP LKnee', # 13
|
312 |
+
'OP LAnkle', # 14
|
313 |
+
'OP REye', # 15
|
314 |
+
'OP LEye', # 16
|
315 |
+
'OP REar', # 17
|
316 |
+
'OP LEar', # 18
|
317 |
+
'OP LBigToe', # 19
|
318 |
+
'OP LSmallToe', # 20
|
319 |
+
'OP LHeel', # 21
|
320 |
+
'OP RBigToe', # 22
|
321 |
+
'OP RSmallToe', # 23
|
322 |
+
'OP RHeel', # 24
|
323 |
+
]
|
324 |
+
|
325 |
+
|
326 |
+
def get_spin_joint_names():
|
327 |
+
return [
|
328 |
+
'OP Nose', # 0
|
329 |
+
'OP Neck', # 1
|
330 |
+
'OP RShoulder', # 2
|
331 |
+
'OP RElbow', # 3
|
332 |
+
'OP RWrist', # 4
|
333 |
+
'OP LShoulder', # 5
|
334 |
+
'OP LElbow', # 6
|
335 |
+
'OP LWrist', # 7
|
336 |
+
'OP MidHip', # 8
|
337 |
+
'OP RHip', # 9
|
338 |
+
'OP RKnee', # 10
|
339 |
+
'OP RAnkle', # 11
|
340 |
+
'OP LHip', # 12
|
341 |
+
'OP LKnee', # 13
|
342 |
+
'OP LAnkle', # 14
|
343 |
+
'OP REye', # 15
|
344 |
+
'OP LEye', # 16
|
345 |
+
'OP REar', # 17
|
346 |
+
'OP LEar', # 18
|
347 |
+
'OP LBigToe', # 19
|
348 |
+
'OP LSmallToe', # 20
|
349 |
+
'OP LHeel', # 21
|
350 |
+
'OP RBigToe', # 22
|
351 |
+
'OP RSmallToe', # 23
|
352 |
+
'OP RHeel', # 24
|
353 |
+
'rankle', # 25
|
354 |
+
'rknee', # 26
|
355 |
+
'rhip', # 27
|
356 |
+
'lhip', # 28
|
357 |
+
'lknee', # 29
|
358 |
+
'lankle', # 30
|
359 |
+
'rwrist', # 31
|
360 |
+
'relbow', # 32
|
361 |
+
'rshoulder', # 33
|
362 |
+
'lshoulder', # 34
|
363 |
+
'lelbow', # 35
|
364 |
+
'lwrist', # 36
|
365 |
+
'neck', # 37
|
366 |
+
'headtop', # 38
|
367 |
+
'hip', # 39 'Pelvis (MPII)', # 39
|
368 |
+
'thorax', # 40 'Thorax (MPII)', # 40
|
369 |
+
'Spine (H36M)', # 41
|
370 |
+
'Jaw (H36M)', # 42
|
371 |
+
'Head (H36M)', # 43
|
372 |
+
'nose', # 44
|
373 |
+
'leye', # 45 'Left Eye', # 45
|
374 |
+
'reye', # 46 'Right Eye', # 46
|
375 |
+
'lear', # 47 'Left Ear', # 47
|
376 |
+
'rear', # 48 'Right Ear', # 48
|
377 |
+
]
|
378 |
+
|
379 |
+
def get_muco3dhp_joint_names():
|
380 |
+
return [
|
381 |
+
'headtop',
|
382 |
+
'thorax',
|
383 |
+
'rshoulder',
|
384 |
+
'relbow',
|
385 |
+
'rwrist',
|
386 |
+
'lshoulder',
|
387 |
+
'lelbow',
|
388 |
+
'lwrist',
|
389 |
+
'rhip',
|
390 |
+
'rknee',
|
391 |
+
'rankle',
|
392 |
+
'lhip',
|
393 |
+
'lknee',
|
394 |
+
'lankle',
|
395 |
+
'hip',
|
396 |
+
'Spine (H36M)',
|
397 |
+
'Head (H36M)',
|
398 |
+
'R_Hand',
|
399 |
+
'L_Hand',
|
400 |
+
'R_Toe',
|
401 |
+
'L_Toe'
|
402 |
+
]
|
403 |
+
|
404 |
+
def get_h36m_joint_names():
|
405 |
+
return [
|
406 |
+
'hip', # 0
|
407 |
+
'lhip', # 1
|
408 |
+
'lknee', # 2
|
409 |
+
'lankle', # 3
|
410 |
+
'rhip', # 4
|
411 |
+
'rknee', # 5
|
412 |
+
'rankle', # 6
|
413 |
+
'Spine (H36M)', # 7
|
414 |
+
'neck', # 8
|
415 |
+
'Head (H36M)', # 9
|
416 |
+
'headtop', # 10
|
417 |
+
'lshoulder', # 11
|
418 |
+
'lelbow', # 12
|
419 |
+
'lwrist', # 13
|
420 |
+
'rshoulder', # 14
|
421 |
+
'relbow', # 15
|
422 |
+
'rwrist', # 16
|
423 |
+
]
|
424 |
+
|
425 |
+
|
426 |
+
def get_spin_skeleton():
|
427 |
+
return np.array(
|
428 |
+
[
|
429 |
+
[0 , 1],
|
430 |
+
[1 , 2],
|
431 |
+
[2 , 3],
|
432 |
+
[3 , 4],
|
433 |
+
[1 , 5],
|
434 |
+
[5 , 6],
|
435 |
+
[6 , 7],
|
436 |
+
[1 , 8],
|
437 |
+
[8 , 9],
|
438 |
+
[9 ,10],
|
439 |
+
[10,11],
|
440 |
+
[8 ,12],
|
441 |
+
[12,13],
|
442 |
+
[13,14],
|
443 |
+
[0 ,15],
|
444 |
+
[0 ,16],
|
445 |
+
[15,17],
|
446 |
+
[16,18],
|
447 |
+
[21,19],
|
448 |
+
[19,20],
|
449 |
+
[14,21],
|
450 |
+
[11,24],
|
451 |
+
[24,22],
|
452 |
+
[22,23],
|
453 |
+
[0 ,38],
|
454 |
+
]
|
455 |
+
)
|
456 |
+
|
457 |
+
|
458 |
+
def get_openpose_skeleton():
|
459 |
+
return np.array(
|
460 |
+
[
|
461 |
+
[0 , 1],
|
462 |
+
[1 , 2],
|
463 |
+
[2 , 3],
|
464 |
+
[3 , 4],
|
465 |
+
[1 , 5],
|
466 |
+
[5 , 6],
|
467 |
+
[6 , 7],
|
468 |
+
[1 , 8],
|
469 |
+
[8 , 9],
|
470 |
+
[9 ,10],
|
471 |
+
[10,11],
|
472 |
+
[8 ,12],
|
473 |
+
[12,13],
|
474 |
+
[13,14],
|
475 |
+
[0 ,15],
|
476 |
+
[0 ,16],
|
477 |
+
[15,17],
|
478 |
+
[16,18],
|
479 |
+
[21,19],
|
480 |
+
[19,20],
|
481 |
+
[14,21],
|
482 |
+
[11,24],
|
483 |
+
[24,22],
|
484 |
+
[22,23],
|
485 |
+
]
|
486 |
+
)
|
487 |
+
|
488 |
+
|
489 |
+
def get_posetrack_joint_names():
|
490 |
+
return [
|
491 |
+
"nose",
|
492 |
+
"neck",
|
493 |
+
"headtop",
|
494 |
+
"lear",
|
495 |
+
"rear",
|
496 |
+
"lshoulder",
|
497 |
+
"rshoulder",
|
498 |
+
"lelbow",
|
499 |
+
"relbow",
|
500 |
+
"lwrist",
|
501 |
+
"rwrist",
|
502 |
+
"lhip",
|
503 |
+
"rhip",
|
504 |
+
"lknee",
|
505 |
+
"rknee",
|
506 |
+
"lankle",
|
507 |
+
"rankle"
|
508 |
+
]
|
509 |
+
|
510 |
+
|
511 |
+
def get_posetrack_original_kp_names():
|
512 |
+
return [
|
513 |
+
'nose',
|
514 |
+
'head_bottom',
|
515 |
+
'head_top',
|
516 |
+
'left_ear',
|
517 |
+
'right_ear',
|
518 |
+
'left_shoulder',
|
519 |
+
'right_shoulder',
|
520 |
+
'left_elbow',
|
521 |
+
'right_elbow',
|
522 |
+
'left_wrist',
|
523 |
+
'right_wrist',
|
524 |
+
'left_hip',
|
525 |
+
'right_hip',
|
526 |
+
'left_knee',
|
527 |
+
'right_knee',
|
528 |
+
'left_ankle',
|
529 |
+
'right_ankle'
|
530 |
+
]
|
531 |
+
|
532 |
+
|
533 |
+
def get_pennaction_joint_names():
|
534 |
+
return [
|
535 |
+
"headtop", # 0
|
536 |
+
"lshoulder", # 1
|
537 |
+
"rshoulder", # 2
|
538 |
+
"lelbow", # 3
|
539 |
+
"relbow", # 4
|
540 |
+
"lwrist", # 5
|
541 |
+
"rwrist", # 6
|
542 |
+
"lhip" , # 7
|
543 |
+
"rhip" , # 8
|
544 |
+
"lknee", # 9
|
545 |
+
"rknee" , # 10
|
546 |
+
"lankle", # 11
|
547 |
+
"rankle" # 12
|
548 |
+
]
|
549 |
+
|
550 |
+
|
551 |
+
def get_common_joint_names():
|
552 |
+
return [
|
553 |
+
"rankle", # 0 "lankle", # 0
|
554 |
+
"rknee", # 1 "lknee", # 1
|
555 |
+
"rhip", # 2 "lhip", # 2
|
556 |
+
"lhip", # 3 "rhip", # 3
|
557 |
+
"lknee", # 4 "rknee", # 4
|
558 |
+
"lankle", # 5 "rankle", # 5
|
559 |
+
"rwrist", # 6 "lwrist", # 6
|
560 |
+
"relbow", # 7 "lelbow", # 7
|
561 |
+
"rshoulder", # 8 "lshoulder", # 8
|
562 |
+
"lshoulder", # 9 "rshoulder", # 9
|
563 |
+
"lelbow", # 10 "relbow", # 10
|
564 |
+
"lwrist", # 11 "rwrist", # 11
|
565 |
+
"neck", # 12 "neck", # 12
|
566 |
+
"headtop", # 13 "headtop", # 13
|
567 |
+
]
|
568 |
+
|
569 |
+
|
570 |
+
def get_common_paper_joint_names():
|
571 |
+
return [
|
572 |
+
"Right Ankle", # 0 "lankle", # 0
|
573 |
+
"Right Knee", # 1 "lknee", # 1
|
574 |
+
"Right Hip", # 2 "lhip", # 2
|
575 |
+
"Left Hip", # 3 "rhip", # 3
|
576 |
+
"Left Knee", # 4 "rknee", # 4
|
577 |
+
"Left Ankle", # 5 "rankle", # 5
|
578 |
+
"Right Wrist", # 6 "lwrist", # 6
|
579 |
+
"Right Elbow", # 7 "lelbow", # 7
|
580 |
+
"Right Shoulder", # 8 "lshoulder", # 8
|
581 |
+
"Left Shoulder", # 9 "rshoulder", # 9
|
582 |
+
"Left Elbow", # 10 "relbow", # 10
|
583 |
+
"Left Wrist", # 11 "rwrist", # 11
|
584 |
+
"Neck", # 12 "neck", # 12
|
585 |
+
"Head", # 13 "headtop", # 13
|
586 |
+
]
|
587 |
+
|
588 |
+
|
589 |
+
def get_common_skeleton():
|
590 |
+
return np.array(
|
591 |
+
[
|
592 |
+
[ 0, 1 ],
|
593 |
+
[ 1, 2 ],
|
594 |
+
[ 3, 4 ],
|
595 |
+
[ 4, 5 ],
|
596 |
+
[ 6, 7 ],
|
597 |
+
[ 7, 8 ],
|
598 |
+
[ 8, 2 ],
|
599 |
+
[ 8, 9 ],
|
600 |
+
[ 9, 3 ],
|
601 |
+
[ 2, 3 ],
|
602 |
+
[ 8, 12],
|
603 |
+
[ 9, 10],
|
604 |
+
[12, 9 ],
|
605 |
+
[10, 11],
|
606 |
+
[12, 13],
|
607 |
+
]
|
608 |
+
)
|
609 |
+
|
610 |
+
|
611 |
+
def get_coco_joint_names():
|
612 |
+
return [
|
613 |
+
"nose", # 0
|
614 |
+
"leye", # 1
|
615 |
+
"reye", # 2
|
616 |
+
"lear", # 3
|
617 |
+
"rear", # 4
|
618 |
+
"lshoulder", # 5
|
619 |
+
"rshoulder", # 6
|
620 |
+
"lelbow", # 7
|
621 |
+
"relbow", # 8
|
622 |
+
"lwrist", # 9
|
623 |
+
"rwrist", # 10
|
624 |
+
"lhip", # 11
|
625 |
+
"rhip", # 12
|
626 |
+
"lknee", # 13
|
627 |
+
"rknee", # 14
|
628 |
+
"lankle", # 15
|
629 |
+
"rankle", # 16
|
630 |
+
]
|
631 |
+
|
632 |
+
|
633 |
+
def get_ochuman_joint_names():
|
634 |
+
return [
|
635 |
+
'rshoulder',
|
636 |
+
'relbow',
|
637 |
+
'rwrist',
|
638 |
+
'lshoulder',
|
639 |
+
'lelbow',
|
640 |
+
'lwrist',
|
641 |
+
'rhip',
|
642 |
+
'rknee',
|
643 |
+
'rankle',
|
644 |
+
'lhip',
|
645 |
+
'lknee',
|
646 |
+
'lankle',
|
647 |
+
'headtop',
|
648 |
+
'neck',
|
649 |
+
'rear',
|
650 |
+
'lear',
|
651 |
+
'nose',
|
652 |
+
'reye',
|
653 |
+
'leye'
|
654 |
+
]
|
655 |
+
|
656 |
+
|
657 |
+
def get_crowdpose_joint_names():
|
658 |
+
return [
|
659 |
+
'lshoulder',
|
660 |
+
'rshoulder',
|
661 |
+
'lelbow',
|
662 |
+
'relbow',
|
663 |
+
'lwrist',
|
664 |
+
'rwrist',
|
665 |
+
'lhip',
|
666 |
+
'rhip',
|
667 |
+
'lknee',
|
668 |
+
'rknee',
|
669 |
+
'lankle',
|
670 |
+
'rankle',
|
671 |
+
'headtop',
|
672 |
+
'neck'
|
673 |
+
]
|
674 |
+
|
675 |
+
def get_coco_skeleton():
|
676 |
+
# 0 - nose,
|
677 |
+
# 1 - leye,
|
678 |
+
# 2 - reye,
|
679 |
+
# 3 - lear,
|
680 |
+
# 4 - rear,
|
681 |
+
# 5 - lshoulder,
|
682 |
+
# 6 - rshoulder,
|
683 |
+
# 7 - lelbow,
|
684 |
+
# 8 - relbow,
|
685 |
+
# 9 - lwrist,
|
686 |
+
# 10 - rwrist,
|
687 |
+
# 11 - lhip,
|
688 |
+
# 12 - rhip,
|
689 |
+
# 13 - lknee,
|
690 |
+
# 14 - rknee,
|
691 |
+
# 15 - lankle,
|
692 |
+
# 16 - rankle,
|
693 |
+
return np.array(
|
694 |
+
[
|
695 |
+
[15, 13],
|
696 |
+
[13, 11],
|
697 |
+
[16, 14],
|
698 |
+
[14, 12],
|
699 |
+
[11, 12],
|
700 |
+
[ 5, 11],
|
701 |
+
[ 6, 12],
|
702 |
+
[ 5, 6 ],
|
703 |
+
[ 5, 7 ],
|
704 |
+
[ 6, 8 ],
|
705 |
+
[ 7, 9 ],
|
706 |
+
[ 8, 10],
|
707 |
+
[ 1, 2 ],
|
708 |
+
[ 0, 1 ],
|
709 |
+
[ 0, 2 ],
|
710 |
+
[ 1, 3 ],
|
711 |
+
[ 2, 4 ],
|
712 |
+
[ 3, 5 ],
|
713 |
+
[ 4, 6 ]
|
714 |
+
]
|
715 |
+
)
|
716 |
+
|
717 |
+
|
718 |
+
def get_mpii_joint_names():
|
719 |
+
return [
|
720 |
+
"rankle", # 0
|
721 |
+
"rknee", # 1
|
722 |
+
"rhip", # 2
|
723 |
+
"lhip", # 3
|
724 |
+
"lknee", # 4
|
725 |
+
"lankle", # 5
|
726 |
+
"hip", # 6
|
727 |
+
"thorax", # 7
|
728 |
+
"neck", # 8
|
729 |
+
"headtop", # 9
|
730 |
+
"rwrist", # 10
|
731 |
+
"relbow", # 11
|
732 |
+
"rshoulder", # 12
|
733 |
+
"lshoulder", # 13
|
734 |
+
"lelbow", # 14
|
735 |
+
"lwrist", # 15
|
736 |
+
]
|
737 |
+
|
738 |
+
|
739 |
+
def get_mpii_skeleton():
|
740 |
+
# 0 - rankle,
|
741 |
+
# 1 - rknee,
|
742 |
+
# 2 - rhip,
|
743 |
+
# 3 - lhip,
|
744 |
+
# 4 - lknee,
|
745 |
+
# 5 - lankle,
|
746 |
+
# 6 - hip,
|
747 |
+
# 7 - thorax,
|
748 |
+
# 8 - neck,
|
749 |
+
# 9 - headtop,
|
750 |
+
# 10 - rwrist,
|
751 |
+
# 11 - relbow,
|
752 |
+
# 12 - rshoulder,
|
753 |
+
# 13 - lshoulder,
|
754 |
+
# 14 - lelbow,
|
755 |
+
# 15 - lwrist,
|
756 |
+
return np.array(
|
757 |
+
[
|
758 |
+
[ 0, 1 ],
|
759 |
+
[ 1, 2 ],
|
760 |
+
[ 2, 6 ],
|
761 |
+
[ 6, 3 ],
|
762 |
+
[ 3, 4 ],
|
763 |
+
[ 4, 5 ],
|
764 |
+
[ 6, 7 ],
|
765 |
+
[ 7, 8 ],
|
766 |
+
[ 8, 9 ],
|
767 |
+
[ 7, 12],
|
768 |
+
[12, 11],
|
769 |
+
[11, 10],
|
770 |
+
[ 7, 13],
|
771 |
+
[13, 14],
|
772 |
+
[14, 15]
|
773 |
+
]
|
774 |
+
)
|
775 |
+
|
776 |
+
|
777 |
+
def get_aich_joint_names():
|
778 |
+
return [
|
779 |
+
"rshoulder", # 0
|
780 |
+
"relbow", # 1
|
781 |
+
"rwrist", # 2
|
782 |
+
"lshoulder", # 3
|
783 |
+
"lelbow", # 4
|
784 |
+
"lwrist", # 5
|
785 |
+
"rhip", # 6
|
786 |
+
"rknee", # 7
|
787 |
+
"rankle", # 8
|
788 |
+
"lhip", # 9
|
789 |
+
"lknee", # 10
|
790 |
+
"lankle", # 11
|
791 |
+
"headtop", # 12
|
792 |
+
"neck", # 13
|
793 |
+
]
|
794 |
+
|
795 |
+
|
796 |
+
def get_aich_skeleton():
|
797 |
+
# 0 - rshoulder,
|
798 |
+
# 1 - relbow,
|
799 |
+
# 2 - rwrist,
|
800 |
+
# 3 - lshoulder,
|
801 |
+
# 4 - lelbow,
|
802 |
+
# 5 - lwrist,
|
803 |
+
# 6 - rhip,
|
804 |
+
# 7 - rknee,
|
805 |
+
# 8 - rankle,
|
806 |
+
# 9 - lhip,
|
807 |
+
# 10 - lknee,
|
808 |
+
# 11 - lankle,
|
809 |
+
# 12 - headtop,
|
810 |
+
# 13 - neck,
|
811 |
+
return np.array(
|
812 |
+
[
|
813 |
+
[ 0, 1 ],
|
814 |
+
[ 1, 2 ],
|
815 |
+
[ 3, 4 ],
|
816 |
+
[ 4, 5 ],
|
817 |
+
[ 6, 7 ],
|
818 |
+
[ 7, 8 ],
|
819 |
+
[ 9, 10],
|
820 |
+
[10, 11],
|
821 |
+
[12, 13],
|
822 |
+
[13, 0 ],
|
823 |
+
[13, 3 ],
|
824 |
+
[ 0, 6 ],
|
825 |
+
[ 3, 9 ]
|
826 |
+
]
|
827 |
+
)
|
828 |
+
|
829 |
+
|
830 |
+
def get_3dpw_joint_names():
|
831 |
+
return [
|
832 |
+
"nose", # 0
|
833 |
+
"thorax", # 1
|
834 |
+
"rshoulder", # 2
|
835 |
+
"relbow", # 3
|
836 |
+
"rwrist", # 4
|
837 |
+
"lshoulder", # 5
|
838 |
+
"lelbow", # 6
|
839 |
+
"lwrist", # 7
|
840 |
+
"rhip", # 8
|
841 |
+
"rknee", # 9
|
842 |
+
"rankle", # 10
|
843 |
+
"lhip", # 11
|
844 |
+
"lknee", # 12
|
845 |
+
"lankle", # 13
|
846 |
+
]
|
847 |
+
|
848 |
+
|
849 |
+
def get_3dpw_skeleton():
|
850 |
+
return np.array(
|
851 |
+
[
|
852 |
+
[ 0, 1 ],
|
853 |
+
[ 1, 2 ],
|
854 |
+
[ 2, 3 ],
|
855 |
+
[ 3, 4 ],
|
856 |
+
[ 1, 5 ],
|
857 |
+
[ 5, 6 ],
|
858 |
+
[ 6, 7 ],
|
859 |
+
[ 2, 8 ],
|
860 |
+
[ 5, 11],
|
861 |
+
[ 8, 11],
|
862 |
+
[ 8, 9 ],
|
863 |
+
[ 9, 10],
|
864 |
+
[11, 12],
|
865 |
+
[12, 13]
|
866 |
+
]
|
867 |
+
)
|
868 |
+
|
869 |
+
|
870 |
+
def get_smplcoco_joint_names():
|
871 |
+
return [
|
872 |
+
"rankle", # 0
|
873 |
+
"rknee", # 1
|
874 |
+
"rhip", # 2
|
875 |
+
"lhip", # 3
|
876 |
+
"lknee", # 4
|
877 |
+
"lankle", # 5
|
878 |
+
"rwrist", # 6
|
879 |
+
"relbow", # 7
|
880 |
+
"rshoulder", # 8
|
881 |
+
"lshoulder", # 9
|
882 |
+
"lelbow", # 10
|
883 |
+
"lwrist", # 11
|
884 |
+
"neck", # 12
|
885 |
+
"headtop", # 13
|
886 |
+
"nose", # 14
|
887 |
+
"leye", # 15
|
888 |
+
"reye", # 16
|
889 |
+
"lear", # 17
|
890 |
+
"rear", # 18
|
891 |
+
]
|
892 |
+
|
893 |
+
|
894 |
+
def get_smplcoco_skeleton():
|
895 |
+
return np.array(
|
896 |
+
[
|
897 |
+
[ 0, 1 ],
|
898 |
+
[ 1, 2 ],
|
899 |
+
[ 3, 4 ],
|
900 |
+
[ 4, 5 ],
|
901 |
+
[ 6, 7 ],
|
902 |
+
[ 7, 8 ],
|
903 |
+
[ 8, 12],
|
904 |
+
[12, 9 ],
|
905 |
+
[ 9, 10],
|
906 |
+
[10, 11],
|
907 |
+
[12, 13],
|
908 |
+
[14, 15],
|
909 |
+
[15, 17],
|
910 |
+
[16, 18],
|
911 |
+
[14, 16],
|
912 |
+
[ 8, 2 ],
|
913 |
+
[ 9, 3 ],
|
914 |
+
[ 2, 3 ],
|
915 |
+
]
|
916 |
+
)
|
917 |
+
|
918 |
+
|
919 |
+
def get_smpl_joint_names():
|
920 |
+
return [
|
921 |
+
'hips', # 0
|
922 |
+
'leftUpLeg', # 1
|
923 |
+
'rightUpLeg', # 2
|
924 |
+
'spine', # 3
|
925 |
+
'leftLeg', # 4
|
926 |
+
'rightLeg', # 5
|
927 |
+
'spine1', # 6
|
928 |
+
'leftFoot', # 7
|
929 |
+
'rightFoot', # 8
|
930 |
+
'spine2', # 9
|
931 |
+
'leftToeBase', # 10
|
932 |
+
'rightToeBase', # 11
|
933 |
+
'neck', # 12
|
934 |
+
'leftShoulder', # 13
|
935 |
+
'rightShoulder', # 14
|
936 |
+
'head', # 15
|
937 |
+
'leftArm', # 16
|
938 |
+
'rightArm', # 17
|
939 |
+
'leftForeArm', # 18
|
940 |
+
'rightForeArm', # 19
|
941 |
+
'leftHand', # 20
|
942 |
+
'rightHand', # 21
|
943 |
+
'leftHandIndex1', # 22
|
944 |
+
'rightHandIndex1', # 23
|
945 |
+
]
|
946 |
+
|
947 |
+
|
948 |
+
def get_smpl_paper_joint_names():
|
949 |
+
return [
|
950 |
+
'Hips', # 0
|
951 |
+
'Left Hip', # 1
|
952 |
+
'Right Hip', # 2
|
953 |
+
'Spine', # 3
|
954 |
+
'Left Knee', # 4
|
955 |
+
'Right Knee', # 5
|
956 |
+
'Spine_1', # 6
|
957 |
+
'Left Ankle', # 7
|
958 |
+
'Right Ankle', # 8
|
959 |
+
'Spine_2', # 9
|
960 |
+
'Left Toe', # 10
|
961 |
+
'Right Toe', # 11
|
962 |
+
'Neck', # 12
|
963 |
+
'Left Shoulder', # 13
|
964 |
+
'Right Shoulder', # 14
|
965 |
+
'Head', # 15
|
966 |
+
'Left Arm', # 16
|
967 |
+
'Right Arm', # 17
|
968 |
+
'Left Elbow', # 18
|
969 |
+
'Right Elbow', # 19
|
970 |
+
'Left Hand', # 20
|
971 |
+
'Right Hand', # 21
|
972 |
+
'Left Thumb', # 22
|
973 |
+
'Right Thumb', # 23
|
974 |
+
]
|
975 |
+
|
976 |
+
|
977 |
+
def get_smpl_neighbor_triplets():
|
978 |
+
return [
|
979 |
+
[ 0, 1, 2 ], # 0
|
980 |
+
[ 1, 4, 0 ], # 1
|
981 |
+
[ 2, 0, 5 ], # 2
|
982 |
+
[ 3, 0, 6 ], # 3
|
983 |
+
[ 4, 7, 1 ], # 4
|
984 |
+
[ 5, 2, 8 ], # 5
|
985 |
+
[ 6, 3, 9 ], # 6
|
986 |
+
[ 7, 10, 4 ], # 7
|
987 |
+
[ 8, 5, 11], # 8
|
988 |
+
[ 9, 13, 14], # 9
|
989 |
+
[10, 7, 4 ], # 10
|
990 |
+
[11, 8, 5 ], # 11
|
991 |
+
[12, 9, 15], # 12
|
992 |
+
[13, 16, 9 ], # 13
|
993 |
+
[14, 9, 17], # 14
|
994 |
+
[15, 9, 12], # 15
|
995 |
+
[16, 18, 13], # 16
|
996 |
+
[17, 14, 19], # 17
|
997 |
+
[18, 20, 16], # 18
|
998 |
+
[19, 17, 21], # 19
|
999 |
+
[20, 22, 18], # 20
|
1000 |
+
[21, 19, 23], # 21
|
1001 |
+
[22, 20, 18], # 22
|
1002 |
+
[23, 19, 21], # 23
|
1003 |
+
]
|
1004 |
+
|
1005 |
+
|
1006 |
+
def get_smpl_skeleton():
|
1007 |
+
return np.array(
|
1008 |
+
[
|
1009 |
+
[ 0, 1 ],
|
1010 |
+
[ 0, 2 ],
|
1011 |
+
[ 0, 3 ],
|
1012 |
+
[ 1, 4 ],
|
1013 |
+
[ 2, 5 ],
|
1014 |
+
[ 3, 6 ],
|
1015 |
+
[ 4, 7 ],
|
1016 |
+
[ 5, 8 ],
|
1017 |
+
[ 6, 9 ],
|
1018 |
+
[ 7, 10],
|
1019 |
+
[ 8, 11],
|
1020 |
+
[ 9, 12],
|
1021 |
+
[ 9, 13],
|
1022 |
+
[ 9, 14],
|
1023 |
+
[12, 15],
|
1024 |
+
[13, 16],
|
1025 |
+
[14, 17],
|
1026 |
+
[16, 18],
|
1027 |
+
[17, 19],
|
1028 |
+
[18, 20],
|
1029 |
+
[19, 21],
|
1030 |
+
[20, 22],
|
1031 |
+
[21, 23],
|
1032 |
+
]
|
1033 |
+
)
|
1034 |
+
|
1035 |
+
|
1036 |
+
def map_spin_joints_to_smpl():
|
1037 |
+
# this function primarily will be used to copy 2D keypoint
|
1038 |
+
# confidences to pose parameters
|
1039 |
+
return [
|
1040 |
+
[(39, 27, 28), 0], # hip,lhip,rhip->hips
|
1041 |
+
[(28,), 1], # lhip->leftUpLeg
|
1042 |
+
[(27,), 2], # rhip->rightUpLeg
|
1043 |
+
[(41, 27, 28, 39), 3], # Spine->spine
|
1044 |
+
[(29,), 4], # lknee->leftLeg
|
1045 |
+
[(26,), 5], # rknee->rightLeg
|
1046 |
+
[(41, 40, 33, 34,), 6], # spine, thorax ->spine1
|
1047 |
+
[(30,), 7], # lankle->leftFoot
|
1048 |
+
[(25,), 8], # rankle->rightFoot
|
1049 |
+
[(40, 33, 34), 9], # thorax,shoulders->spine2
|
1050 |
+
[(30,), 10], # lankle -> leftToe
|
1051 |
+
[(25,), 11], # rankle -> rightToe
|
1052 |
+
[(37, 42, 33, 34), 12], # neck, shoulders -> neck
|
1053 |
+
[(34,), 13], # lshoulder->leftShoulder
|
1054 |
+
[(33,), 14], # rshoulder->rightShoulder
|
1055 |
+
[(33, 34, 38, 43, 44, 45, 46, 47, 48,), 15], # nose, eyes, ears, headtop, shoulders->head
|
1056 |
+
[(34,), 16], # lshoulder->leftArm
|
1057 |
+
[(33,), 17], # rshoulder->rightArm
|
1058 |
+
[(35,), 18], # lelbow->leftForeArm
|
1059 |
+
[(32,), 19], # relbow->rightForeArm
|
1060 |
+
[(36,), 20], # lwrist->leftHand
|
1061 |
+
[(31,), 21], # rwrist->rightHand
|
1062 |
+
[(36,), 22], # lhand -> leftHandIndex
|
1063 |
+
[(31,), 23], # rhand -> rightHandIndex
|
1064 |
+
]
|
1065 |
+
|
1066 |
+
|
1067 |
+
def map_smpl_to_common():
|
1068 |
+
return [
|
1069 |
+
[(11, 8), 0], # rightToe, rightFoot -> rankle
|
1070 |
+
[(5,), 1], # rightleg -> rknee,
|
1071 |
+
[(2,), 2], # rhip
|
1072 |
+
[(1,), 3], # lhip
|
1073 |
+
[(4,), 4], # leftLeg -> lknee
|
1074 |
+
[(10, 7), 5], # lefttoe, leftfoot -> lankle
|
1075 |
+
[(21, 23), 6], # rwrist
|
1076 |
+
[(18,), 7], # relbow
|
1077 |
+
[(17, 14), 8], # rshoulder
|
1078 |
+
[(16, 13), 9], # lshoulder
|
1079 |
+
[(19,), 10], # lelbow
|
1080 |
+
[(20, 22), 11], # lwrist
|
1081 |
+
[(0, 3, 6, 9, 12), 12], # neck
|
1082 |
+
[(15,), 13], # headtop
|
1083 |
+
]
|
1084 |
+
|
1085 |
+
|
1086 |
+
def relation_among_spin_joints():
|
1087 |
+
# this function primarily will be used to copy 2D keypoint
|
1088 |
+
# confidences to 3D joints
|
1089 |
+
return [
|
1090 |
+
[(), 25],
|
1091 |
+
[(), 26],
|
1092 |
+
[(39,), 27],
|
1093 |
+
[(39,), 28],
|
1094 |
+
[(), 29],
|
1095 |
+
[(), 30],
|
1096 |
+
[(), 31],
|
1097 |
+
[(), 32],
|
1098 |
+
[(), 33],
|
1099 |
+
[(), 34],
|
1100 |
+
[(), 35],
|
1101 |
+
[(), 36],
|
1102 |
+
[(40,42,44,43,38,33,34,), 37],
|
1103 |
+
[(43,44,45,46,47,48,33,34,), 38],
|
1104 |
+
[(27,28,), 39],
|
1105 |
+
[(27,28,37,41,42,), 40],
|
1106 |
+
[(27,28,39,40,), 41],
|
1107 |
+
[(37,38,44,45,46,47,48,), 42],
|
1108 |
+
[(44,45,46,47,48,38,42,37,33,34,), 43],
|
1109 |
+
[(44,45,46,47,48,38,42,37,33,34), 44],
|
1110 |
+
[(44,45,46,47,48,38,42,37,33,34), 45],
|
1111 |
+
[(44,45,46,47,48,38,42,37,33,34), 46],
|
1112 |
+
[(44,45,46,47,48,38,42,37,33,34), 47],
|
1113 |
+
[(44,45,46,47,48,38,42,37,33,34), 48],
|
1114 |
+
]
|
utils/loss.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from common import constants
|
4 |
+
from models.smpl import SMPL
|
5 |
+
from smplx import SMPLX
|
6 |
+
import pickle as pkl
|
7 |
+
import numpy as np
|
8 |
+
from utils.mesh_utils import save_results_mesh
|
9 |
+
from utils.diff_renderer import Pytorch3D
|
10 |
+
import os
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
class sem_loss_function(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(sem_loss_function, self).__init__()
|
17 |
+
self.ce = nn.BCELoss()
|
18 |
+
|
19 |
+
def forward(self, y_true, y_pred):
|
20 |
+
loss = self.ce(y_pred, y_true)
|
21 |
+
return loss
|
22 |
+
|
23 |
+
|
24 |
+
class class_loss_function(nn.Module):
|
25 |
+
def __init__(self):
|
26 |
+
super(class_loss_function, self).__init__()
|
27 |
+
self.ce_loss = nn.BCELoss()
|
28 |
+
# self.ce_loss = nn.MultiLabelSoftMarginLoss()
|
29 |
+
# self.ce_loss = nn.MultiLabelMarginLoss()
|
30 |
+
|
31 |
+
def forward(self, y_true, y_pred, valid_mask):
|
32 |
+
# y_true = torch.squeeze(y_true, 1).long()
|
33 |
+
# y_true = torch.squeeze(y_true, 1)
|
34 |
+
# y_pred = torch.squeeze(y_pred, 1)
|
35 |
+
bs = y_true.shape[0]
|
36 |
+
if bs != 1:
|
37 |
+
y_pred = y_pred[valid_mask == 1]
|
38 |
+
y_true = y_true[valid_mask == 1]
|
39 |
+
if len(y_pred) > 0:
|
40 |
+
return self.ce_loss(y_pred, y_true)
|
41 |
+
else:
|
42 |
+
return torch.tensor(0.0).to(y_pred.device)
|
43 |
+
|
44 |
+
|
45 |
+
class pixel_anchoring_function(nn.Module):
|
46 |
+
def __init__(self, model_type, device='cuda'):
|
47 |
+
super(pixel_anchoring_function, self).__init__()
|
48 |
+
|
49 |
+
self.device = device
|
50 |
+
|
51 |
+
self.model_type = model_type
|
52 |
+
|
53 |
+
if self.model_type == 'smplx':
|
54 |
+
# load mapping from smpl vertices to smplx vertices
|
55 |
+
mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
|
56 |
+
with open(mapping_pkl, 'rb') as f:
|
57 |
+
smpl_to_smplx_mapping = pkl.load(f)
|
58 |
+
smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
|
59 |
+
self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
|
60 |
+
|
61 |
+
|
62 |
+
# Setup the SMPL model
|
63 |
+
if self.model_type == 'smpl':
|
64 |
+
self.n_vertices = 6890
|
65 |
+
self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
|
66 |
+
if self.model_type == 'smplx':
|
67 |
+
self.n_vertices = 10475
|
68 |
+
self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
|
69 |
+
num_betas=10,
|
70 |
+
use_pca=False).to(self.device)
|
71 |
+
self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
|
72 |
+
|
73 |
+
self.ce_loss = nn.BCELoss()
|
74 |
+
|
75 |
+
def get_posed_mesh(self, body_params, debug=False):
|
76 |
+
betas = body_params['betas']
|
77 |
+
pose = body_params['pose']
|
78 |
+
transl = body_params['transl']
|
79 |
+
|
80 |
+
# extra smplx params
|
81 |
+
extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
82 |
+
'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
83 |
+
'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
84 |
+
'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
|
85 |
+
'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
|
86 |
+
'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
|
87 |
+
|
88 |
+
smpl_output = self.body_model(betas=betas,
|
89 |
+
body_pose=pose[:, 3:],
|
90 |
+
global_orient=pose[:, :3],
|
91 |
+
pose2rot=True,
|
92 |
+
transl=transl,
|
93 |
+
**extra_args)
|
94 |
+
smpl_verts = smpl_output.vertices
|
95 |
+
smpl_joints = smpl_output.joints
|
96 |
+
|
97 |
+
if debug:
|
98 |
+
for mesh_i in range(smpl_verts.shape[0]):
|
99 |
+
out_dir = 'temp_meshes'
|
100 |
+
os.makedirs(out_dir, exist_ok=True)
|
101 |
+
out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
|
102 |
+
save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
|
103 |
+
return smpl_verts, smpl_joints
|
104 |
+
|
105 |
+
|
106 |
+
def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
|
107 |
+
|
108 |
+
bs = smpl_verts.shape[0]
|
109 |
+
|
110 |
+
# Incorporate resizing factor into the camera
|
111 |
+
img_w = 256 # TODO: Remove hardcoding
|
112 |
+
img_h = 256 # TODO: Remove hardcoding
|
113 |
+
focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
|
114 |
+
focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
|
115 |
+
# convert to float for pytorch3d
|
116 |
+
focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
|
117 |
+
|
118 |
+
# concatenate focal length
|
119 |
+
focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
|
120 |
+
|
121 |
+
# Setup renderer
|
122 |
+
renderer = Pytorch3D(img_h=img_h,
|
123 |
+
img_w=img_w,
|
124 |
+
focal_length=focal_length,
|
125 |
+
smpl_faces=self.body_faces,
|
126 |
+
texture_mode='deco',
|
127 |
+
vertex_colors=vertex_colors,
|
128 |
+
face_textures=face_textures,
|
129 |
+
is_train=True,
|
130 |
+
is_cam_batch=True)
|
131 |
+
front_view = renderer(smpl_verts)
|
132 |
+
if debug:
|
133 |
+
# visualize the front view as images in a temp_image folder
|
134 |
+
for i in range(bs):
|
135 |
+
front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
|
136 |
+
front_view_mask = front_view[i, 3, :, :].detach().cpu()
|
137 |
+
out_dir = 'temp_images'
|
138 |
+
os.makedirs(out_dir, exist_ok=True)
|
139 |
+
out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
|
140 |
+
out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
|
141 |
+
cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
|
142 |
+
cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
|
143 |
+
|
144 |
+
return front_view
|
145 |
+
|
146 |
+
def paint_contact(self, pred_contact):
|
147 |
+
"""
|
148 |
+
Paints the contact vertices on the SMPL mesh
|
149 |
+
|
150 |
+
Args:
|
151 |
+
pred_contact: prbabilities of contact vertices
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
pred_rgb: RGB colors for the contact vertices
|
155 |
+
"""
|
156 |
+
bs = pred_contact.shape[0]
|
157 |
+
|
158 |
+
# initialize black and while colors
|
159 |
+
colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
|
160 |
+
colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
|
161 |
+
|
162 |
+
# add another dimension to the contact probabilities for inverse probabilities
|
163 |
+
pred_contact = torch.unsqueeze(pred_contact, 2)
|
164 |
+
pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
|
165 |
+
|
166 |
+
# get pred_rgb colors
|
167 |
+
pred_vert_rgb = torch.bmm(pred_contact, colors)
|
168 |
+
pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
|
169 |
+
pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
|
170 |
+
pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
|
171 |
+
return pred_vert_rgb, pred_face_texture
|
172 |
+
|
173 |
+
def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
|
174 |
+
"""
|
175 |
+
Takes predicted contact labels (probabilities), transfers them to the posed mesh and
|
176 |
+
renders to the image. Loss is computed between the rendered contact and the ground truth
|
177 |
+
polygons from HOT.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
pred_contact: predicted contact labels (probabilities)
|
181 |
+
body_params: SMPL parameters in camera coords
|
182 |
+
cam_k: camera intrinsics
|
183 |
+
gt_contact_polygon: ground truth polygons from HOT
|
184 |
+
"""
|
185 |
+
# convert pred_contact to smplx
|
186 |
+
bs = pred_contact.shape[0]
|
187 |
+
if self.model_type == 'smplx':
|
188 |
+
smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
|
189 |
+
pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
|
190 |
+
pred_contact = pred_contact.squeeze()
|
191 |
+
|
192 |
+
# get the posed mesh
|
193 |
+
smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
|
194 |
+
|
195 |
+
# paint the contact vertices on the mesh
|
196 |
+
vertex_colors, face_textures = self.paint_contact(pred_contact)
|
197 |
+
|
198 |
+
# render the mesh
|
199 |
+
front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
|
200 |
+
front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
|
201 |
+
front_view_mask = front_view[:, 3, :, :]
|
202 |
+
|
203 |
+
# compute segmentation loss between rendered contact mask and ground truth contact mask
|
204 |
+
front_view_rgb = front_view_rgb[valid_mask == 1]
|
205 |
+
gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
|
206 |
+
loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
|
207 |
+
return loss, front_view_rgb, front_view_mask
|
utils/mesh_utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trimesh
|
2 |
+
|
3 |
+
def save_results_mesh(vertices, faces, filename):
|
4 |
+
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
5 |
+
mesh.export(filename)
|
6 |
+
print(f'save results to {filename}')
|
utils/metrics.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import monai.metrics as metrics
|
4 |
+
from common.constants import DIST_MATRIX_PATH
|
5 |
+
|
6 |
+
DIST_MATRIX = np.load(DIST_MATRIX_PATH)
|
7 |
+
|
8 |
+
def metric(mask, pred, back=True):
|
9 |
+
iou = metrics.compute_meaniou(pred, mask, back, False)
|
10 |
+
iou = iou.mean()
|
11 |
+
|
12 |
+
return iou
|
13 |
+
|
14 |
+
def precision_recall_f1score(gt, pred):
|
15 |
+
"""
|
16 |
+
Compute precision, recall, and f1
|
17 |
+
"""
|
18 |
+
|
19 |
+
# gt = gt.numpy()
|
20 |
+
# pred = pred.numpy()
|
21 |
+
|
22 |
+
precision = torch.zeros(gt.shape[0])
|
23 |
+
recall = torch.zeros(gt.shape[0])
|
24 |
+
f1 = torch.zeros(gt.shape[0])
|
25 |
+
|
26 |
+
for b in range(gt.shape[0]):
|
27 |
+
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
28 |
+
precision_denominator = (pred[b, :] >= 0.5).sum()
|
29 |
+
recall_denominator = (gt[b, :]).sum()
|
30 |
+
|
31 |
+
precision_ = tp_num / precision_denominator
|
32 |
+
recall_ = tp_num / recall_denominator
|
33 |
+
if precision_denominator == 0: # if no pred
|
34 |
+
precision_ = 1.
|
35 |
+
recall_ = 0.
|
36 |
+
f1_ = 0.
|
37 |
+
elif recall_denominator == 0: # if no GT
|
38 |
+
precision_ = 0.
|
39 |
+
recall_ = 1.
|
40 |
+
f1_ = 0.
|
41 |
+
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
|
42 |
+
precision_= 0.
|
43 |
+
recall_= 0.
|
44 |
+
f1_ = 0.
|
45 |
+
else:
|
46 |
+
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
|
47 |
+
|
48 |
+
precision[b] = precision_
|
49 |
+
recall[b] = recall_
|
50 |
+
f1[b] = f1_
|
51 |
+
|
52 |
+
# return precision, recall, f1
|
53 |
+
return precision, recall, f1
|
54 |
+
|
55 |
+
def acc_precision_recall_f1score(gt, pred):
|
56 |
+
"""
|
57 |
+
Compute acc, precision, recall, and f1
|
58 |
+
"""
|
59 |
+
|
60 |
+
# gt = gt.numpy()
|
61 |
+
# pred = pred.numpy()
|
62 |
+
|
63 |
+
acc = torch.zeros(gt.shape[0])
|
64 |
+
precision = torch.zeros(gt.shape[0])
|
65 |
+
recall = torch.zeros(gt.shape[0])
|
66 |
+
f1 = torch.zeros(gt.shape[0])
|
67 |
+
|
68 |
+
for b in range(gt.shape[0]):
|
69 |
+
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
70 |
+
precision_denominator = (pred[b, :] >= 0.5).sum()
|
71 |
+
recall_denominator = (gt[b, :]).sum()
|
72 |
+
tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
|
73 |
+
|
74 |
+
acc_ = (tp_num + tn_num) / gt.shape[-1]
|
75 |
+
precision_ = tp_num / (precision_denominator + 1e-10)
|
76 |
+
recall_ = tp_num / (recall_denominator + 1e-10)
|
77 |
+
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
|
78 |
+
|
79 |
+
acc[b] = acc_
|
80 |
+
precision[b] = precision_
|
81 |
+
recall[b] = recall_
|
82 |
+
|
83 |
+
# return precision, recall, f1
|
84 |
+
return acc, precision, recall, f1
|
85 |
+
|
86 |
+
def det_error_metric(pred, gt):
|
87 |
+
|
88 |
+
gt = gt.detach().cpu()
|
89 |
+
pred = pred.detach().cpu()
|
90 |
+
|
91 |
+
dist_matrix = torch.tensor(DIST_MATRIX)
|
92 |
+
|
93 |
+
false_positive_dist = torch.zeros(gt.shape[0])
|
94 |
+
false_negative_dist = torch.zeros(gt.shape[0])
|
95 |
+
|
96 |
+
for b in range(gt.shape[0]):
|
97 |
+
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
|
98 |
+
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
|
99 |
+
|
100 |
+
false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
|
101 |
+
false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
|
102 |
+
|
103 |
+
false_positive_dist[b] = false_positive_dist_
|
104 |
+
false_negative_dist[b] = false_negative_dist_
|
105 |
+
|
106 |
+
return false_positive_dist, false_negative_dist
|
utils/smpl_uv.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import trimesh
|
3 |
+
import numpy as np
|
4 |
+
import skimage.io as io
|
5 |
+
from PIL import Image
|
6 |
+
from smplx import SMPL
|
7 |
+
from matplotlib import cm as mpl_cm, colors as mpl_colors
|
8 |
+
from trimesh.visual.color import face_to_vertex_color, vertex_to_face_color, to_rgba
|
9 |
+
|
10 |
+
from common import constants
|
11 |
+
from .colorwheel import make_color_wheel_image
|
12 |
+
|
13 |
+
|
14 |
+
def get_smpl_uv():
|
15 |
+
uv_obj = 'data/body_models/smpl_uv_20200910/smpl_uv.obj'
|
16 |
+
|
17 |
+
uv_map = []
|
18 |
+
with open(uv_obj) as f:
|
19 |
+
for line in f.readlines():
|
20 |
+
if line.startswith('vt'):
|
21 |
+
coords = [float(x) for x in line.split(' ')[1:]]
|
22 |
+
uv_map.append(coords)
|
23 |
+
|
24 |
+
uv_map = np.array(uv_map)
|
25 |
+
|
26 |
+
return uv_map
|
27 |
+
|
28 |
+
|
29 |
+
def show_uv_texture():
|
30 |
+
# image = io.imread('data/body_models/smpl_uv_20200910/smpl_uv_20200910.png')
|
31 |
+
image = make_color_wheel_image(1024, 1024)
|
32 |
+
image = Image.fromarray(image)
|
33 |
+
|
34 |
+
uv = np.load('data/body_models/smpl_uv_20200910/uv_table.npy') # get_smpl_uv()
|
35 |
+
material = trimesh.visual.texture.SimpleMaterial(image=image)
|
36 |
+
tex_visuals = trimesh.visual.TextureVisuals(uv=uv, image=image, material=material)
|
37 |
+
|
38 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
39 |
+
|
40 |
+
faces = smpl.faces
|
41 |
+
verts = smpl().vertices[0].detach().numpy()
|
42 |
+
|
43 |
+
# assert(len(uv) == len(verts))
|
44 |
+
print(uv.shape)
|
45 |
+
vc = tex_visuals.to_color().vertex_colors
|
46 |
+
fc = trimesh.visual.color.vertex_to_face_color(vc, faces)
|
47 |
+
face_colors = fc.copy()
|
48 |
+
fc = fc.astype(float)
|
49 |
+
vc = vc.astype(float)
|
50 |
+
fc[:,:3] = fc[:,:3] / 255.
|
51 |
+
vc[:,:3] = vc[:,:3] / 255.
|
52 |
+
print(fc[:,:3].max(), fc[:,:3].min(), fc[:,:3].mean())
|
53 |
+
print(vc[:, :3].max(), vc[:, :3].min(), vc[:, :3].mean())
|
54 |
+
np.save('data/body_models/smpl/color_wheel_face_colors.npy', fc)
|
55 |
+
np.save('data/body_models/smpl/color_wheel_vertex_colors.npy', vc)
|
56 |
+
print(fc.shape)
|
57 |
+
mesh = trimesh.Trimesh(verts, faces, validate=True, process=False, face_colors=face_colors)
|
58 |
+
# mesh = trimesh.load('data/body_models/smpl_uv_20200910/smpl_uv.obj', process=False)
|
59 |
+
# mesh.visual = tex_visuals
|
60 |
+
|
61 |
+
# import ipdb; ipdb.set_trace()
|
62 |
+
# print(vc.shape)
|
63 |
+
mesh.show()
|
64 |
+
|
65 |
+
|
66 |
+
def show_colored_mesh():
|
67 |
+
cm = mpl_cm.get_cmap('jet')
|
68 |
+
norm_gt = mpl_colors.Normalize()
|
69 |
+
|
70 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
71 |
+
|
72 |
+
faces = smpl.faces
|
73 |
+
verts = smpl().vertices[0].detach().numpy()
|
74 |
+
|
75 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
76 |
+
|
77 |
+
mode = 1
|
78 |
+
if mode == 0:
|
79 |
+
# mano_segm_labels = m.triangles_center
|
80 |
+
face_labels = m.triangles_center
|
81 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
82 |
+
|
83 |
+
elif mode == 1:
|
84 |
+
# print(face_labels.shape)
|
85 |
+
face_labels = m.triangles_center
|
86 |
+
face_labels = np.argsort(np.linalg.norm(face_labels, axis=-1))
|
87 |
+
face_colors = np.ones((13776, 4))
|
88 |
+
face_colors[:, 3] = 1.0
|
89 |
+
face_colors[:, :3] = cm(norm_gt(face_labels))[:, :3]
|
90 |
+
elif mode == 2:
|
91 |
+
# breakpoint()
|
92 |
+
fc = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')[0, :, 0, 0, 0, :]
|
93 |
+
face_colors = np.ones((13776, 4))
|
94 |
+
face_colors[:, :3] = fc
|
95 |
+
mesh = trimesh.Trimesh(verts, faces, process=False, face_colors=face_colors)
|
96 |
+
mesh.show()
|
97 |
+
|
98 |
+
|
99 |
+
def get_tenet_texture(mode='smplpix'):
|
100 |
+
# mode = 'smplpix', 'decomr'
|
101 |
+
|
102 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
103 |
+
|
104 |
+
faces = smpl.faces
|
105 |
+
verts = smpl().vertices[0].detach().numpy()
|
106 |
+
|
107 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
108 |
+
if mode == 'smplpix':
|
109 |
+
# mano_segm_labels = m.triangles_center
|
110 |
+
face_labels = m.triangles_center
|
111 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
112 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
113 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
114 |
+
texture = torch.from_numpy(texture).float()
|
115 |
+
elif mode == 'decomr':
|
116 |
+
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
117 |
+
texture = torch.from_numpy(texture).float()
|
118 |
+
elif mode == 'colorwheel':
|
119 |
+
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
120 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
121 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
122 |
+
texture = torch.from_numpy(texture).float()
|
123 |
+
else:
|
124 |
+
raise ValueError(f'{mode} is not defined!')
|
125 |
+
|
126 |
+
return texture
|
127 |
+
|
128 |
+
|
129 |
+
def save_tenet_textures(mode='smplpix'):
|
130 |
+
# mode = 'smplpix', 'decomr'
|
131 |
+
|
132 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
133 |
+
|
134 |
+
faces = smpl.faces
|
135 |
+
verts = smpl().vertices[0].detach().numpy()
|
136 |
+
|
137 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
138 |
+
|
139 |
+
if mode == 'smplpix':
|
140 |
+
# mano_segm_labels = m.triangles_center
|
141 |
+
face_labels = m.triangles_center
|
142 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
143 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
144 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
145 |
+
texture = torch.from_numpy(texture).float()
|
146 |
+
|
147 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
148 |
+
|
149 |
+
elif mode == 'decomr':
|
150 |
+
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
151 |
+
texture = torch.from_numpy(texture).float()
|
152 |
+
face_colors = texture[0, :, 0, 0, 0, :]
|
153 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
154 |
+
|
155 |
+
elif mode == 'colorwheel':
|
156 |
+
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
157 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
158 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
159 |
+
texture = torch.from_numpy(texture).float()
|
160 |
+
face_colors[:, :3] *= 255
|
161 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
162 |
+
else:
|
163 |
+
raise ValueError(f'{mode} is not defined!')
|
164 |
+
|
165 |
+
print(vert_colors.shape, vert_colors.max())
|
166 |
+
np.save(f'data/body_models/smpl/{mode}_vertex_colors.npy', vert_colors)
|
167 |
+
return texture
|