YoonaAI commited on
Commit
87ea12d
·
1 Parent(s): 208e8a7

Upload 8 files

Browse files
lib/dataset/Evaluator.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+
19
+ from lib.renderer.gl.normal_render import NormalRender
20
+ from lib.dataset.mesh_util import projection
21
+ from lib.common.render import Render
22
+ from PIL import Image
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+ import trimesh
27
+ import os.path as osp
28
+ from PIL import Image
29
+
30
+
31
+ class Evaluator:
32
+
33
+ _normal_render = None
34
+
35
+ @staticmethod
36
+ def init_gl():
37
+ Evaluator._normal_render = NormalRender(width=512, height=512)
38
+
39
+ def __init__(self, device):
40
+ self.device = device
41
+ self.render = Render(size=512, device=self.device)
42
+ self.error_term = nn.MSELoss()
43
+
44
+ self.offset = 0.0
45
+ self.scale_factor = None
46
+
47
+ def set_mesh(self, result_dict, scale_factor=1.0, offset=0.0):
48
+
49
+ for key in result_dict.keys():
50
+ if torch.is_tensor(result_dict[key]):
51
+ result_dict[key] = result_dict[key].detach().cpu().numpy()
52
+
53
+ for k, v in result_dict.items():
54
+ setattr(self, k, v)
55
+
56
+ self.scale_factor = scale_factor
57
+ self.offset = offset
58
+
59
+ def _render_normal(self, mesh, deg, norms=None):
60
+ view_mat = np.identity(4)
61
+ rz = deg / 180.0 * np.pi
62
+ model_mat = np.identity(4)
63
+ model_mat[:3, :3] = self._normal_render.euler_to_rot_mat(0, rz, 0)
64
+ model_mat[1, 3] = self.offset
65
+ view_mat[2, 2] *= -1
66
+
67
+ self._normal_render.set_matrices(view_mat, model_mat)
68
+ if norms is None:
69
+ norms = mesh.vertex_normals
70
+ self._normal_render.set_normal_mesh(self.scale_factor * mesh.vertices,
71
+ mesh.faces, norms, mesh.faces)
72
+ self._normal_render.draw()
73
+ normal_img = self._normal_render.get_color()
74
+ return normal_img
75
+
76
+ def render_mesh_list(self, mesh_lst):
77
+
78
+ self.offset = 0.0
79
+ self.scale_factor = 1.0
80
+
81
+ full_list = []
82
+ for mesh in mesh_lst:
83
+ row_lst = []
84
+ for deg in np.arange(0, 360, 90):
85
+ normal = self._render_normal(mesh, deg)
86
+ row_lst.append(normal)
87
+ full_list.append(np.concatenate(row_lst, axis=1))
88
+
89
+ res_array = np.concatenate(full_list, axis=0)
90
+
91
+ return res_array
92
+
93
+ def _get_reproj_normal_error(self, deg):
94
+
95
+ tgt_normal = self._render_normal(self.tgt_mesh, deg)
96
+ src_normal = self._render_normal(self.src_mesh, deg)
97
+ error = (((src_normal[:, :, :3] -
98
+ tgt_normal[:, :, :3])**2).sum(axis=2).mean(axis=(0, 1)))
99
+
100
+ return error, [src_normal, tgt_normal]
101
+
102
+ def render_normal(self, verts, faces):
103
+
104
+ verts = verts[0].detach().cpu().numpy()
105
+ faces = faces[0].detach().cpu().numpy()
106
+
107
+ mesh_F = trimesh.Trimesh(verts * np.array([1.0, -1.0, 1.0]), faces)
108
+ mesh_B = trimesh.Trimesh(verts * np.array([1.0, -1.0, -1.0]), faces)
109
+
110
+ self.scale_factor = 1.0
111
+
112
+ normal_F = self._render_normal(mesh_F, 0)
113
+ normal_B = self._render_normal(mesh_B,
114
+ 0,
115
+ norms=mesh_B.vertex_normals *
116
+ np.array([-1.0, -1.0, 1.0]))
117
+
118
+ mask = normal_F[:, :, 3:4]
119
+ normal_F = (torch.as_tensor(2.0 * (normal_F - 0.5) * mask).permute(
120
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
121
+ normal_B = (torch.as_tensor(2.0 * (normal_B - 0.5) * mask).permute(
122
+ 2, 0, 1)[:3, :, :].float().unsqueeze(0).to(self.device))
123
+
124
+ return {"T_normal_F": normal_F, "T_normal_B": normal_B}
125
+
126
+ def calculate_normal_consist(
127
+ self,
128
+ frontal=True,
129
+ back=True,
130
+ left=True,
131
+ right=True,
132
+ save_demo_img=None,
133
+ return_demo=False,
134
+ ):
135
+
136
+ # reproj error
137
+ # if save_demo_img is not None, save a visualization at the given path (etc, "./test.png")
138
+ if self._normal_render is None:
139
+ print(
140
+ "In order to use normal render, "
141
+ "you have to call init_gl() before initialing any evaluator objects."
142
+ )
143
+ return -1
144
+
145
+ side_cnt = 0
146
+ total_error = 0
147
+ demo_list = []
148
+
149
+ if frontal:
150
+ side_cnt += 1
151
+ error, normal_lst = self._get_reproj_normal_error(0)
152
+ total_error += error
153
+ demo_list.append(np.concatenate(normal_lst, axis=0))
154
+ if back:
155
+ side_cnt += 1
156
+ error, normal_lst = self._get_reproj_normal_error(180)
157
+ total_error += error
158
+ demo_list.append(np.concatenate(normal_lst, axis=0))
159
+ if left:
160
+ side_cnt += 1
161
+ error, normal_lst = self._get_reproj_normal_error(90)
162
+ total_error += error
163
+ demo_list.append(np.concatenate(normal_lst, axis=0))
164
+ if right:
165
+ side_cnt += 1
166
+ error, normal_lst = self._get_reproj_normal_error(270)
167
+ total_error += error
168
+ demo_list.append(np.concatenate(normal_lst, axis=0))
169
+ if save_demo_img is not None:
170
+ res_array = np.concatenate(demo_list, axis=1)
171
+ res_img = Image.fromarray((res_array * 255).astype(np.uint8))
172
+ res_img.save(save_demo_img)
173
+
174
+ if return_demo:
175
+ res_array = np.concatenate(demo_list, axis=1)
176
+ return res_array
177
+ else:
178
+ return total_error
179
+
180
+ def space_transfer(self):
181
+
182
+ # convert from GT to SDF
183
+ self.verts_pr -= self.recon_size / 2.0
184
+ self.verts_pr /= self.recon_size / 2.0
185
+
186
+ self.verts_gt = projection(self.verts_gt, self.calib)
187
+ self.verts_gt[:, 1] *= -1
188
+
189
+ self.tgt_mesh = trimesh.Trimesh(self.verts_gt, self.faces_gt)
190
+ self.src_mesh = trimesh.Trimesh(self.verts_pr, self.faces_pr)
191
+
192
+ # (self.tgt_mesh+self.src_mesh).show()
193
+
194
+ def export_mesh(self, dir, name):
195
+ self.tgt_mesh.visual.vertex_colors = np.array([255, 0, 0])
196
+ self.src_mesh.visual.vertex_colors = np.array([0, 255, 0])
197
+
198
+ (self.tgt_mesh + self.src_mesh).export(
199
+ osp.join(dir, f"{name}_gt_pr.obj"))
200
+
201
+ def calculate_chamfer_p2s(self, sampled_points=1000):
202
+ """calculate the geometry metrics [chamfer, p2s, chamfer_H, p2s_H]
203
+
204
+ Args:
205
+ verts_gt (torch.cuda.tensor): [N, 3]
206
+ faces_gt (torch.cuda.tensor): [M, 3]
207
+ verts_pr (torch.cuda.tensor): [N', 3]
208
+ faces_pr (torch.cuda.tensor): [M', 3]
209
+ sampled_points (int, optional): use smaller number for faster testing. Defaults to 1000.
210
+
211
+ Returns:
212
+ tuple: chamfer, p2s, chamfer_H, p2s_H
213
+ """
214
+
215
+ gt_surface_pts, _ = trimesh.sample.sample_surface_even(
216
+ self.tgt_mesh, sampled_points)
217
+ pred_surface_pts, _ = trimesh.sample.sample_surface_even(
218
+ self.src_mesh, sampled_points)
219
+
220
+ _, dist_pred_gt, _ = trimesh.proximity.closest_point(
221
+ self.src_mesh, gt_surface_pts)
222
+ _, dist_gt_pred, _ = trimesh.proximity.closest_point(
223
+ self.tgt_mesh, pred_surface_pts)
224
+
225
+ dist_pred_gt[np.isnan(dist_pred_gt)] = 0
226
+ dist_gt_pred[np.isnan(dist_gt_pred)] = 0
227
+ chamfer_dist = 0.5 * (dist_pred_gt.mean() +
228
+ dist_gt_pred.mean()).item() * 100
229
+ p2s_dist = dist_pred_gt.mean().item() * 100
230
+
231
+ return chamfer_dist, p2s_dist
232
+
233
+ def calc_acc(self, output, target, thres=0.5, use_sdf=False):
234
+
235
+ # # remove the surface points with thres
236
+ # non_surf_ids = (target != thres)
237
+ # output = output[non_surf_ids]
238
+ # target = target[non_surf_ids]
239
+
240
+ with torch.no_grad():
241
+ output = output.masked_fill(output < thres, 0.0)
242
+ output = output.masked_fill(output > thres, 1.0)
243
+
244
+ if use_sdf:
245
+ target = target.masked_fill(target < thres, 0.0)
246
+ target = target.masked_fill(target > thres, 1.0)
247
+
248
+ acc = output.eq(target).float().mean()
249
+
250
+ # iou, precison, recall
251
+ output = output > thres
252
+ target = target > thres
253
+
254
+ union = output | target
255
+ inter = output & target
256
+
257
+ _max = torch.tensor(1.0).to(output.device)
258
+
259
+ union = max(union.sum().float(), _max)
260
+ true_pos = max(inter.sum().float(), _max)
261
+ vol_pred = max(output.sum().float(), _max)
262
+ vol_gt = max(target.sum().float(), _max)
263
+
264
+ return acc, true_pos / union, true_pos / vol_pred, true_pos / vol_gt
lib/dataset/NormalDataset.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import os.path as osp
19
+ import numpy as np
20
+ from PIL import Image
21
+ import torchvision.transforms as transforms
22
+
23
+
24
+ class NormalDataset():
25
+ def __init__(self, cfg, split='train'):
26
+
27
+ self.split = split
28
+ self.root = cfg.root
29
+ self.overfit = cfg.overfit
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.set_splits = self.opt.set_splits
35
+ self.scales = self.opt.scales
36
+ self.pifu = self.opt.pifu
37
+
38
+ # input data types and dimensions
39
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
40
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
41
+ self.in_total = self.in_nml + ['normal_F', 'normal_B']
42
+ self.in_total_dim = self.in_nml_dim + [3, 3]
43
+
44
+ if self.split != 'train':
45
+ self.rotations = range(0, 360, 120)
46
+ else:
47
+ self.rotations = np.arange(0, 360, 360 /
48
+ self.opt.rotation_num).astype(np.int)
49
+
50
+ self.datasets_dict = {}
51
+ for dataset_id, dataset in enumerate(self.datasets):
52
+ dataset_dir = osp.join(self.root, dataset, "smplx")
53
+ self.datasets_dict[dataset] = {
54
+ "subjects":
55
+ np.loadtxt(osp.join(self.root, dataset, "all.txt"), dtype=str),
56
+ "path":
57
+ dataset_dir,
58
+ "scale":
59
+ self.scales[dataset_id]
60
+ }
61
+
62
+ self.subject_list = self.get_subject_list(split)
63
+
64
+ # PIL to tensor
65
+ self.image_to_tensor = transforms.Compose([
66
+ transforms.Resize(self.input_size),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
69
+ ])
70
+
71
+ # PIL to tensor
72
+ self.mask_to_tensor = transforms.Compose([
73
+ transforms.Resize(self.input_size),
74
+ transforms.ToTensor(),
75
+ transforms.Normalize((0.0, ), (1.0, ))
76
+ ])
77
+
78
+ def get_subject_list(self, split):
79
+
80
+ subject_list = []
81
+
82
+ for dataset in self.datasets:
83
+
84
+ if self.pifu:
85
+ txt = osp.join(self.root, dataset, f'{split}_pifu.txt')
86
+ else:
87
+ txt = osp.join(self.root, dataset, f'{split}.txt')
88
+
89
+ if osp.exists(txt):
90
+ print(f"load from {txt}")
91
+ subject_list += sorted(np.loadtxt(txt, dtype=str).tolist())
92
+
93
+ if self.pifu:
94
+ miss_pifu = sorted(
95
+ np.loadtxt(osp.join(self.root, dataset,
96
+ "miss_pifu.txt"),
97
+ dtype=str).tolist())
98
+ subject_list = [
99
+ subject for subject in subject_list
100
+ if subject not in miss_pifu
101
+ ]
102
+ subject_list = [
103
+ "renderpeople/" + subject for subject in subject_list
104
+ ]
105
+
106
+ else:
107
+ train_txt = osp.join(self.root, dataset, 'train.txt')
108
+ val_txt = osp.join(self.root, dataset, 'val.txt')
109
+ test_txt = osp.join(self.root, dataset, 'test.txt')
110
+
111
+ print(
112
+ f"generate lists of [train, val, test] \n {train_txt} \n {val_txt} \n {test_txt} \n"
113
+ )
114
+
115
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
116
+
117
+ subjects = self.datasets_dict[dataset]['subjects']
118
+ train_split = int(len(subjects) * self.set_splits[0])
119
+ val_split = int(
120
+ len(subjects) * self.set_splits[1]) + train_split
121
+
122
+ with open(train_txt, "w") as f:
123
+ f.write("\n".join(dataset + "/" + item
124
+ for item in subjects[:train_split]))
125
+ with open(val_txt, "w") as f:
126
+ f.write("\n".join(
127
+ dataset + "/" + item
128
+ for item in subjects[train_split:val_split]))
129
+ with open(test_txt, "w") as f:
130
+ f.write("\n".join(dataset + "/" + item
131
+ for item in subjects[val_split:]))
132
+
133
+ subject_list += sorted(
134
+ np.loadtxt(split_txt, dtype=str).tolist())
135
+
136
+ bug_list = sorted(
137
+ np.loadtxt(osp.join(self.root, 'bug.txt'), dtype=str).tolist())
138
+
139
+ subject_list = [
140
+ subject for subject in subject_list if (subject not in bug_list)
141
+ ]
142
+
143
+ return subject_list
144
+
145
+ def __len__(self):
146
+ return len(self.subject_list) * len(self.rotations)
147
+
148
+ def __getitem__(self, index):
149
+
150
+ # only pick the first data if overfitting
151
+ if self.overfit:
152
+ index = 0
153
+
154
+ rid = index % len(self.rotations)
155
+ mid = index // len(self.rotations)
156
+
157
+ rotation = self.rotations[rid]
158
+
159
+ # choose specific test sets
160
+ subject = self.subject_list[mid]
161
+
162
+ subject_render = "/".join(
163
+ [subject.split("/")[0] + "_12views",
164
+ subject.split("/")[1]])
165
+
166
+ # setup paths
167
+ data_dict = {
168
+ 'dataset':
169
+ subject.split("/")[0],
170
+ 'subject':
171
+ subject,
172
+ 'rotation':
173
+ rotation,
174
+ 'image_path':
175
+ osp.join(self.root, subject_render, 'render',
176
+ f'{rotation:03d}.png')
177
+ }
178
+
179
+ # image/normal/depth loader
180
+ for name, channel in zip(self.in_total, self.in_total_dim):
181
+
182
+ if name != 'image':
183
+ data_dict.update({
184
+ f'{name}_path':
185
+ osp.join(self.root, subject_render, name,
186
+ f'{rotation:03d}.png')
187
+ })
188
+ data_dict.update({
189
+ name:
190
+ self.imagepath2tensor(data_dict[f'{name}_path'],
191
+ channel,
192
+ inv='depth_B' in name)
193
+ })
194
+
195
+ path_keys = [
196
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
197
+ ]
198
+ for key in path_keys:
199
+ del data_dict[key]
200
+
201
+ return data_dict
202
+
203
+ def imagepath2tensor(self, path, channel=3, inv=False):
204
+
205
+ rgba = Image.open(path).convert('RGBA')
206
+ mask = rgba.split()[-1]
207
+ image = rgba.convert('RGB')
208
+ image = self.image_to_tensor(image)
209
+ mask = self.mask_to_tensor(mask)
210
+ image = (image * mask)[:channel]
211
+
212
+ return (image * (0.5 - inv) * 2.0).float()
lib/dataset/NormalModule.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import numpy as np
19
+ from torch.utils.data import DataLoader
20
+ from .NormalDataset import NormalDataset
21
+
22
+ # pytorch lightning related libs
23
+ import pytorch_lightning as pl
24
+
25
+
26
+ class NormalModule(pl.LightningDataModule):
27
+ def __init__(self, cfg):
28
+ super(NormalModule, self).__init__()
29
+ self.cfg = cfg
30
+ self.overfit = self.cfg.overfit
31
+
32
+ if self.overfit:
33
+ self.batch_size = 1
34
+ else:
35
+ self.batch_size = self.cfg.batch_size
36
+
37
+ self.data_size = {}
38
+
39
+ def prepare_data(self):
40
+
41
+ pass
42
+
43
+ @staticmethod
44
+ def worker_init_fn(worker_id):
45
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
46
+
47
+ def setup(self, stage):
48
+
49
+ if stage == 'fit' or stage is None:
50
+ self.train_dataset = NormalDataset(cfg=self.cfg, split="train")
51
+ self.val_dataset = NormalDataset(cfg=self.cfg, split="val")
52
+ self.data_size = {
53
+ 'train': len(self.train_dataset),
54
+ 'val': len(self.val_dataset)
55
+ }
56
+
57
+ if stage == 'test' or stage is None:
58
+ self.test_dataset = NormalDataset(cfg=self.cfg, split="test")
59
+
60
+ def train_dataloader(self):
61
+
62
+ train_data_loader = DataLoader(self.train_dataset,
63
+ batch_size=self.batch_size,
64
+ shuffle=not self.overfit,
65
+ num_workers=self.cfg.num_threads,
66
+ pin_memory=True,
67
+ worker_init_fn=self.worker_init_fn)
68
+
69
+ return train_data_loader
70
+
71
+ def val_dataloader(self):
72
+
73
+ if self.overfit:
74
+ current_dataset = self.train_dataset
75
+ else:
76
+ current_dataset = self.val_dataset
77
+
78
+ val_data_loader = DataLoader(current_dataset,
79
+ batch_size=self.batch_size,
80
+ shuffle=False,
81
+ num_workers=self.cfg.num_threads,
82
+ pin_memory=True)
83
+
84
+ return val_data_loader
85
+
86
+ def test_dataloader(self):
87
+
88
+ test_data_loader = DataLoader(self.test_dataset,
89
+ batch_size=1,
90
+ shuffle=False,
91
+ num_workers=self.cfg.num_threads,
92
+ pin_memory=True)
93
+
94
+ return test_data_loader
lib/dataset/PIFuDataModule.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.utils.data import DataLoader
3
+ from .PIFuDataset import PIFuDataset
4
+ import pytorch_lightning as pl
5
+
6
+
7
+ class PIFuDataModule(pl.LightningDataModule):
8
+ def __init__(self, cfg):
9
+ super(PIFuDataModule, self).__init__()
10
+ self.cfg = cfg
11
+ self.overfit = self.cfg.overfit
12
+
13
+ if self.overfit:
14
+ self.batch_size = 1
15
+ else:
16
+ self.batch_size = self.cfg.batch_size
17
+
18
+ self.data_size = {}
19
+
20
+ def prepare_data(self):
21
+
22
+ pass
23
+
24
+ @staticmethod
25
+ def worker_init_fn(worker_id):
26
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
27
+
28
+ def setup(self, stage):
29
+
30
+ if stage == 'fit':
31
+ self.train_dataset = PIFuDataset(cfg=self.cfg, split="train")
32
+ self.val_dataset = PIFuDataset(cfg=self.cfg, split="val")
33
+ self.data_size = {'train': len(self.train_dataset),
34
+ 'val': len(self.val_dataset)}
35
+
36
+ if stage == 'test':
37
+ self.test_dataset = PIFuDataset(cfg=self.cfg, split="test")
38
+
39
+ def train_dataloader(self):
40
+
41
+ train_data_loader = DataLoader(
42
+ self.train_dataset,
43
+ batch_size=self.batch_size, shuffle=True,
44
+ num_workers=self.cfg.num_threads, pin_memory=True,
45
+ worker_init_fn=self.worker_init_fn)
46
+
47
+ return train_data_loader
48
+
49
+ def val_dataloader(self):
50
+
51
+ if self.overfit:
52
+ current_dataset = self.train_dataset
53
+ else:
54
+ current_dataset = self.val_dataset
55
+
56
+ val_data_loader = DataLoader(
57
+ current_dataset,
58
+ batch_size=1, shuffle=False,
59
+ num_workers=self.cfg.num_threads, pin_memory=True,
60
+ worker_init_fn=self.worker_init_fn)
61
+
62
+ return val_data_loader
63
+
64
+ def test_dataloader(self):
65
+
66
+ test_data_loader = DataLoader(
67
+ self.test_dataset,
68
+ batch_size=1, shuffle=False,
69
+ num_workers=self.cfg.num_threads, pin_memory=True)
70
+
71
+ return test_data_loader
lib/dataset/PIFuDataset.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lib.renderer.mesh import load_fit_body
2
+ from lib.dataset.hoppeMesh import HoppeMesh
3
+ from lib.dataset.body_model import TetraSMPLModel
4
+ from lib.common.render import Render
5
+ from lib.dataset.mesh_util import SMPLX, projection, cal_sdf_batch, get_visibility
6
+ from lib.pare.pare.utils.geometry import rotation_matrix_to_angle_axis
7
+ from termcolor import colored
8
+ import os.path as osp
9
+ import numpy as np
10
+ from PIL import Image
11
+ import random
12
+ import os
13
+ import trimesh
14
+ import torch
15
+ from kaolin.ops.mesh import check_sign
16
+ import torchvision.transforms as transforms
17
+ from huggingface_hub import hf_hub_download, cached_download
18
+
19
+
20
+ class PIFuDataset():
21
+ def __init__(self, cfg, split='train', vis=False):
22
+
23
+ self.split = split
24
+ self.root = cfg.root
25
+ self.bsize = cfg.batch_size
26
+ self.overfit = cfg.overfit
27
+
28
+ # for debug, only used in visualize_sampling3D
29
+ self.vis = vis
30
+
31
+ self.opt = cfg.dataset
32
+ self.datasets = self.opt.types
33
+ self.input_size = self.opt.input_size
34
+ self.scales = self.opt.scales
35
+ self.workers = cfg.num_threads
36
+ self.prior_type = cfg.net.prior_type
37
+
38
+ self.noise_type = self.opt.noise_type
39
+ self.noise_scale = self.opt.noise_scale
40
+
41
+ noise_joints = [4, 5, 7, 8, 13, 14, 16, 17, 18, 19, 20, 21]
42
+
43
+ self.noise_smpl_idx = []
44
+ self.noise_smplx_idx = []
45
+
46
+ for idx in noise_joints:
47
+ self.noise_smpl_idx.append(idx * 3)
48
+ self.noise_smpl_idx.append(idx * 3 + 1)
49
+ self.noise_smpl_idx.append(idx * 3 + 2)
50
+
51
+ self.noise_smplx_idx.append((idx-1) * 3)
52
+ self.noise_smplx_idx.append((idx-1) * 3 + 1)
53
+ self.noise_smplx_idx.append((idx-1) * 3 + 2)
54
+
55
+ self.use_sdf = cfg.sdf
56
+ self.sdf_clip = cfg.sdf_clip
57
+
58
+ # [(feat_name, channel_num),...]
59
+ self.in_geo = [item[0] for item in cfg.net.in_geo]
60
+ self.in_nml = [item[0] for item in cfg.net.in_nml]
61
+
62
+ self.in_geo_dim = [item[1] for item in cfg.net.in_geo]
63
+ self.in_nml_dim = [item[1] for item in cfg.net.in_nml]
64
+
65
+ self.in_total = self.in_geo + self.in_nml
66
+ self.in_total_dim = self.in_geo_dim + self.in_nml_dim
67
+
68
+ if self.split == 'train':
69
+ self.rotations = np.arange(
70
+ 0, 360, 360 / self.opt.rotation_num).astype(np.int32)
71
+ else:
72
+ self.rotations = range(0, 360, 120)
73
+
74
+ self.datasets_dict = {}
75
+
76
+ for dataset_id, dataset in enumerate(self.datasets):
77
+
78
+ mesh_dir = None
79
+ smplx_dir = None
80
+
81
+ dataset_dir = osp.join(self.root, dataset)
82
+
83
+ if dataset in ['thuman2']:
84
+ mesh_dir = osp.join(dataset_dir, "scans")
85
+ smplx_dir = osp.join(dataset_dir, "fits")
86
+ smpl_dir = osp.join(dataset_dir, "smpl")
87
+
88
+ self.datasets_dict[dataset] = {
89
+ "subjects": np.loadtxt(osp.join(dataset_dir, "all.txt"), dtype=str),
90
+ "smplx_dir": smplx_dir,
91
+ "smpl_dir": smpl_dir,
92
+ "mesh_dir": mesh_dir,
93
+ "scale": self.scales[dataset_id]
94
+ }
95
+
96
+ self.subject_list = self.get_subject_list(split)
97
+ self.smplx = SMPLX()
98
+
99
+ # PIL to tensor
100
+ self.image_to_tensor = transforms.Compose([
101
+ transforms.Resize(self.input_size),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
104
+ ])
105
+
106
+ # PIL to tensor
107
+ self.mask_to_tensor = transforms.Compose([
108
+ transforms.Resize(self.input_size),
109
+ transforms.ToTensor(),
110
+ transforms.Normalize((0.0, ), (1.0, ))
111
+ ])
112
+
113
+ self.device = torch.device(f"cuda:{cfg.gpus[0]}")
114
+ self.render = Render(size=512, device=self.device)
115
+
116
+ def render_normal(self, verts, faces):
117
+
118
+ # render optimized mesh (normal, T_normal, image [-1,1])
119
+ self.render.load_meshes(verts, faces)
120
+ return self.render.get_rgb_image()
121
+
122
+ def get_subject_list(self, split):
123
+
124
+ subject_list = []
125
+
126
+ for dataset in self.datasets:
127
+
128
+ split_txt = osp.join(self.root, dataset, f'{split}.txt')
129
+
130
+ if osp.exists(split_txt):
131
+ print(f"load from {split_txt}")
132
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
133
+ else:
134
+ full_txt = osp.join(self.root, dataset, 'all.txt')
135
+ print(f"split {full_txt} into train/val/test")
136
+
137
+ full_lst = np.loadtxt(full_txt, dtype=str)
138
+ full_lst = [dataset+"/"+item for item in full_lst]
139
+ [train_lst, test_lst, val_lst] = np.split(
140
+ full_lst, [500, 500+5, ])
141
+
142
+ np.savetxt(full_txt.replace(
143
+ "all", "train"), train_lst, fmt="%s")
144
+ np.savetxt(full_txt.replace("all", "test"), test_lst, fmt="%s")
145
+ np.savetxt(full_txt.replace("all", "val"), val_lst, fmt="%s")
146
+
147
+ print(f"load from {split_txt}")
148
+ subject_list += np.loadtxt(split_txt, dtype=str).tolist()
149
+
150
+ if self.split != 'test':
151
+ subject_list += subject_list[:self.bsize -
152
+ len(subject_list) % self.bsize]
153
+ print(colored(f"total: {len(subject_list)}", "yellow"))
154
+ random.shuffle(subject_list)
155
+
156
+ # subject_list = ["thuman2/0008"]
157
+ return subject_list
158
+
159
+ def __len__(self):
160
+ return len(self.subject_list) * len(self.rotations)
161
+
162
+ def __getitem__(self, index):
163
+
164
+ # only pick the first data if overfitting
165
+ if self.overfit:
166
+ index = 0
167
+
168
+ rid = index % len(self.rotations)
169
+ mid = index // len(self.rotations)
170
+
171
+ rotation = self.rotations[rid]
172
+ subject = self.subject_list[mid].split("/")[1]
173
+ dataset = self.subject_list[mid].split("/")[0]
174
+ render_folder = "/".join([dataset +
175
+ f"_{self.opt.rotation_num}views", subject])
176
+
177
+ # setup paths
178
+ data_dict = {
179
+ 'dataset': dataset,
180
+ 'subject': subject,
181
+ 'rotation': rotation,
182
+ 'scale': self.datasets_dict[dataset]["scale"],
183
+ 'mesh_path': osp.join(self.datasets_dict[dataset]["mesh_dir"], f"{subject}/{subject}.obj"),
184
+ 'smplx_path': osp.join(self.datasets_dict[dataset]["smplx_dir"], f"{subject}/smplx_param.pkl"),
185
+ 'smpl_path': osp.join(self.datasets_dict[dataset]["smpl_dir"], f"{subject}.pkl"),
186
+ 'calib_path': osp.join(self.root, render_folder, 'calib', f'{rotation:03d}.txt'),
187
+ 'vis_path': osp.join(self.root, render_folder, 'vis', f'{rotation:03d}.pt'),
188
+ 'image_path': osp.join(self.root, render_folder, 'render', f'{rotation:03d}.png')
189
+ }
190
+
191
+ # load training data
192
+ data_dict.update(self.load_calib(data_dict))
193
+
194
+ # image/normal/depth loader
195
+ for name, channel in zip(self.in_total, self.in_total_dim):
196
+
197
+ if f'{name}_path' not in data_dict.keys():
198
+ data_dict.update({
199
+ f'{name}_path': osp.join(self.root, render_folder, name, f'{rotation:03d}.png')
200
+ })
201
+
202
+ # tensor update
203
+ data_dict.update({
204
+ name: self.imagepath2tensor(
205
+ data_dict[f'{name}_path'], channel, inv=False)
206
+ })
207
+
208
+ data_dict.update(self.load_mesh(data_dict))
209
+ data_dict.update(self.get_sampling_geo(
210
+ data_dict, is_valid=self.split == "val", is_sdf=self.use_sdf))
211
+ data_dict.update(self.load_smpl(data_dict, self.vis))
212
+
213
+ if self.prior_type == 'pamir':
214
+ data_dict.update(self.load_smpl_voxel(data_dict))
215
+
216
+ if (self.split != 'test') and (not self.vis):
217
+
218
+ del data_dict['verts']
219
+ del data_dict['faces']
220
+
221
+ if not self.vis:
222
+ del data_dict['mesh']
223
+
224
+ path_keys = [
225
+ key for key in data_dict.keys() if '_path' in key or '_dir' in key
226
+ ]
227
+ for key in path_keys:
228
+ del data_dict[key]
229
+
230
+ return data_dict
231
+
232
+ def imagepath2tensor(self, path, channel=3, inv=False):
233
+
234
+ rgba = Image.open(path).convert('RGBA')
235
+ mask = rgba.split()[-1]
236
+ image = rgba.convert('RGB')
237
+ image = self.image_to_tensor(image)
238
+ mask = self.mask_to_tensor(mask)
239
+ image = (image * mask)[:channel]
240
+
241
+ return (image * (0.5 - inv) * 2.0).float()
242
+
243
+ def load_calib(self, data_dict):
244
+ calib_data = np.loadtxt(data_dict['calib_path'], dtype=float)
245
+ extrinsic = calib_data[:4, :4]
246
+ intrinsic = calib_data[4:8, :4]
247
+ calib_mat = np.matmul(intrinsic, extrinsic)
248
+ calib_mat = torch.from_numpy(calib_mat).float()
249
+ return {'calib': calib_mat}
250
+
251
+ def load_mesh(self, data_dict):
252
+ mesh_path = data_dict['mesh_path']
253
+ scale = data_dict['scale']
254
+
255
+ mesh_ori = trimesh.load(mesh_path,
256
+ skip_materials=True,
257
+ process=False,
258
+ maintain_order=True)
259
+ verts = mesh_ori.vertices * scale
260
+ faces = mesh_ori.faces
261
+
262
+ vert_normals = np.array(mesh_ori.vertex_normals)
263
+ face_normals = np.array(mesh_ori.face_normals)
264
+
265
+ mesh = HoppeMesh(verts, faces, vert_normals, face_normals)
266
+
267
+ return {
268
+ 'mesh': mesh,
269
+ 'verts': torch.as_tensor(mesh.verts).float(),
270
+ 'faces': torch.as_tensor(mesh.faces).long()
271
+ }
272
+
273
+ def add_noise(self,
274
+ beta_num,
275
+ smpl_pose,
276
+ smpl_betas,
277
+ noise_type,
278
+ noise_scale,
279
+ type,
280
+ hashcode):
281
+
282
+ np.random.seed(hashcode)
283
+
284
+ if type == 'smplx':
285
+ noise_idx = self.noise_smplx_idx
286
+ else:
287
+ noise_idx = self.noise_smpl_idx
288
+
289
+ if 'beta' in noise_type and noise_scale[noise_type.index("beta")] > 0.0:
290
+ smpl_betas += (np.random.rand(beta_num) -
291
+ 0.5) * 2.0 * noise_scale[noise_type.index("beta")]
292
+ smpl_betas = smpl_betas.astype(np.float32)
293
+
294
+ if 'pose' in noise_type and noise_scale[noise_type.index("pose")] > 0.0:
295
+ smpl_pose[noise_idx] += (
296
+ np.random.rand(len(noise_idx)) -
297
+ 0.5) * 2.0 * np.pi * noise_scale[noise_type.index("pose")]
298
+ smpl_pose = smpl_pose.astype(np.float32)
299
+ if type == 'smplx':
300
+ return torch.as_tensor(smpl_pose[None, ...]), torch.as_tensor(smpl_betas[None, ...])
301
+ else:
302
+ return smpl_pose, smpl_betas
303
+
304
+ def compute_smpl_verts(self, data_dict, noise_type=None, noise_scale=None):
305
+
306
+ dataset = data_dict['dataset']
307
+ smplx_dict = {}
308
+
309
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
310
+ smplx_pose = smplx_param["body_pose"] # [1,63]
311
+ smplx_betas = smplx_param["betas"] # [1,10]
312
+ smplx_pose, smplx_betas = self.add_noise(
313
+ smplx_betas.shape[1],
314
+ smplx_pose[0],
315
+ smplx_betas[0],
316
+ noise_type,
317
+ noise_scale,
318
+ type='smplx',
319
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
320
+
321
+ smplx_out, _ = load_fit_body(fitted_path=data_dict['smplx_path'],
322
+ scale=self.datasets_dict[dataset]['scale'],
323
+ smpl_type='smplx',
324
+ smpl_gender='male',
325
+ noise_dict=dict(betas=smplx_betas, body_pose=smplx_pose))
326
+
327
+ smplx_dict.update({"type": "smplx",
328
+ "gender": 'male',
329
+ "body_pose": torch.as_tensor(smplx_pose),
330
+ "betas": torch.as_tensor(smplx_betas)})
331
+
332
+ return smplx_out.vertices, smplx_dict
333
+
334
+ def compute_voxel_verts(self,
335
+ data_dict,
336
+ noise_type=None,
337
+ noise_scale=None):
338
+
339
+ smpl_param = np.load(data_dict['smpl_path'], allow_pickle=True)
340
+ smplx_param = np.load(data_dict['smplx_path'], allow_pickle=True)
341
+
342
+ smpl_pose = rotation_matrix_to_angle_axis(
343
+ torch.as_tensor(smpl_param['full_pose'][0])).numpy()
344
+ smpl_betas = smpl_param["betas"]
345
+
346
+ smpl_path = cached_download(osp.join(self.smplx.model_dir, "smpl/SMPL_MALE.pkl"), use_auth_token=os.environ['ICON'])
347
+ tetra_path = cached_download(osp.join(self.smplx.tedra_dir,
348
+ "tetra_male_adult_smpl.npz"), use_auth_token=os.environ['ICON'])
349
+
350
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
351
+
352
+ smpl_pose, smpl_betas = self.add_noise(
353
+ smpl_model.beta_shape[0],
354
+ smpl_pose.flatten(),
355
+ smpl_betas[0],
356
+ noise_type,
357
+ noise_scale,
358
+ type='smpl',
359
+ hashcode=(hash(f"{data_dict['subject']}_{data_dict['rotation']}")) % (10**8))
360
+
361
+ smpl_model.set_params(pose=smpl_pose.reshape(-1, 3),
362
+ beta=smpl_betas,
363
+ trans=smpl_param["transl"])
364
+
365
+ verts = (np.concatenate([smpl_model.verts, smpl_model.verts_added],
366
+ axis=0) * smplx_param["scale"] + smplx_param["translation"]
367
+ ) * self.datasets_dict[data_dict['dataset']]['scale']
368
+ faces = np.loadtxt(cached_download(osp.join(self.smplx.tedra_dir, "tetrahedrons_male_adult.txt"), use_auth_token=os.environ['ICON']),
369
+ dtype=np.int32) - 1
370
+
371
+ pad_v_num = int(8000 - verts.shape[0])
372
+ pad_f_num = int(25100 - faces.shape[0])
373
+
374
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
375
+ mode='constant',
376
+ constant_values=0.0).astype(np.float32)
377
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
378
+ mode='constant',
379
+ constant_values=0.0).astype(np.int32)
380
+
381
+
382
+ return verts, faces, pad_v_num, pad_f_num
383
+
384
+ def load_smpl(self, data_dict, vis=False):
385
+
386
+ smplx_verts, smplx_dict = self.compute_smpl_verts(
387
+ data_dict, self.noise_type,
388
+ self.noise_scale) # compute using smpl model
389
+
390
+ smplx_verts = projection(smplx_verts, data_dict['calib']).float()
391
+ smplx_faces = torch.as_tensor(self.smplx.faces).long()
392
+ smplx_vis = torch.load(data_dict['vis_path']).float()
393
+ smplx_cmap = torch.as_tensor(
394
+ np.load(self.smplx.cmap_vert_path)).float()
395
+
396
+ # get smpl_signs
397
+ query_points = projection(data_dict['samples_geo'],
398
+ data_dict['calib']).float()
399
+
400
+ pts_signs = 2.0 * (check_sign(smplx_verts.unsqueeze(0),
401
+ smplx_faces,
402
+ query_points.unsqueeze(0)).float() - 0.5).squeeze(0)
403
+
404
+ return_dict = {
405
+ 'smpl_verts': smplx_verts,
406
+ 'smpl_faces': smplx_faces,
407
+ 'smpl_vis': smplx_vis,
408
+ 'smpl_cmap': smplx_cmap,
409
+ 'pts_signs': pts_signs
410
+ }
411
+ if smplx_dict is not None:
412
+ return_dict.update(smplx_dict)
413
+
414
+ if vis:
415
+
416
+ (xy, z) = torch.as_tensor(smplx_verts).to(
417
+ self.device).split([2, 1], dim=1)
418
+ smplx_vis = get_visibility(xy, z, torch.as_tensor(
419
+ smplx_faces).to(self.device).long())
420
+
421
+ T_normal_F, T_normal_B = self.render_normal(
422
+ (smplx_verts*torch.tensor([1.0, -1.0, 1.0])).to(self.device),
423
+ smplx_faces.to(self.device))
424
+
425
+ return_dict.update({"T_normal_F": T_normal_F.squeeze(0),
426
+ "T_normal_B": T_normal_B.squeeze(0)})
427
+ query_points = projection(data_dict['samples_geo'],
428
+ data_dict['calib']).float()
429
+
430
+ smplx_sdf, smplx_norm, smplx_cmap, smplx_vis = cal_sdf_batch(
431
+ smplx_verts.unsqueeze(0).to(self.device),
432
+ smplx_faces.unsqueeze(0).to(self.device),
433
+ smplx_cmap.unsqueeze(0).to(self.device),
434
+ smplx_vis.unsqueeze(0).to(self.device),
435
+ query_points.unsqueeze(0).contiguous().to(self.device))
436
+
437
+ return_dict.update({
438
+ 'smpl_feat':
439
+ torch.cat(
440
+ (smplx_sdf[0].detach().cpu(),
441
+ smplx_cmap[0].detach().cpu(),
442
+ smplx_norm[0].detach().cpu(),
443
+ smplx_vis[0].detach().cpu()),
444
+ dim=1)
445
+ })
446
+
447
+ return return_dict
448
+
449
+ def load_smpl_voxel(self, data_dict):
450
+
451
+ smpl_verts, smpl_faces, pad_v_num, pad_f_num = self.compute_voxel_verts(
452
+ data_dict, self.noise_type,
453
+ self.noise_scale) # compute using smpl model
454
+ smpl_verts = projection(smpl_verts, data_dict['calib'])
455
+
456
+ smpl_verts *= 0.5
457
+
458
+ return {
459
+ 'voxel_verts': smpl_verts,
460
+ 'voxel_faces': smpl_faces,
461
+ 'pad_v_num': pad_v_num,
462
+ 'pad_f_num': pad_f_num
463
+ }
464
+
465
+ def get_sampling_geo(self, data_dict, is_valid=False, is_sdf=False):
466
+
467
+ mesh = data_dict['mesh']
468
+ calib = data_dict['calib']
469
+
470
+ # Samples are around the true surface with an offset
471
+ n_samples_surface = 4 * self.opt.num_sample_geo
472
+ vert_ids = np.arange(mesh.verts.shape[0])
473
+ thickness_sample_ratio = np.ones_like(vert_ids).astype(np.float32)
474
+
475
+ thickness_sample_ratio /= thickness_sample_ratio.sum()
476
+
477
+ samples_surface_ids = np.random.choice(vert_ids,
478
+ n_samples_surface,
479
+ replace=True,
480
+ p=thickness_sample_ratio)
481
+
482
+ samples_normal_ids = np.random.choice(vert_ids,
483
+ self.opt.num_sample_geo // 2,
484
+ replace=False,
485
+ p=thickness_sample_ratio)
486
+
487
+ surf_samples = mesh.verts[samples_normal_ids, :]
488
+ surf_normals = mesh.vert_normals[samples_normal_ids, :]
489
+
490
+ samples_surface = mesh.verts[samples_surface_ids, :]
491
+
492
+ # Sampling offsets are random noise with constant scale (15cm - 20cm)
493
+ offset = np.random.normal(scale=self.opt.sigma_geo,
494
+ size=(n_samples_surface, 1))
495
+ samples_surface += mesh.vert_normals[samples_surface_ids, :] * offset
496
+
497
+ # Uniform samples in [-1, 1]
498
+ calib_inv = np.linalg.inv(calib)
499
+ n_samples_space = self.opt.num_sample_geo // 4
500
+ samples_space_img = 2.0 * np.random.rand(n_samples_space, 3) - 1.0
501
+ samples_space = projection(samples_space_img, calib_inv)
502
+
503
+ # z-ray direction samples
504
+ if self.opt.zray_type and not is_valid:
505
+ n_samples_rayz = self.opt.ray_sample_num
506
+ samples_surface_cube = projection(samples_surface, calib)
507
+ samples_surface_cube_repeat = np.repeat(samples_surface_cube,
508
+ n_samples_rayz,
509
+ axis=0)
510
+
511
+ thickness_repeat = np.repeat(0.5 *
512
+ np.ones_like(samples_surface_ids),
513
+ n_samples_rayz,
514
+ axis=0)
515
+
516
+ noise_repeat = np.random.normal(scale=0.40,
517
+ size=(n_samples_surface *
518
+ n_samples_rayz, ))
519
+ samples_surface_cube_repeat[:,
520
+ -1] += thickness_repeat * noise_repeat
521
+ samples_surface_rayz = projection(samples_surface_cube_repeat,
522
+ calib_inv)
523
+
524
+ samples = np.concatenate(
525
+ [samples_surface, samples_space, samples_surface_rayz], 0)
526
+ else:
527
+ samples = np.concatenate([samples_surface, samples_space], 0)
528
+
529
+ np.random.shuffle(samples)
530
+
531
+ # labels: in->1.0; out->0.0.
532
+ if is_sdf:
533
+ sdfs = mesh.get_sdf(samples)
534
+ inside_samples = samples[sdfs < 0]
535
+ outside_samples = samples[sdfs >= 0]
536
+
537
+ inside_sdfs = sdfs[sdfs < 0]
538
+ outside_sdfs = sdfs[sdfs >= 0]
539
+ else:
540
+ inside = mesh.contains(samples)
541
+ inside_samples = samples[inside >= 0.5]
542
+ outside_samples = samples[inside < 0.5]
543
+
544
+ nin = inside_samples.shape[0]
545
+
546
+ if nin > self.opt.num_sample_geo // 2:
547
+ inside_samples = inside_samples[:self.opt.num_sample_geo // 2]
548
+ outside_samples = outside_samples[:self.opt.num_sample_geo // 2]
549
+ if is_sdf:
550
+ inside_sdfs = inside_sdfs[:self.opt.num_sample_geo // 2]
551
+ outside_sdfs = outside_sdfs[:self.opt.num_sample_geo // 2]
552
+ else:
553
+ outside_samples = outside_samples[:(self.opt.num_sample_geo - nin)]
554
+ if is_sdf:
555
+ outside_sdfs = outside_sdfs[:(self.opt.num_sample_geo - nin)]
556
+
557
+ if is_sdf:
558
+ samples = np.concatenate(
559
+ [inside_samples, outside_samples, surf_samples], 0)
560
+
561
+ labels = np.concatenate([
562
+ inside_sdfs, outside_sdfs, 0.0 * np.ones(surf_samples.shape[0])
563
+ ])
564
+
565
+ normals = np.zeros_like(samples)
566
+ normals[-self.opt.num_sample_geo // 2:, :] = surf_normals
567
+
568
+ # convert sdf from [-14, 130] to [0, 1]
569
+ # outside: 0, inside: 1
570
+ # Note: Marching cubes is defined on occupancy space (inside=1.0, outside=0.0)
571
+
572
+ labels = -labels.clip(min=-self.sdf_clip, max=self.sdf_clip)
573
+ labels += self.sdf_clip
574
+ labels /= (self.sdf_clip * 2)
575
+
576
+ else:
577
+ samples = np.concatenate([inside_samples, outside_samples])
578
+ labels = np.concatenate([
579
+ np.ones(inside_samples.shape[0]),
580
+ np.zeros(outside_samples.shape[0])
581
+ ])
582
+
583
+ normals = np.zeros_like(samples)
584
+
585
+ samples = torch.from_numpy(samples).float()
586
+ labels = torch.from_numpy(labels).float()
587
+ normals = torch.from_numpy(normals).float()
588
+
589
+ return {'samples_geo': samples, 'labels_geo': labels}
lib/dataset/TestDataset.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import os
19
+
20
+ import lib.smplx as smplx
21
+ from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues
22
+ from lib.pymaf.utils.imutils import process_image
23
+ from lib.pymaf.core import path_config
24
+ from lib.pymaf.models import pymaf_net
25
+ from lib.common.config import cfg
26
+ from lib.common.render import Render
27
+ from lib.dataset.body_model import TetraSMPLModel
28
+ from lib.dataset.mesh_util import get_visibility, SMPLX
29
+ import os.path as osp
30
+ import torch
31
+ import numpy as np
32
+ import random
33
+ from termcolor import colored
34
+ from PIL import ImageFile
35
+ from huggingface_hub import cached_download
36
+
37
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
38
+
39
+
40
+ class TestDataset():
41
+ def __init__(self, cfg, device):
42
+
43
+ random.seed(1993)
44
+
45
+ self.image_path = cfg['image_path']
46
+ self.seg_dir = cfg['seg_dir']
47
+ self.has_det = cfg['has_det']
48
+ self.hps_type = cfg['hps_type']
49
+ self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx'
50
+ self.smpl_gender = 'neutral'
51
+
52
+ self.device = device
53
+
54
+ self.subject_list = [self.image_path]
55
+
56
+ # smpl related
57
+ self.smpl_data = SMPLX()
58
+
59
+ self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(
60
+ model_path=self.smpl_data.model_dir,
61
+ gender=smpl_gender,
62
+ model_type=smpl_type,
63
+ ext='npz')
64
+
65
+ # Load SMPL model
66
+ self.smpl_model = self.get_smpl_model(
67
+ self.smpl_type, self.smpl_gender).to(self.device)
68
+ self.faces = self.smpl_model.faces
69
+
70
+ self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS,
71
+ pretrained=True).to(self.device)
72
+ self.hps.load_state_dict(torch.load(
73
+ path_config.CHECKPOINT_FILE)['model'],
74
+ strict=True)
75
+ self.hps.eval()
76
+
77
+ print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green"))
78
+
79
+ self.render = Render(size=512, device=device)
80
+
81
+ def __len__(self):
82
+ return len(self.subject_list)
83
+
84
+ def compute_vis_cmap(self, smpl_verts, smpl_faces):
85
+
86
+ (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1)
87
+ smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long())
88
+ if self.smpl_type == 'smpl':
89
+ smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0]))
90
+ else:
91
+ smplx_ind = np.arange(smpl_vis.shape[0])
92
+ smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind)
93
+
94
+ return {
95
+ 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device),
96
+ 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device),
97
+ 'smpl_verts': smpl_verts.unsqueeze(0)
98
+ }
99
+
100
+ def compute_voxel_verts(self, body_pose, global_orient, betas, trans,
101
+ scale):
102
+
103
+ smpl_path = cached_download(osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl"), use_auth_token=os.environ['ICON'])
104
+ tetra_path = cached_download(osp.join(self.smpl_data.tedra_dir,
105
+ 'tetra_neutral_adult_smpl.npz'), use_auth_token=os.environ['ICON'])
106
+ smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult')
107
+
108
+ pose = torch.cat([global_orient[0], body_pose[0]], dim=0)
109
+ smpl_model.set_params(rotation_matrix_to_angle_axis(pose),
110
+ beta=betas[0])
111
+
112
+ verts = np.concatenate(
113
+ [smpl_model.verts, smpl_model.verts_added],
114
+ axis=0) * scale.item() + trans.detach().cpu().numpy()
115
+ faces = np.loadtxt(cached_download(osp.join(self.smpl_data.tedra_dir,
116
+ 'tetrahedrons_neutral_adult.txt'), use_auth_token=os.environ['ICON']),
117
+ dtype=np.int32) - 1
118
+
119
+ pad_v_num = int(8000 - verts.shape[0])
120
+ pad_f_num = int(25100 - faces.shape[0])
121
+
122
+ verts = np.pad(verts, ((0, pad_v_num), (0, 0)),
123
+ mode='constant',
124
+ constant_values=0.0).astype(np.float32) * 0.5
125
+ faces = np.pad(faces, ((0, pad_f_num), (0, 0)),
126
+ mode='constant',
127
+ constant_values=0.0).astype(np.int32)
128
+
129
+ verts[:, 2] *= -1.0
130
+
131
+ voxel_dict = {
132
+ 'voxel_verts':
133
+ torch.from_numpy(verts).to(self.device).unsqueeze(0).float(),
134
+ 'voxel_faces':
135
+ torch.from_numpy(faces).to(self.device).unsqueeze(0).long(),
136
+ 'pad_v_num':
137
+ torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(),
138
+ 'pad_f_num':
139
+ torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long()
140
+ }
141
+
142
+ return voxel_dict
143
+
144
+ def __getitem__(self, index):
145
+
146
+ img_path = self.subject_list[index]
147
+ img_name = img_path.split("/")[-1].rsplit(".", 1)[0]
148
+
149
+ if self.seg_dir is None:
150
+ img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image(
151
+ img_path, self.hps_type, 512, self.device)
152
+
153
+ data_dict = {
154
+ 'name': img_name,
155
+ 'image': img_icon.to(self.device).unsqueeze(0),
156
+ 'ori_image': img_ori,
157
+ 'mask': img_mask,
158
+ 'uncrop_param': uncrop_param
159
+ }
160
+
161
+ else:
162
+ img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image(
163
+ img_path, self.hps_type, 512, self.device,
164
+ seg_path=os.path.join(self.seg_dir, f'{img_name}.json'))
165
+ data_dict = {
166
+ 'name': img_name,
167
+ 'image': img_icon.to(self.device).unsqueeze(0),
168
+ 'ori_image': img_ori,
169
+ 'mask': img_mask,
170
+ 'uncrop_param': uncrop_param,
171
+ 'segmentations': segmentations
172
+ }
173
+
174
+ with torch.no_grad():
175
+ # import ipdb; ipdb.set_trace()
176
+ preds_dict = self.hps.forward(img_hps)
177
+
178
+ data_dict['smpl_faces'] = torch.Tensor(
179
+ self.faces.astype(np.int16)).long().unsqueeze(0).to(
180
+ self.device)
181
+
182
+ if self.hps_type == 'pymaf':
183
+ output = preds_dict['smpl_out'][-1]
184
+ scale, tranX, tranY = output['theta'][0, :3]
185
+ data_dict['betas'] = output['pred_shape']
186
+ data_dict['body_pose'] = output['rotmat'][:, 1:]
187
+ data_dict['global_orient'] = output['rotmat'][:, 0:1]
188
+ data_dict['smpl_verts'] = output['verts']
189
+
190
+ elif self.hps_type == 'pare':
191
+ data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:]
192
+ data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1]
193
+ data_dict['betas'] = preds_dict['pred_shape']
194
+ data_dict['smpl_verts'] = preds_dict['smpl_vertices']
195
+ scale, tranX, tranY = preds_dict['pred_cam'][0, :3]
196
+
197
+ elif self.hps_type == 'pixie':
198
+ data_dict.update(preds_dict)
199
+ data_dict['body_pose'] = preds_dict['body_pose']
200
+ data_dict['global_orient'] = preds_dict['global_pose']
201
+ data_dict['betas'] = preds_dict['shape']
202
+ data_dict['smpl_verts'] = preds_dict['vertices']
203
+ scale, tranX, tranY = preds_dict['cam'][0, :3]
204
+
205
+ elif self.hps_type == 'hybrik':
206
+ data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:]
207
+ data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]]
208
+ data_dict['betas'] = preds_dict['pred_shape']
209
+ data_dict['smpl_verts'] = preds_dict['pred_vertices']
210
+ scale, tranX, tranY = preds_dict['pred_camera'][0, :3]
211
+ scale = scale * 2
212
+
213
+ elif self.hps_type == 'bev':
214
+ data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[
215
+ [0], :10].to(self.device).float()
216
+ pred_thetas = batch_rodrigues(torch.from_numpy(
217
+ preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float()
218
+ data_dict['body_pose'] = pred_thetas[1:][None].to(self.device)
219
+ data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device)
220
+ data_dict['smpl_verts'] = torch.from_numpy(
221
+ preds_dict['verts'][[0]]).to(self.device).float()
222
+ tranX = preds_dict['cam_trans'][0, 0]
223
+ tranY = preds_dict['cam'][0, 1] + 0.28
224
+ scale = preds_dict['cam'][0, 0] * 1.1
225
+
226
+ data_dict['scale'] = scale
227
+ data_dict['trans'] = torch.tensor(
228
+ [tranX, tranY, 0.0]).to(self.device).float()
229
+
230
+ # data_dict info (key-shape):
231
+ # scale, tranX, tranY - tensor.float
232
+ # betas - [1,10] / [1, 200]
233
+ # body_pose - [1, 23, 3, 3] / [1, 21, 3, 3]
234
+ # global_orient - [1, 1, 3, 3]
235
+ # smpl_verts - [1, 6890, 3] / [1, 10475, 3]
236
+
237
+ # from rot_mat to rot_6d for better optimization
238
+ N_body = data_dict["body_pose"].shape[1]
239
+ data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body,-1)
240
+ data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1,-1)
241
+
242
+ return data_dict
243
+
244
+ def render_normal(self, verts, faces):
245
+
246
+ # render optimized mesh (normal, T_normal, image [-1,1])
247
+ self.render.load_meshes(verts, faces)
248
+ return self.render.get_rgb_image()
249
+
250
+ def render_depth(self, verts, faces):
251
+
252
+ # render optimized mesh (normal, T_normal, image [-1,1])
253
+ self.render.load_meshes(verts, faces)
254
+ return self.render.get_depth_map(cam_ids=[0, 2])
lib/dataset/body_model.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import numpy as np
19
+ import pickle
20
+ import torch
21
+ import os
22
+
23
+
24
+ class SMPLModel():
25
+ def __init__(self, model_path, age):
26
+ """
27
+ SMPL model.
28
+
29
+ Parameter:
30
+ ---------
31
+ model_path: Path to the SMPL model parameters, pre-processed by
32
+ `preprocess.py`.
33
+
34
+ """
35
+ with open(model_path, 'rb') as f:
36
+ params = pickle.load(f, encoding='latin1')
37
+
38
+ self.J_regressor = params['J_regressor']
39
+ self.weights = np.asarray(params['weights'])
40
+ self.posedirs = np.asarray(params['posedirs'])
41
+ self.v_template = np.asarray(params['v_template'])
42
+ self.shapedirs = np.asarray(params['shapedirs'])
43
+ self.faces = np.asarray(params['f'])
44
+ self.kintree_table = np.asarray(params['kintree_table'])
45
+
46
+ self.pose_shape = [24, 3]
47
+ self.beta_shape = [10]
48
+ self.trans_shape = [3]
49
+
50
+ if age == 'kid':
51
+ v_template_smil = np.load(
52
+ os.path.join(os.path.dirname(model_path),
53
+ "smpl/smpl_kid_template.npy"))
54
+ v_template_smil -= np.mean(v_template_smil, axis=0)
55
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template,
56
+ axis=2)
57
+ self.shapedirs = np.concatenate(
58
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
59
+ axis=2)
60
+ self.beta_shape[0] += 1
61
+
62
+ id_to_col = {
63
+ self.kintree_table[1, i]: i
64
+ for i in range(self.kintree_table.shape[1])
65
+ }
66
+ self.parent = {
67
+ i: id_to_col[self.kintree_table[0, i]]
68
+ for i in range(1, self.kintree_table.shape[1])
69
+ }
70
+
71
+ self.pose = np.zeros(self.pose_shape)
72
+ self.beta = np.zeros(self.beta_shape)
73
+ self.trans = np.zeros(self.trans_shape)
74
+
75
+ self.verts = None
76
+ self.J = None
77
+ self.R = None
78
+ self.G = None
79
+
80
+ self.update()
81
+
82
+ def set_params(self, pose=None, beta=None, trans=None):
83
+ """
84
+ Set pose, shape, and/or translation parameters of SMPL model. Verices of the
85
+ model will be updated and returned.
86
+
87
+ Prameters:
88
+ ---------
89
+ pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
90
+ relative to parent joint. For root joint it's global orientation.
91
+ Represented in a axis-angle format.
92
+
93
+ beta: Parameter for model shape. A vector of shape [10]. Coefficients for
94
+ PCA component. Only 10 components were released by MPI.
95
+
96
+ trans: Global translation of shape [3].
97
+
98
+ Return:
99
+ ------
100
+ Updated vertices.
101
+
102
+ """
103
+ if pose is not None:
104
+ self.pose = pose
105
+ if beta is not None:
106
+ self.beta = beta
107
+ if trans is not None:
108
+ self.trans = trans
109
+ self.update()
110
+ return self.verts
111
+
112
+ def update(self):
113
+ """
114
+ Called automatically when parameters are updated.
115
+
116
+ """
117
+ # how beta affect body shape
118
+ v_shaped = self.shapedirs.dot(self.beta) + self.v_template
119
+ # joints location
120
+ self.J = self.J_regressor.dot(v_shaped)
121
+ pose_cube = self.pose.reshape((-1, 1, 3))
122
+ # rotation matrix for each joint
123
+ self.R = self.rodrigues(pose_cube)
124
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
125
+ (self.R.shape[0] - 1, 3, 3))
126
+ lrotmin = (self.R[1:] - I_cube).ravel()
127
+ # how pose affect body shape in zero pose
128
+ v_posed = v_shaped + self.posedirs.dot(lrotmin)
129
+ # world transformation of each joint
130
+ G = np.empty((self.kintree_table.shape[1], 4, 4))
131
+ G[0] = self.with_zeros(
132
+ np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
133
+ for i in range(1, self.kintree_table.shape[1]):
134
+ G[i] = G[self.parent[i]].dot(
135
+ self.with_zeros(
136
+ np.hstack([
137
+ self.R[i],
138
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
139
+ [3, 1]))
140
+ ])))
141
+ # remove the transformation due to the rest pose
142
+ G = G - self.pack(
143
+ np.matmul(
144
+ G,
145
+ np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
146
+ # transformation of each vertex
147
+ T = np.tensordot(self.weights, G, axes=[[1], [0]])
148
+ rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
149
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
150
+ 4])[:, :3]
151
+ self.verts = v + self.trans.reshape([1, 3])
152
+ self.G = G
153
+
154
+ def rodrigues(self, r):
155
+ """
156
+ Rodrigues' rotation formula that turns axis-angle vector into rotation
157
+ matrix in a batch-ed manner.
158
+
159
+ Parameter:
160
+ ----------
161
+ r: Axis-angle rotation vector of shape [batch_size, 1, 3].
162
+
163
+ Return:
164
+ -------
165
+ Rotation matrix of shape [batch_size, 3, 3].
166
+
167
+ """
168
+ theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
169
+ # avoid zero divide
170
+ theta = np.maximum(theta, np.finfo(np.float64).tiny)
171
+ r_hat = r / theta
172
+ cos = np.cos(theta)
173
+ z_stick = np.zeros(theta.shape[0])
174
+ m = np.dstack([
175
+ z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
176
+ -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
177
+ ]).reshape([-1, 3, 3])
178
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
179
+ [theta.shape[0], 3, 3])
180
+ A = np.transpose(r_hat, axes=[0, 2, 1])
181
+ B = r_hat
182
+ dot = np.matmul(A, B)
183
+ R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
184
+ return R
185
+
186
+ def with_zeros(self, x):
187
+ """
188
+ Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
189
+
190
+ Parameter:
191
+ ---------
192
+ x: Matrix to be appended.
193
+
194
+ Return:
195
+ ------
196
+ Matrix after appending of shape [4,4]
197
+
198
+ """
199
+ return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
200
+
201
+ def pack(self, x):
202
+ """
203
+ Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
204
+ manner.
205
+
206
+ Parameter:
207
+ ----------
208
+ x: Matrices to be appended of shape [batch_size, 4, 1]
209
+
210
+ Return:
211
+ ------
212
+ Matrix of shape [batch_size, 4, 4] after appending.
213
+
214
+ """
215
+ return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
216
+
217
+ def save_to_obj(self, path):
218
+ """
219
+ Save the SMPL model into .obj file.
220
+
221
+ Parameter:
222
+ ---------
223
+ path: Path to save.
224
+
225
+ """
226
+ with open(path, 'w') as fp:
227
+ for v in self.verts:
228
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
229
+ for f in self.faces + 1:
230
+ fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
231
+
232
+
233
+ class TetraSMPLModel():
234
+ def __init__(self,
235
+ model_path,
236
+ model_addition_path,
237
+ age='adult',
238
+ v_template=None):
239
+ """
240
+ SMPL model.
241
+
242
+ Parameter:
243
+ ---------
244
+ model_path: Path to the SMPL model parameters, pre-processed by
245
+ `preprocess.py`.
246
+
247
+ """
248
+ with open(model_path, 'rb') as f:
249
+ params = pickle.load(f, encoding='latin1')
250
+
251
+ self.J_regressor = params['J_regressor']
252
+ self.weights = np.asarray(params['weights'])
253
+ self.posedirs = np.asarray(params['posedirs'])
254
+
255
+ if v_template is not None:
256
+ self.v_template = v_template
257
+ else:
258
+ self.v_template = np.asarray(params['v_template'])
259
+
260
+ self.shapedirs = np.asarray(params['shapedirs'])
261
+ self.faces = np.asarray(params['f'])
262
+ self.kintree_table = np.asarray(params['kintree_table'])
263
+
264
+ params_added = np.load(model_addition_path)
265
+ self.v_template_added = params_added['v_template_added']
266
+ self.weights_added = params_added['weights_added']
267
+ self.shapedirs_added = params_added['shapedirs_added']
268
+ self.posedirs_added = params_added['posedirs_added']
269
+ self.tetrahedrons = params_added['tetrahedrons']
270
+
271
+ id_to_col = {
272
+ self.kintree_table[1, i]: i
273
+ for i in range(self.kintree_table.shape[1])
274
+ }
275
+ self.parent = {
276
+ i: id_to_col[self.kintree_table[0, i]]
277
+ for i in range(1, self.kintree_table.shape[1])
278
+ }
279
+
280
+ self.pose_shape = [24, 3]
281
+ self.beta_shape = [10]
282
+ self.trans_shape = [3]
283
+
284
+ if age == 'kid':
285
+ v_template_smil = np.load(
286
+ os.path.join(os.path.dirname(model_path),
287
+ "smpl/smpl_kid_template.npy"))
288
+ v_template_smil -= np.mean(v_template_smil, axis=0)
289
+ v_template_diff = np.expand_dims(v_template_smil - self.v_template,
290
+ axis=2)
291
+ self.shapedirs = np.concatenate(
292
+ (self.shapedirs[:, :, :self.beta_shape[0]], v_template_diff),
293
+ axis=2)
294
+ self.beta_shape[0] += 1
295
+
296
+ self.pose = np.zeros(self.pose_shape)
297
+ self.beta = np.zeros(self.beta_shape)
298
+ self.trans = np.zeros(self.trans_shape)
299
+
300
+ self.verts = None
301
+ self.verts_added = None
302
+ self.J = None
303
+ self.R = None
304
+ self.G = None
305
+
306
+ self.update()
307
+
308
+ def set_params(self, pose=None, beta=None, trans=None):
309
+ """
310
+ Set pose, shape, and/or translation parameters of SMPL model. Verices of the
311
+ model will be updated and returned.
312
+
313
+ Prameters:
314
+ ---------
315
+ pose: Also known as 'theta', a [24,3] matrix indicating child joint rotation
316
+ relative to parent joint. For root joint it's global orientation.
317
+ Represented in a axis-angle format.
318
+
319
+ beta: Parameter for model shape. A vector of shape [10]. Coefficients for
320
+ PCA component. Only 10 components were released by MPI.
321
+
322
+ trans: Global translation of shape [3].
323
+
324
+ Return:
325
+ ------
326
+ Updated vertices.
327
+
328
+ """
329
+
330
+ if torch.is_tensor(pose):
331
+ pose = pose.detach().cpu().numpy()
332
+ if torch.is_tensor(beta):
333
+ beta = beta.detach().cpu().numpy()
334
+
335
+ if pose is not None:
336
+ self.pose = pose
337
+ if beta is not None:
338
+ self.beta = beta
339
+ if trans is not None:
340
+ self.trans = trans
341
+ self.update()
342
+ return self.verts
343
+
344
+ def update(self):
345
+ """
346
+ Called automatically when parameters are updated.
347
+
348
+ """
349
+ # how beta affect body shape
350
+ v_shaped = self.shapedirs.dot(self.beta) + self.v_template
351
+ v_shaped_added = self.shapedirs_added.dot(
352
+ self.beta) + self.v_template_added
353
+ # joints location
354
+ self.J = self.J_regressor.dot(v_shaped)
355
+ pose_cube = self.pose.reshape((-1, 1, 3))
356
+ # rotation matrix for each joint
357
+ self.R = self.rodrigues(pose_cube)
358
+ I_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
359
+ (self.R.shape[0] - 1, 3, 3))
360
+ lrotmin = (self.R[1:] - I_cube).ravel()
361
+ # how pose affect body shape in zero pose
362
+ v_posed = v_shaped + self.posedirs.dot(lrotmin)
363
+ v_posed_added = v_shaped_added + self.posedirs_added.dot(lrotmin)
364
+ # world transformation of each joint
365
+ G = np.empty((self.kintree_table.shape[1], 4, 4))
366
+ G[0] = self.with_zeros(
367
+ np.hstack((self.R[0], self.J[0, :].reshape([3, 1]))))
368
+ for i in range(1, self.kintree_table.shape[1]):
369
+ G[i] = G[self.parent[i]].dot(
370
+ self.with_zeros(
371
+ np.hstack([
372
+ self.R[i],
373
+ ((self.J[i, :] - self.J[self.parent[i], :]).reshape(
374
+ [3, 1]))
375
+ ])))
376
+ # remove the transformation due to the rest pose
377
+ G = G - self.pack(
378
+ np.matmul(
379
+ G,
380
+ np.hstack([self.J, np.zeros([24, 1])]).reshape([24, 4, 1])))
381
+ self.G = G
382
+ # transformation of each vertex
383
+ T = np.tensordot(self.weights, G, axes=[[1], [0]])
384
+ rest_shape_h = np.hstack((v_posed, np.ones([v_posed.shape[0], 1])))
385
+ v = np.matmul(T, rest_shape_h.reshape([-1, 4, 1])).reshape([-1,
386
+ 4])[:, :3]
387
+ self.verts = v + self.trans.reshape([1, 3])
388
+ T_added = np.tensordot(self.weights_added, G, axes=[[1], [0]])
389
+ rest_shape_added_h = np.hstack(
390
+ (v_posed_added, np.ones([v_posed_added.shape[0], 1])))
391
+ v_added = np.matmul(T_added,
392
+ rest_shape_added_h.reshape([-1, 4,
393
+ 1])).reshape([-1, 4
394
+ ])[:, :3]
395
+ self.verts_added = v_added + self.trans.reshape([1, 3])
396
+
397
+ def rodrigues(self, r):
398
+ """
399
+ Rodrigues' rotation formula that turns axis-angle vector into rotation
400
+ matrix in a batch-ed manner.
401
+
402
+ Parameter:
403
+ ----------
404
+ r: Axis-angle rotation vector of shape [batch_size, 1, 3].
405
+
406
+ Return:
407
+ -------
408
+ Rotation matrix of shape [batch_size, 3, 3].
409
+
410
+ """
411
+ theta = np.linalg.norm(r, axis=(1, 2), keepdims=True)
412
+ # avoid zero divide
413
+ theta = np.maximum(theta, np.finfo(np.float64).tiny)
414
+ r_hat = r / theta
415
+ cos = np.cos(theta)
416
+ z_stick = np.zeros(theta.shape[0])
417
+ m = np.dstack([
418
+ z_stick, -r_hat[:, 0, 2], r_hat[:, 0, 1], r_hat[:, 0, 2], z_stick,
419
+ -r_hat[:, 0, 0], -r_hat[:, 0, 1], r_hat[:, 0, 0], z_stick
420
+ ]).reshape([-1, 3, 3])
421
+ i_cube = np.broadcast_to(np.expand_dims(np.eye(3), axis=0),
422
+ [theta.shape[0], 3, 3])
423
+ A = np.transpose(r_hat, axes=[0, 2, 1])
424
+ B = r_hat
425
+ dot = np.matmul(A, B)
426
+ R = cos * i_cube + (1 - cos) * dot + np.sin(theta) * m
427
+ return R
428
+
429
+ def with_zeros(self, x):
430
+ """
431
+ Append a [0, 0, 0, 1] vector to a [3, 4] matrix.
432
+
433
+ Parameter:
434
+ ---------
435
+ x: Matrix to be appended.
436
+
437
+ Return:
438
+ ------
439
+ Matrix after appending of shape [4,4]
440
+
441
+ """
442
+ return np.vstack((x, np.array([[0.0, 0.0, 0.0, 1.0]])))
443
+
444
+ def pack(self, x):
445
+ """
446
+ Append zero matrices of shape [4, 3] to vectors of [4, 1] shape in a batched
447
+ manner.
448
+
449
+ Parameter:
450
+ ----------
451
+ x: Matrices to be appended of shape [batch_size, 4, 1]
452
+
453
+ Return:
454
+ ------
455
+ Matrix of shape [batch_size, 4, 4] after appending.
456
+
457
+ """
458
+ return np.dstack((np.zeros((x.shape[0], 4, 3)), x))
459
+
460
+ def save_mesh_to_obj(self, path):
461
+ """
462
+ Save the SMPL model into .obj file.
463
+
464
+ Parameter:
465
+ ---------
466
+ path: Path to save.
467
+
468
+ """
469
+ with open(path, 'w') as fp:
470
+ for v in self.verts:
471
+ fp.write('v %f %f %f\n' % (v[0], v[1], v[2]))
472
+ for f in self.faces + 1:
473
+ fp.write('f %d %d %d\n' % (f[0], f[1], f[2]))
474
+
475
+ def save_tetrahedron_to_obj(self, path):
476
+ """
477
+ Save the tetrahedron SMPL model into .obj file.
478
+
479
+ Parameter:
480
+ ---------
481
+ path: Path to save.
482
+
483
+ """
484
+
485
+ with open(path, 'w') as fp:
486
+ for v in self.verts:
487
+ fp.write('v %f %f %f 1 0 0\n' % (v[0], v[1], v[2]))
488
+ for va in self.verts_added:
489
+ fp.write('v %f %f %f 0 0 1\n' % (va[0], va[1], va[2]))
490
+ for t in self.tetrahedrons + 1:
491
+ fp.write('f %d %d %d\n' % (t[0], t[2], t[1]))
492
+ fp.write('f %d %d %d\n' % (t[0], t[3], t[2]))
493
+ fp.write('f %d %d %d\n' % (t[0], t[1], t[3]))
494
+ fp.write('f %d %d %d\n' % (t[1], t[2], t[3]))
lib/dataset/hoppeMesh.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # -*- coding: utf-8 -*-
3
+
4
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
5
+ # holder of all proprietary rights on this computer program.
6
+ # You can only use this computer program if you have closed
7
+ # a license agreement with MPG or you get the right to use the computer
8
+ # program from someone who is authorized to grant you that right.
9
+ # Any use of the computer program without a valid license is prohibited and
10
+ # liable to prosecution.
11
+ #
12
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
13
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
14
+ # for Intelligent Systems. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ import numpy as np
19
+ from scipy.spatial import cKDTree
20
+ import trimesh
21
+
22
+ import logging
23
+
24
+ logging.getLogger("trimesh").setLevel(logging.ERROR)
25
+
26
+
27
+ def save_obj_mesh(mesh_path, verts, faces):
28
+ file = open(mesh_path, 'w')
29
+ for v in verts:
30
+ file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2]))
31
+ for f in faces:
32
+ f_plus = f + 1
33
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
34
+ file.close()
35
+
36
+
37
+ def save_obj_mesh_with_color(mesh_path, verts, faces, colors):
38
+ file = open(mesh_path, 'w')
39
+
40
+ for idx, v in enumerate(verts):
41
+ c = colors[idx]
42
+ file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' %
43
+ (v[0], v[1], v[2], c[0], c[1], c[2]))
44
+ for f in faces:
45
+ f_plus = f + 1
46
+ file.write('f %d %d %d\n' % (f_plus[0], f_plus[1], f_plus[2]))
47
+ file.close()
48
+
49
+
50
+ def save_ply(mesh_path, points, rgb):
51
+ '''
52
+ Save the visualization of sampling to a ply file.
53
+ Red points represent positive predictions.
54
+ Green points represent negative predictions.
55
+ :param mesh_path: File name to save
56
+ :param points: [N, 3] array of points
57
+ :param rgb: [N, 3] array of rgb values in the range [0~1]
58
+ :return:
59
+ '''
60
+ to_save = np.concatenate([points, rgb * 255], axis=-1)
61
+ return np.savetxt(
62
+ mesh_path,
63
+ to_save,
64
+ fmt='%.6f %.6f %.6f %d %d %d',
65
+ comments='',
66
+ header=(
67
+ 'ply\nformat ascii 1.0\nelement vertex {:d}\n' +
68
+ 'property float x\nproperty float y\nproperty float z\n' +
69
+ 'property uchar red\nproperty uchar green\nproperty uchar blue\n' +
70
+ 'end_header').format(points.shape[0]))
71
+
72
+
73
+ class HoppeMesh:
74
+ def __init__(self, verts, faces, vert_normals, face_normals):
75
+ '''
76
+ The HoppeSDF calculates signed distance towards a predefined oriented point cloud
77
+ http://hhoppe.com/recon.pdf
78
+ For clean and high-resolution pcl data, this is the fastest and accurate approximation of sdf
79
+ :param points: pts
80
+ :param normals: normals
81
+ '''
82
+ self.verts = verts # [n, 3]
83
+ self.faces = faces # [m, 3]
84
+ self.vert_normals = vert_normals # [n, 3]
85
+ self.face_normals = face_normals # [m, 3]
86
+
87
+ self.kd_tree = cKDTree(self.verts)
88
+ self.len = len(self.verts)
89
+
90
+ def query(self, points):
91
+ dists, idx = self.kd_tree.query(points, n_jobs=1)
92
+ # FIXME: because the eyebows are removed, cKDTree around eyebows
93
+ # are not accurate. Cause a few false-inside labels here.
94
+ dirs = points - self.verts[idx]
95
+ signs = (dirs * self.vert_normals[idx]).sum(axis=1)
96
+ signs = (signs > 0) * 2 - 1
97
+ return signs * dists
98
+
99
+ def contains(self, points):
100
+
101
+ labels = trimesh.Trimesh(vertices=self.verts,
102
+ faces=self.faces).contains(points)
103
+ return labels
104
+
105
+ def export(self, path):
106
+ if self.colors is not None:
107
+ save_obj_mesh_with_color(path, self.verts, self.faces,
108
+ self.colors[:, 0:3] / 255.0)
109
+ else:
110
+ save_obj_mesh(path, self.verts, self.faces)
111
+
112
+ def export_ply(self, path):
113
+ save_ply(path, self.verts, self.colors[:, 0:3] / 255.0)
114
+
115
+ def triangles(self):
116
+ return self.verts[self.faces] # [n, 3, 3]