daidedou commited on
Commit
e321b92
·
1 Parent(s): 458efe2

forgot a few things lol

Browse files
config/diffusion/dfaust_fmap.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # misc
2
+ misc:
3
+ cuda: True
4
+ device: 0
5
+ checkpoint_interval: 1
6
+ log_interval: 812
7
+ desc: null
8
+ precond: False
9
+ dry_run: False
10
+
11
+ data:
12
+ root_dir: "data_cache"
13
+ name: DFAUST_fmap_30
14
+ n_fmap: 30
15
+ out: "fmap_exps"
16
+ cond: False
17
+ template_path: "data/template.ply"
18
+ normalize: True
19
+ pairs: False
20
+ abs: True
21
+
22
+ add_name:
23
+ do: False
24
+ name: "bis"
25
+
26
+ architecture:
27
+ model: "DiT"
28
+ name_arch: "DiT-S/4"
29
+ input_type: "img"
30
+ cond: False # Conditioning with 3D-CODED
31
+
32
+ ## loss params
33
+ #loss:
34
+ # w_gt: False # if one wants to train as a supervised method, one should set w_gt=True
35
+ # w_ortho: 1 # orthogonal loss for functional map (default: 1)
36
+ # w_Qortho: 0 # orthogonal loss for complex functional map (default: 1)
37
+ # w_bij: 1
38
+ # w_res: 1 # residual loss for functional map (default: 1)
39
+ # w_rank: -0.1
40
+ # w_srnf: 1
41
+ # min_alpha: 1
42
+ # max_alpha: 100
43
+ #
44
+
45
+ hyper_params:
46
+ iterations: 200
47
+ batch_size: 256
48
+ lr: 0.001
49
+ lr_rampup_kimg: 10000 # Learning rate ramp-up duration
50
+ ema_halflife_nshape : 500 # ema half-life of the exponential moving average (EMA) of model weights.
51
+ ema_rampup_ratio : 0.05 # EMA ramp-up coefficient, None = no rampup.
52
+ dropout: 0
53
+ loss_name: 'VPLoss'
54
+ ls : 1 #loss scaling
55
+
56
+ perfs:
57
+ fp16: False
58
+ workers: 1
59
+
60
+ resume:
61
+ pkl: null
62
+ transfer: null
63
+ kimg_per_tick: 5
64
+ snapshot_ticks: 50
65
+ state_dump_ticks: 50
config/matching/diff_mask.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ optimize: False
6
+
7
+ sds_conf:
8
+ train_dir: pretrained
9
+ diff_num_exp: 53
10
+ zoomout: 40
11
+
12
+ deepfeat_conf:
13
+ fmap:
14
+ feat: "xyz"
15
+ n_fmap: 30
16
+ C_in: 3
17
+ n_feat: 128 ## Doesn't change
18
+ lambda_: 2
19
+ use_diff: True
20
+ diffusion:
21
+ abs: True
22
+ normalize: False
23
+ time: 1
24
+ batch_sds: 32
25
+ batch_mask: 200
26
+
27
+
28
+ zo_shot: 150
config/matching/lap_mask.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ optimize: False
6
+
7
+ sds_conf:
8
+ train_dir: pretrained
9
+ diff_num_exp: 53
10
+ zoomout: 40
11
+
12
+ deepfeat_conf:
13
+ fmap:
14
+ feat: "xyz"
15
+ n_fmap: 30
16
+ C_in: 3
17
+ n_feat: 128 ## Doesn't change
18
+ lambda_: 1e-3
19
+ use_resolvent: False ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
20
+ resolvent_gamma: 0.5
21
+
22
+
23
+ zo_shot: 150
config/matching/resol_mask.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ optimize: False
6
+
7
+ sds_conf:
8
+ train_dir: pretrained
9
+ diff_num_exp: 53
10
+ zoomout: 40
11
+
12
+ deepfeat_conf:
13
+ fmap:
14
+ feat: "xyz"
15
+ n_fmap: 30
16
+ C_in: 3
17
+ n_feat: 128 ## Doesn't change
18
+ lambda_: 100
19
+ use_resolvent: True ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
20
+ resolvent_gamma: 0.5
21
+
22
+
23
+ zo_shot: 150
config/matching/sds.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ refine: True
6
+ optimize: True
7
+ oriented: True
8
+
9
+ sds_conf:
10
+ train_dir: pretrained
11
+ diff_num_exp: 53
12
+ zoomout: 40
13
+
14
+ deepfeat_conf:
15
+ fmap:
16
+ feat: "xyz"
17
+ n_fmap: 30
18
+ C_in: 3
19
+ n_feat: 128 ## Doesn't change
20
+ lambda_: 1e-1
21
+ use_diff: True
22
+ diffusion:
23
+ abs: True
24
+ normalize: False
25
+ time: 1
26
+ batch_sds: 32
27
+ batch_mask: 200
28
+
29
+ opt:
30
+ n_loop: 300
31
+ soft_p2p: False
32
+
33
+ loss:
34
+ sds: 1.
35
+ proper: 1.
config/matching/sds_dt4d.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ refine: True
6
+ optimize: True
7
+ rotate: True
8
+
9
+ sds_conf:
10
+ train_dir: fmap_exps
11
+ diff_num_exp: 53
12
+ zoomout: 40
13
+
14
+ deepfeat_conf:
15
+ fmap:
16
+ feat: "xyz"
17
+ n_fmap: 30
18
+ C_in: 3
19
+ n_feat: 128 ## Doesn't change
20
+ lambda_: 1e-3
21
+ use_diff: True
22
+ diffusion:
23
+ abs: True
24
+ normalize: False
25
+ time: 1
26
+ batch_sds: 32
27
+ batch_mask: 200
28
+
29
+ opt:
30
+ n_loop: 1000
31
+ soft_p2p: False
32
+
33
+ loss:
34
+ sds: 0.1
35
+ proper: 1.
config/matching/sds_slow.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ refine: True
6
+ optimize: True
7
+ oriented: True
8
+
9
+ diff_model:
10
+ train_dir: pretrained
11
+
12
+ # diff_model:
13
+ # train_dir: fmap_exps
14
+ # diff_num_exp: 53
15
+
16
+ sds_conf:
17
+ zoomout: 40
18
+
19
+ deepfeat_conf:
20
+ fmap:
21
+ feat: "xyz"
22
+ n_fmap: 30
23
+ C_in: 3
24
+ n_feat: 128 ## Doesn't change
25
+ lambda_: 1e-3
26
+ use_diff: True
27
+ diffusion:
28
+ abs: True
29
+ normalize: False
30
+ time: 1
31
+ batch_sds: 32
32
+ batch_mask: 200
33
+
34
+ opt:
35
+ n_loop: 1000
36
+ soft_p2p: False
37
+
38
+ loss:
39
+ sds: 0.1
40
+ proper: 1.
config/matching/sds_smal.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ sds: True
5
+ refine: True
6
+ optimize: True
7
+ oriented: True
8
+
9
+ sds_conf:
10
+ train_dir: pretrained
11
+ diff_num_exp: 53
12
+ zoomout: 40
13
+
14
+ deepfeat_conf:
15
+ fmap:
16
+ feat: "xyz"
17
+ n_fmap: 30
18
+ C_in: 3
19
+ n_feat: 128 ## Doesn't change
20
+ lambda_: 1e-1
21
+ use_diff: True
22
+ diffusion:
23
+ abs: True
24
+ normalize: False
25
+ time: 1
26
+ batch_sds: 32
27
+ batch_mask: 200
28
+
29
+ opt:
30
+ n_loop: 1000
31
+ soft_p2p: False
32
+
33
+ loss:
34
+ sds: 0.1
35
+ proper: 1.
config/matching/snk.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gpu: 0
2
+ cache: "cache/fmaps"
3
+
4
+ snk: True
5
+ refine: True
6
+ optimize: True
7
+
8
+ deepfeat_conf:
9
+ fmap:
10
+ feat: "xyz"
11
+ n_fmap: 30
12
+ C_in: 3
13
+ n_feat: 128 ## Doesn't change
14
+ lambda_: 100
15
+ use_resolvent: True ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
16
+ resolvent_gamma: 0.5
17
+
18
+ opt:
19
+ n_loop: 1000
20
+ soft_p2p: True
21
+
22
+ loss:
23
+ bij: 1.
24
+ ortho: 1.
25
+ cycle: 1
26
+ mse_rec: 1
27
+ prism_rec: 1
diffu_models/basis_dataset.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+
7
+ N_POSES = 21
8
+
9
+
10
+ class AMASSDataset(torch.utils.data.Dataset):
11
+ def __init__(self, root_path, version='version0', subset='train', basis_path='base_amass.npy',
12
+ sample_interval=None, num_coeffs=100, return_shape=False,
13
+ normalize=True, min_max=False):
14
+
15
+ self.root_path = root_path
16
+ self.version = version
17
+ assert subset in ['train', 'valid', 'test']
18
+ self.subset = subset
19
+ self.sample_interval = sample_interval
20
+ self.return_shape = return_shape
21
+ self.normalize = normalize
22
+ self.min_max = min_max
23
+ self.num_coeffs = num_coeffs
24
+ self.poses, self.shapes = self.read_data()
25
+
26
+ if self.sample_interval:
27
+ self._sample(sample_interval)
28
+ if self.normalize:
29
+ if self.min_max:
30
+ self.min_poses, self.max_poses, self.min_shapes, self.max_shapes = self.Normalize()
31
+ else:
32
+ self.mean_poses, self.std_poses, self.mean_shapes, self.std_shapes = self.Normalize()
33
+
34
+ self.real_data_len = len(self.poses)
35
+
36
+ def __getitem__(self, idx):
37
+ """
38
+ Return:
39
+ [21, 3] or [21, 6] for poses including body and root orient
40
+ [10] for shapes (betas) [Optimal]
41
+ """
42
+ data_poses = self.poses[idx % self.real_data_len]
43
+ #coeffs = data_poses}
44
+ if self.return_shape:
45
+ return data_poses, self.shapes[idx % self.real_data_len]
46
+ return data_poses
47
+
48
+ def __len__(self, ):
49
+ return len(self.poses)
50
+
51
+ def _sample(self, sample_interval):
52
+ print(f'Class AMASSDataset({self.subset}): sample dataset every {sample_interval} frame')
53
+ self.poses = self.poses[::sample_interval]
54
+
55
+ def read_data(self):
56
+ data_path = os.path.join(self.root_path, self.subset)
57
+ # root_orient = torch.load(os.path.join(data_path, 'root_orient.pt'))
58
+ coeffs = torch.load(os.path.join(data_path, 'train_coeffs.pt'))
59
+ shapes = torch.load(os.path.join(data_path, 'betas.pt')) if self.return_shape else None
60
+ # poses = torch.cat([root_orient, pose_body], dim=1)
61
+ data_len = len(coeffs)
62
+ if self.num_coeffs < 300:
63
+ coeffs = coeffs[:, -self.num_coeffs:]
64
+
65
+ return coeffs, shapes
66
+
67
+ def Normalize(self):
68
+ # Use train dataset for normalize computing, Z_score or min-max Normalize
69
+ if self.min_max:
70
+ normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize1.pt')
71
+ else:
72
+ normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize2.pt')
73
+
74
+ if os.path.exists(normalize_path):
75
+ normalize_params = torch.load(normalize_path)
76
+ if self.min_max:
77
+ min_poses, max_poses, min_shapes, max_shapes = (
78
+ normalize_params['min_poses'],
79
+ normalize_params['max_poses'],
80
+ normalize_params['min_shapes'],
81
+ normalize_params['max_shapes']
82
+ )
83
+ else:
84
+ mean_poses, std_poses, mean_shapes, std_shapes = (
85
+ normalize_params['mean_poses'],
86
+ normalize_params['std_poses'],
87
+ normalize_params['mean_shapes'],
88
+ normalize_params['std_shapes']
89
+ )
90
+ else:
91
+ if self.min_max:
92
+ min_poses = torch.min(self.poses, dim=0)[0]
93
+ max_poses = torch.max(self.poses, dim=0)[0]
94
+
95
+ min_shapes = torch.min(self.shapes, dim=0)[0] if self.return_shape else None
96
+ max_shapes = torch.max(self.shapes, dim=0)[0] if self.return_shape else None
97
+
98
+ torch.save({
99
+ 'min_poses': min_poses,
100
+ 'max_poses': max_poses,
101
+ 'min_shapes': min_shapes,
102
+ 'max_shapes': max_shapes
103
+ }, normalize_path)
104
+ else:
105
+ mean_poses = torch.mean(self.poses, dim=0)
106
+ std_poses = torch.std(self.poses, dim=0)
107
+
108
+ mean_shapes = torch.mean(self.shapes, dim=0) if self.return_shape else None
109
+ std_shapes = torch.std(self.shapes, dim=0) if self.return_shape else None
110
+
111
+ torch.save({
112
+ 'mean_poses': mean_poses,
113
+ 'std_poses': std_poses,
114
+ 'mean_shapes': mean_shapes,
115
+ 'std_shapes': std_shapes
116
+ }, normalize_path)
117
+
118
+ if self.min_max:
119
+ self.poses = 2 * (self.poses - min_poses) / (max_poses - min_poses) - 1
120
+ if self.return_shape:
121
+ self.shapes = 2 * (self.shapes - min_shapes) / (max_shapes - min_shapes) - 1
122
+ return min_poses, max_poses, min_shapes, max_shapes
123
+
124
+ else:
125
+ self.poses = (self.poses - mean_poses) / std_poses
126
+ if self.return_shape:
127
+ self.shapes = (self.shapes - mean_shapes) / std_shapes
128
+ return mean_poses, std_poses, mean_shapes, std_shapes
129
+
130
+
131
+ def Denormalize(self, poses, shapes=None):
132
+ assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
133
+
134
+ if self.min_max:
135
+ min_poses = self.min_poses.view(1, -1).to(poses.device)
136
+ max_poses = self.max_poses.view(1, -1).to(poses.device)
137
+
138
+ if len(poses.shape) == 3: # [t, b, data_dim]
139
+ min_poses = min_poses.unsqueeze(0)
140
+ max_poses = max_poses.unsqueeze(0)
141
+
142
+ normalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)
143
+
144
+ if shapes is not None and self.min_shapes is not None:
145
+ min_shapes = self.min_shapes.view(1, -1).to(shapes.device)
146
+ max_shapes = self.max_shapes.view(1, -1).to(shapes.device)
147
+
148
+ if len(shapes.shape) == 3:
149
+ min_shapes = min_shapes.unsqueeze(0)
150
+ max_shapes = max_shapes.unsqueeze(0)
151
+
152
+ normalized_shapes = 0.5 * ((shapes + 1) * (max_shapes - min_shapes) + 2 * min_shapes)
153
+ return normalized_poses, normalized_shapes
154
+ else:
155
+ return normalized_poses
156
+ else:
157
+ mean_poses = self.mean_poses.view(1, -1).to(poses.device)
158
+ std_poses = self.std_poses.view(1, -1).to(poses.device)
159
+
160
+ if len(poses.shape) == 3: # [t, b, data_dim]
161
+ mean_poses = mean_poses.unsqueeze(0)
162
+ std_poses = std_poses.unsqueeze(0)
163
+
164
+ normalized_poses = poses * std_poses + mean_poses
165
+
166
+ if shapes is not None and self.mean_shapes is not None:
167
+ mean_shapes = self.mean_shapes.view(1, -1)
168
+ std_shapes = self.std_shapes.view(1, -1)
169
+
170
+ if len(shapes.shape) == 3:
171
+ mean_shapes = mean_shapes.unsqueeze(0)
172
+ std_shapes = std_shapes.unsqueeze(0)
173
+
174
+ normalized_shapes = shapes * std_shapes + mean_shapes
175
+ return normalized_poses, normalized_shapes
176
+ else:
177
+ return normalized_poses
178
+
179
+ def eval(self, preds):
180
+ pass
181
+
182
+
183
+ class Posenormalizer:
184
+ def __init__(self, data_path, device='cuda:0', normalize=True, min_max=True, rot_rep=None):
185
+ assert rot_rep in ['rot6d', 'axis']
186
+ self.normalize = normalize
187
+ self.min_max = min_max
188
+ self.rot_rep = rot_rep
189
+ normalize_params = torch.load(os.path.join(data_path, '{}_normalize1.pt'.format(rot_rep)))
190
+ self.min_poses, self.max_poses = normalize_params['min_poses'].to(device), normalize_params['max_poses'].to(device)
191
+ normalize_params = torch.load(os.path.join(data_path, '{}_normalize2.pt'.format(rot_rep)))
192
+ self.mean_poses, self.std_poses = normalize_params['mean_poses'].to(device), normalize_params['std_poses'].to(device)
193
+
194
+ def offline_normalize(self, poses, from_axis=False):
195
+ assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
196
+ pose_shape = poses.shape
197
+
198
+ if not self.normalize:
199
+ return poses
200
+
201
+ if self.min_max:
202
+ min_poses = self.min_poses.view(1, -1)
203
+ max_poses = self.max_poses.view(1, -1)
204
+
205
+ if len(poses.shape) == 3: # [t, b, data_dim]
206
+ min_poses = min_poses.unsqueeze(0)
207
+ max_poses = max_poses.unsqueeze(0)
208
+
209
+ normalized_poses = 2 * (poses - min_poses) / (max_poses - min_poses) - 1
210
+
211
+ else:
212
+ mean_poses = self.mean_poses.view(1, -1)
213
+ std_poses = self.std_poses.view(1, -1)
214
+
215
+ if len(poses.shape) == 3: # [t, b, data_dim]
216
+ mean_poses = mean_poses.unsqueeze(0)
217
+ std_poses = std_poses.unsqueeze(0)
218
+
219
+ normalized_poses = (poses - mean_poses) / std_poses
220
+
221
+ return normalized_poses
222
+
223
+ def offline_denormalize(self, poses, to_axis=False):
224
+ assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
225
+
226
+ if not self.normalize:
227
+ denormalized_poses = poses
228
+ else:
229
+ if self.min_max:
230
+ min_poses = self.min_poses.view(1, -1)
231
+ max_poses = self.max_poses.view(1, -1)
232
+
233
+ if len(poses.shape) == 3: # [t, b, data_dim]
234
+ min_poses = min_poses.unsqueeze(0)
235
+ max_poses = max_poses.unsqueeze(0)
236
+
237
+ denormalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)
238
+
239
+ else:
240
+ mean_poses = self.mean_poses.view(1, -1)
241
+ std_poses = self.std_poses.view(1, -1)
242
+
243
+ if len(poses.shape) == 3: # [t, b, data_dim]
244
+ mean_poses = mean_poses.unsqueeze(0)
245
+ std_poses = std_poses.unsqueeze(0)
246
+
247
+ denormalized_poses = poses * std_poses + mean_poses
248
+
249
+ return denormalized_poses
250
+
251
+
252
+ # a simple eval process for completion task
253
+ class Evaler:
254
+ def __init__(self, body_model, part=None):
255
+ self.body_model = body_model
256
+ self.part = part
257
+
258
+ if self.part is not None:
259
+ self.joint_idx = np.array(getattr(BodyPartIndices, self.part)) + 1 # skip pelvis
260
+ self.vert_idx = np.array(getattr(BodySegIndices, self.part))
261
+ else:
262
+ self.joint_idx = slice(None)
263
+ self.vert_idx = slice(None)
264
+
265
+ def eval_bodys(self, outs, gts):
266
+ '''
267
+ :param outs: [b, j*3] axis-angle results of body poses
268
+ :param gts: [b, j*3] axis-angle groundtruth of body poses
269
+ :return: result dict for every sample
270
+ '''
271
+ sample_num = len(outs)
272
+ eval_result = {'mpvpe_all': [], 'mpjpe_body': []}
273
+ body_gt = self.body_model(pose_body=gts)
274
+ body_out = self.body_model(pose_body=outs)
275
+
276
+ for n in range(sample_num):
277
+ # MPVPE from all vertices
278
+ mesh_gt = body_gt.v.detach().cpu().numpy()[n, self.vert_idx]
279
+ mesh_out = body_out.v.detach().cpu().numpy()[n, self.vert_idx]
280
+ eval_result['mpvpe_all'].append(np.sqrt(np.sum((mesh_out - mesh_gt) ** 2, 1)).mean() * 1000)
281
+
282
+ joint_gt_body = body_gt.Jtr.detach().cpu().numpy()[n, self.joint_idx]
283
+ joint_out_body = body_out.Jtr.detach().cpu().numpy()[n, self.joint_idx]
284
+
285
+ eval_result['mpjpe_body'].append(
286
+ np.sqrt(np.sum((joint_out_body - joint_gt_body) ** 2, 1)).mean() * 1000)
287
+
288
+ return eval_result
289
+
290
+ def multi_eval_bodys(self, outs, gts):
291
+ '''
292
+ :param outs: [b, hypo, j*3] axis-angle results of body poses, multiple hypothesis
293
+ :param gts: [b, j*3] axis-angle groundtruth of body poses
294
+ :return: result dict
295
+ '''
296
+ hypo_num = outs.shape[1]
297
+ eval_result = {f'mpvpe_all': [], f'mpjpe_body': []}
298
+ for hypo in range(hypo_num):
299
+ result = self.eval_bodys(outs[:, hypo], gts)
300
+ eval_result['mpvpe_all'].append(result['mpvpe_all'])
301
+ eval_result['mpjpe_body'].append(result['mpjpe_body'])
302
+
303
+ eval_result['mpvpe_all'] = np.min(eval_result['mpvpe_all'], axis=0)
304
+ eval_result['mpjpe_body'] = np.min(eval_result['mpjpe_body'], axis=0)
305
+
306
+ return eval_result
307
+
308
+ def print_eval_result(self, eval_result):
309
+ print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
310
+ print('MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))
311
+
312
+ def print_multi_eval_result(self, eval_result, hypo_num):
313
+ print(f'multihypo {hypo_num} MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
314
+ print(f'multihypo {hypo_num} MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))
diffu_models/dit_models.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
17
+
18
+
19
+ def modulate(x, shift, scale):
20
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
21
+
22
+
23
+ #################################################################################
24
+ # Embedding Layers for Timesteps and Class Labels #
25
+ #################################################################################
26
+
27
+ class TimestepEmbedder(nn.Module):
28
+ """
29
+ Embeds scalar timesteps into vector representations.
30
+ """
31
+ def __init__(self, hidden_size, frequency_embedding_size=256):
32
+ super().__init__()
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
35
+ nn.SiLU(),
36
+ nn.Linear(hidden_size, hidden_size, bias=True),
37
+ )
38
+ self.frequency_embedding_size = frequency_embedding_size
39
+
40
+ @staticmethod
41
+ def timestep_embedding(t, dim, max_period=10000):
42
+ """
43
+ Create sinusoidal timestep embeddings.
44
+ :param t: a 1-D Tensor of N indices, one per batch element.
45
+ These may be fractional.
46
+ :param dim: the dimension of the output.
47
+ :param max_period: controls the minimum frequency of the embeddings.
48
+ :return: an (N, D) Tensor of positional embeddings.
49
+ """
50
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
51
+ half = dim // 2
52
+ freqs = torch.exp(
53
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
54
+ ).to(device=t.device)
55
+ args = t[:, None].float() * freqs[None]
56
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
57
+ if dim % 2:
58
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
59
+ return embedding
60
+
61
+ def forward(self, t):
62
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
63
+ t_emb = self.mlp(t_freq)
64
+ return t_emb
65
+
66
+
67
+ class LabelEmbedder(nn.Module):
68
+ """
69
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
70
+ """
71
+ def __init__(self, num_classes, hidden_size, dropout_prob):
72
+ super().__init__()
73
+ use_cfg_embedding = dropout_prob > 0
74
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
75
+ self.num_classes = num_classes
76
+ self.dropout_prob = dropout_prob
77
+
78
+ def token_drop(self, labels, force_drop_ids=None):
79
+ """
80
+ Drops labels to enable classifier-free guidance.
81
+ """
82
+ if force_drop_ids is None:
83
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
84
+ else:
85
+ drop_ids = force_drop_ids == 1
86
+ labels = torch.where(drop_ids, self.num_classes, labels)
87
+ return labels
88
+
89
+ def forward(self, labels, train, force_drop_ids=None):
90
+ use_dropout = self.dropout_prob > 0
91
+ if (train and use_dropout) or (force_drop_ids is not None):
92
+ labels = self.token_drop(labels, force_drop_ids)
93
+ embeddings = self.embedding_table(labels)
94
+ return embeddings
95
+
96
+
97
+ #################################################################################
98
+ # Core DiT Model #
99
+ #################################################################################
100
+
101
+ class DiTBlock(nn.Module):
102
+ """
103
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
104
+ """
105
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
106
+ super().__init__()
107
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
108
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
109
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
111
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
112
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
113
+ self.adaLN_modulation = nn.Sequential(
114
+ nn.SiLU(),
115
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
116
+ )
117
+
118
+ def forward(self, x, c):
119
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
120
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
121
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
122
+ return x
123
+
124
+
125
+ class FinalLayer(nn.Module):
126
+ """
127
+ The final layer of DiT.
128
+ """
129
+ def __init__(self, hidden_size, patch_size, out_channels):
130
+ super().__init__()
131
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133
+ self.adaLN_modulation = nn.Sequential(
134
+ nn.SiLU(),
135
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
136
+ )
137
+
138
+ def forward(self, x, c):
139
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
140
+ x = modulate(self.norm_final(x), shift, scale)
141
+ x = self.linear(x)
142
+ return x
143
+
144
+
145
+ class DiT(nn.Module):
146
+ """
147
+ Diffusion model with a Transformer backbone.
148
+ """
149
+ def __init__(
150
+ self,
151
+ input_size=32,
152
+ patch_size=2,
153
+ in_channels=4,
154
+ hidden_size=1152,
155
+ depth=28,
156
+ num_heads=16,
157
+ mlp_ratio=4.0,
158
+ class_dropout_prob=0.1,
159
+ num_classes=1000,
160
+ learn_sigma=True,
161
+ conditioning=False
162
+ ):
163
+ super().__init__()
164
+ self.learn_sigma = learn_sigma
165
+ self.in_channels = in_channels
166
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
167
+ self.patch_size = patch_size
168
+ self.num_heads = num_heads
169
+
170
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
171
+ self.t_embedder = TimestepEmbedder(hidden_size)
172
+ self.conditioning = conditioning
173
+ if conditioning:
174
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
175
+ num_patches = self.x_embedder.num_patches
176
+ # Will use fixed sin-cos embedding:
177
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
178
+
179
+ self.blocks = nn.ModuleList([
180
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
181
+ ])
182
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
183
+ self.initialize_weights()
184
+
185
+ def initialize_weights(self):
186
+ # Initialize transformer layers:
187
+ def _basic_init(module):
188
+ if isinstance(module, nn.Linear):
189
+ torch.nn.init.xavier_uniform_(module.weight)
190
+ if module.bias is not None:
191
+ nn.init.constant_(module.bias, 0)
192
+ self.apply(_basic_init)
193
+
194
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
195
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
196
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
197
+
198
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
199
+ w = self.x_embedder.proj.weight.data
200
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
201
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
202
+
203
+ # Initialize label embedding table:
204
+ if self.conditioning:
205
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
206
+
207
+ # Initialize timestep embedding MLP:
208
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
209
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
210
+
211
+ # Zero-out adaLN modulation layers in DiT blocks:
212
+ for block in self.blocks:
213
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
214
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
215
+
216
+ # Zero-out output layers:
217
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
218
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
219
+ nn.init.constant_(self.final_layer.linear.weight, 0)
220
+ nn.init.constant_(self.final_layer.linear.bias, 0)
221
+
222
+ def unpatchify(self, x):
223
+ """
224
+ x: (N, T, patch_size**2 * C)
225
+ imgs: (N, H, W, C)
226
+ """
227
+ c = self.out_channels
228
+ p = self.x_embedder.patch_size[0]
229
+ h = w = int(x.shape[1] ** 0.5)
230
+ assert h * w == x.shape[1]
231
+
232
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
233
+ x = torch.einsum('nhwpqc->nchpwq', x)
234
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
235
+ return imgs
236
+
237
+ def forward(self, x, t, y=None):
238
+ """
239
+ Forward pass of DiT.
240
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
241
+ t: (N,) tensor of diffusion timesteps
242
+ y: (N,) tensor of class labels
243
+ """
244
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
245
+ t = self.t_embedder(t) # (N, D)
246
+
247
+ c = t # (N, D)
248
+ if self.conditioning:
249
+ y = self.y_embedder(y, self.training) # (N, D)
250
+ c += t
251
+ for block in self.blocks:
252
+ x = block(x, c) # (N, T, D)
253
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
254
+ x = self.unpatchify(x) # (N, out_channels, H, W)
255
+ return x
256
+
257
+ def forward_with_cfg(self, x, t, y, cfg_scale):
258
+ """
259
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
260
+ """
261
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
262
+ half = x[: len(x) // 2]
263
+ combined = torch.cat([half, half], dim=0)
264
+ model_out = self.forward(combined, t, y)
265
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
266
+ # three channels by default. The standard approach to cfg applies it to all channels.
267
+ # This can be done by uncommenting the following line and commenting-out the line following that.
268
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
269
+ eps, rest = model_out[:, :3], model_out[:, 3:]
270
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
271
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
272
+ eps = torch.cat([half_eps, half_eps], dim=0)
273
+ return torch.cat([eps, rest], dim=1)
274
+
275
+
276
+ #################################################################################
277
+ # Sine/Cosine Positional Embedding Functions #
278
+ #################################################################################
279
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
280
+
281
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
282
+ """
283
+ grid_size: int of the grid height and width
284
+ return:
285
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
286
+ """
287
+ grid_h = np.arange(grid_size, dtype=np.float32)
288
+ grid_w = np.arange(grid_size, dtype=np.float32)
289
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
290
+ grid = np.stack(grid, axis=0)
291
+
292
+ grid = grid.reshape([2, 1, grid_size, grid_size])
293
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
294
+ if cls_token and extra_tokens > 0:
295
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
296
+ return pos_embed
297
+
298
+
299
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
300
+ assert embed_dim % 2 == 0
301
+
302
+ # use half of dimensions to encode grid_h
303
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
304
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
305
+
306
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
307
+ return emb
308
+
309
+
310
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
311
+ """
312
+ embed_dim: output dimension for each position
313
+ pos: a list of positions to be encoded: size (M,)
314
+ out: (M, D)
315
+ """
316
+ assert embed_dim % 2 == 0
317
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
318
+ omega /= embed_dim / 2.
319
+ omega = 1. / 10000**omega # (D/2,)
320
+
321
+ pos = pos.reshape(-1) # (M,)
322
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
323
+
324
+ emb_sin = np.sin(out) # (M, D/2)
325
+ emb_cos = np.cos(out) # (M, D/2)
326
+
327
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
328
+ return emb
329
+
330
+
331
+ #################################################################################
332
+ # DiT Configs #
333
+ #################################################################################
334
+
335
+ def DiT_XL_2(**kwargs):
336
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
337
+
338
+ def DiT_XL_4(**kwargs):
339
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
340
+
341
+ def DiT_XL_8(**kwargs):
342
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
343
+
344
+ def DiT_L_2(**kwargs):
345
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
346
+
347
+ def DiT_L_4(**kwargs):
348
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
349
+
350
+ def DiT_L_5(**kwargs):
351
+ return DiT(depth=24, hidden_size=1024, patch_size=5, num_heads=16, **kwargs)
352
+
353
+ def DiT_L_8(**kwargs):
354
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
355
+
356
+ def DiT_B_2(**kwargs):
357
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
358
+
359
+ def DiT_B_4(**kwargs):
360
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
361
+
362
+ def DiT_B_5(**kwargs):
363
+ return DiT(depth=12, hidden_size=768, patch_size=5, num_heads=12, **kwargs)
364
+
365
+ def DiT_B_8(**kwargs):
366
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
367
+
368
+ def DiT_S_2(**kwargs):
369
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
370
+
371
+ def DiT_S_4(**kwargs):
372
+ return DiT(depth=12, hidden_size=384, patch_size=5, num_heads=6, **kwargs)
373
+
374
+ def DiT_S_8(**kwargs):
375
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
376
+
377
+
378
+ DiT_models = {
379
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
380
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-L/5': DiT_L_5,
381
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-B/5': DiT_B_5,
382
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
383
+ }
diffu_models/losses.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+
8
+ """Loss functions used in the paper
9
+ "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10
+
11
+ import torch
12
+ from edm.torch_utils import persistence
13
+ import pdb
14
+ #----------------------------------------------------------------------------
15
+ # Loss function corresponding to the variance preserving (VP) formulation
16
+ # from the paper "Score-Based Generative Modeling through Stochastic
17
+ # Differential Equations".
18
+
19
+ @persistence.persistent_class
20
+ class VPLoss:
21
+ def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
22
+ self.beta_d = beta_d
23
+ self.beta_min = beta_min
24
+ self.epsilon_t = epsilon_t
25
+
26
+ def noise_and_weight(self, shape, device, sds=False):
27
+ rnd_uniform = torch.rand([shape, 1, 1, 1], device=device)
28
+ if sds:
29
+ rnd_uniform = 0.02 + rnd_uniform*0.96 #Between O.O2 and 0.98, see https://github.com/ashawkey/stable-dreamfusion/blob/5550b91862a3af7842bb04875b7f1211e5095a63/guidance/sd_utils.py#L180
30
+ sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
31
+ weight = 1 / sigma ** 2
32
+ return sigma, weight
33
+
34
+ def __call__(self, net, x, latents, augment_pipe=None):
35
+ sigma, weight = self.noise_and_weight(x.shape[0], x.device)
36
+ n = torch.randn_like(x) * sigma
37
+ D_xn = net(x + n, sigma, latents)
38
+ loss = weight * ((D_xn - x) ** 2)
39
+ return loss
40
+
41
+ def sigma(self, t):
42
+ t = torch.as_tensor(t)
43
+ return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
44
+
45
+ #----------------------------------------------------------------------------
46
+ # Loss function corresponding to the variance exploding (VE) formulation
47
+ # from the paper "Score-Based Generative Modeling through Stochastic
48
+ # Differential Equations".
49
+
50
+ @persistence.persistent_class
51
+ class VELoss:
52
+ def __init__(self, sigma_min=0.02, sigma_max=100):
53
+ self.sigma_min = sigma_min
54
+ self.sigma_max = sigma_max
55
+
56
+ def noise_and_weight(self, shape, device, sds=False):
57
+ rnd_uniform = torch.rand([x.shape[0], 1], device=x.device)
58
+ sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
59
+ weight = 1 / sigma ** 2
60
+ return sigma, weight
61
+
62
+ def __call__(self, net, x, latents, augment_pipe=None):
63
+ sigma, weight = self.noise_and_weight(x.shape[0], x.device)
64
+ n = torch.randn_like(x) * sigma
65
+ D_xn = net(x + n, sigma, latents)
66
+ loss = weight * ((D_xn - x) ** 2)
67
+ return loss
68
+
69
+ #----------------------------------------------------------------------------
70
+ # Improved loss function proposed in the paper "Elucidating the Design Space
71
+ # of Diffusion-Based Generative Models" (EDM).
72
+
73
+ @persistence.persistent_class
74
+ class EDMLoss:
75
+ def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
76
+ self.P_mean = P_mean
77
+ self.P_std = P_std
78
+ self.sigma_data = sigma_data
79
+ self.sigma_min = 0.4
80
+ self.sigma_max = 10
81
+ self.rho=3
82
+
83
+ def noise_and_weight(self, shape, device, sds=False):
84
+ rnd_normal = torch.randn([shape, 1, 1, 1], device=device)
85
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
86
+ weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
87
+ return sigma.float(), weight.float()
88
+
89
+ def __call__(self, net, x, latents, augment_pipe=None):
90
+ sigma, weight = self.noise_and_weight(x.shape[0], x.device)
91
+ n = torch.randn_like(x) * sigma
92
+ D_xn = net(x + n, sigma, latents)
93
+ loss = weight * ((D_xn - x) ** 2)
94
+ return loss
95
+
96
+ #----------------------------------------------------------------------------
diffu_models/precond.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ #----------------------------------------------------------------------------
5
+ # Preconditioning corresponding to the variance exploding (VE) formulation
6
+ # from the paper "Score-Based Generative Modeling through Stochastic
7
+ # Differential Equations".
8
+
9
+ class VEPrecond(torch.nn.Module):
10
+ def __init__(self,
11
+ model,
12
+ label_dim = 0, # Number of class labels, 0 = unconditional.
13
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
14
+ sigma_min = 0.02, # Minimum supported noise level.
15
+ sigma_max = 100, # Maximum supported noise level.
16
+ ):
17
+ super().__init__()
18
+ self.label_dim = label_dim
19
+ self.use_fp16 = use_fp16
20
+ self.sigma_min = sigma_min
21
+ self.sigma_max = sigma_max
22
+ self.model = model
23
+
24
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
25
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
26
+ x = x.to(torch.float32)
27
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
28
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
29
+
30
+ c_skip = 1
31
+ c_out = sigma
32
+ c_in = 1
33
+ c_noise = (0.5 * sigma).log()
34
+
35
+ if class_labels is not None:
36
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
37
+ else:
38
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
39
+ assert F_x.dtype == dtype
40
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
41
+ return D_x
42
+
43
+ def round_sigma(self, sigma):
44
+ return torch.as_tensor(sigma)
45
+
46
+ #----------------------------------------------------------------------------
47
+ # Preconditioning corresponding to improved DDPM (iDDPM) formulation from
48
+ # the paper "Improved Denoising Diffusion Probabilistic Models".
49
+
50
+
51
+ class iDDPMPrecond(torch.nn.Module):
52
+ def __init__(self,
53
+ model,
54
+ label_dim = 0, # Number of class labels, 0 = unconditional.
55
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
56
+ C_1 = 0.001, # Timestep adjustment at low noise levels.
57
+ C_2 = 0.008, # Timestep adjustment at high noise levels.
58
+ M = 1000, # Original number of timesteps in the DDPM formulation.
59
+ ):
60
+ super().__init__()
61
+ self.label_dim = label_dim
62
+ self.use_fp16 = use_fp16
63
+ self.C_1 = C_1
64
+ self.C_2 = C_2
65
+ self.M = M
66
+ self.model = model
67
+ u = torch.zeros(M + 1)
68
+ for j in range(M, 0, -1): # M, ..., 1
69
+ u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
70
+ self.register_buffer('u', u)
71
+ self.sigma_min = float(u[M - 1])
72
+ self.sigma_max = float(u[0])
73
+
74
+ def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs):
75
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
76
+ x = x.to(torch.float32)
77
+ class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
78
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
79
+
80
+ c_skip = 1
81
+ c_out = -sigma
82
+ c_in = 1 / (sigma ** 2 + 1).sqrt()
83
+ c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
84
+ # if class_labels is not None:
85
+ # F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
86
+ # else:
87
+ if lamb is not None:
88
+ F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs)
89
+ else:
90
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
91
+ assert F_x.dtype == dtype
92
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
93
+ return D_x
94
+
95
+ def alpha_bar(self, j):
96
+ j = torch.as_tensor(j)
97
+ return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
98
+
99
+ def round_sigma(self, sigma, return_index=False):
100
+ sigma = torch.as_tensor(sigma)
101
+ index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
102
+ result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
103
+ return result.reshape(sigma.shape).to(sigma.device)
104
+
105
+ #----------------------------------------------------------------------------
106
+ # Improved preconditioning proposed in the paper "Elucidating the Design
107
+ # Space of Diffusion-Based Generative Models" (EDM).
108
+
109
+ class EDMPrecond(torch.nn.Module):
110
+ def __init__(self,
111
+ model,
112
+ label_dim = 0, # Number of class labels, 0 = unconditional.
113
+ use_fp16 = False, # Execute the underlying model at FP16 precision?
114
+ sigma_min = 0, # Minimum supported noise level.
115
+ sigma_max = float('inf'), # Maximum supported noise level.
116
+ sigma_data = 0.5, # Expected standard deviation of the training data.
117
+ ):
118
+ super().__init__()
119
+ self.label_dim = label_dim
120
+ self.use_fp16 = use_fp16
121
+ self.sigma_min = sigma_min
122
+ self.sigma_max = sigma_max
123
+ self.sigma_data = sigma_data
124
+ self.model = model
125
+
126
+ def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
127
+ x = x.to(torch.float32)
128
+ sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
129
+ if class_labels is not None:
130
+ if self.label_dim == 0:
131
+ class_labels = None
132
+ else:
133
+ class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
134
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
135
+
136
+ c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
137
+ c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
138
+ c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
139
+ c_in = c_in.to(x.device)
140
+ c_noise = sigma.log() / 4
141
+ if class_labels is not None:
142
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs)
143
+ else:
144
+ F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
145
+ assert F_x.dtype == dtype
146
+ D_x = c_skip * x + c_out * F_x.to(torch.float32)
147
+ return D_x
148
+
149
+ def round_sigma(self, sigma):
150
+ return torch.as_tensor(sigma)
151
+
152
+ #----------------------------------------------------------------------------
diffu_models/sds.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import grad
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class WarmupCosineDecayScheduler(_LRScheduler):
8
+ def __init__(self, optimizer, warmup_steps, total_steps, warmup_start_lr=1e-9, max_lr=1e-4, min_lr=1e-6, last_epoch=-1):
9
+ self.warmup_steps = warmup_steps
10
+ self.total_steps = total_steps
11
+ self.warmup_start_lr = warmup_start_lr
12
+ self.max_lr = max_lr
13
+ self.min_lr = min_lr
14
+ super(WarmupCosineDecayScheduler, self).__init__(optimizer, last_epoch)
15
+
16
+ def get_lr(self):
17
+ if self.last_epoch < self.warmup_steps:
18
+ # Linear warmup
19
+ lr = self.max_lr * self.last_epoch/self.warmup_steps + (1-self.last_epoch/self.warmup_steps) * self.warmup_start_lr
20
+ else:
21
+ # Cosine decay
22
+ cosine_decay = 0.5 * (1 + np.cos(torch.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
23
+ decayed = (1 - self.min_lr / self.max_lr) * cosine_decay + self.min_lr / self.max_lr
24
+ lr = self.max_lr * decayed
25
+ return [lr for _ in self.base_lrs]
26
+
27
+
28
+ def guidance_grad(pred_shape, net, scale_noise, grad_scale=1, batch_size=32, device="cpu", save_guidance_path=None):
29
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
30
+ sigma = 0.01 + torch.rand([batch_size, 1, 1, 1], device=device)*scale_noise
31
+ # predict the noise residual with unet, NO grad!
32
+ with torch.no_grad():
33
+ # sample noise
34
+ noise = torch.randn_like(pred_shape) * sigma
35
+ # pred noise
36
+ x = pred_shape+noise
37
+ denoised = net(x, sigma)
38
+ # w(t), sigma_t^2
39
+ grad = torch.mean(grad_scale * (pred_shape - denoised), dim=0) # / sigma**2
40
+ #print(sigma.item()**2, weight.item(), torch.norm(pred_shape-denoised).item())
41
+ #print(grad)
42
+ grad = torch.nan_to_num(grad)
43
+
44
+ # if save_guidance_path:
45
+ # with torch.no_grad():
46
+ # if as_latent:
47
+ # pred_rgb_512 = self.decode_latents(latents)
48
+
49
+ # # visualize predicted denoised image
50
+ # # The following block of code is equivalent to `predict_start_from_noise`...
51
+ # # see zero123_utils.py's version for a simpler implementation.
52
+ # alphas = self.scheduler.alphas.to(latents)
53
+ # total_timesteps = self.max_step - self.min_step + 1
54
+ # index = total_timesteps - t.to(latents.device) - 1
55
+ # b = len(noise_pred)
56
+ # a_t = alphas[index].reshape(b,1,1,1).to(self.device)
57
+ # sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
58
+ # sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
59
+ # pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
60
+ # result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
61
+
62
+ # # visualize noisier image
63
+ # result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
64
+
65
+ # # TODO: also denoise all-the-way
66
+
67
+ # # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
68
+ # viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
69
+ # save_image(viz_images, save_guidance_path)
70
+
71
+ return grad, denoised
72
+
73
+ def guidance_loss(pred_shape, loss_sde, net, grad_scale=1, device="cpu", save_guidance_path=None):
74
+ grad = guidance_grad(pred_shape, loss_sde, net, grad_scale, device, save_guidance_path)
75
+ targets = (pred_shape - grad).detach()
76
+ loss = 0.5 * F.mse_loss(pred_shape.float(), targets, reduction='sum') / pred_shape.shape[0]
77
+ return loss
shape_data/__init__.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os.path as osp
3
+ import numpy as np
4
+ import torch
5
+ from collections import defaultdict
6
+
7
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
8
+ if ROOT_DIR not in sys.path:
9
+ sys.path.append(ROOT_DIR)
10
+
11
+ DATA_DIRS = {
12
+ 'faust': 'FAUST_r',
13
+ 'faust_ori': 'FAUST_r_ori',
14
+ 'scape': 'SCAPE_r',
15
+ 'scape_ori': 'SCAPE_r_ori',
16
+ 'smalr': 'SMAL_r',
17
+ 'smalr_ori': 'SMAL_r_ori',
18
+ 'shrec19': 'SHREC_r',
19
+ 'shrec19_ori': 'SHREC_r_ori',
20
+ 'dt4d': 'DT4D_r',
21
+ 'dt4dintra': 'DT4D_r',
22
+ 'dt4dintra_ori': 'DT4D_r_ori',
23
+ 'dt4dinter': 'DT4D_r',
24
+ 'dt4dinter_ori': 'DT4D_r_ori',
25
+ 'tosca': 'TOSCA_r',
26
+ 'tosca_ori': 'TOSCA_r',
27
+ }
28
+
29
+
30
+ def get_data_dirs(root, name, mode):
31
+ prefix = osp.join(root, DATA_DIRS[name])
32
+ shape_dir = osp.join(prefix, 'shapes')
33
+ corr_dir = osp.join(prefix, 'correspondences')
34
+ return shape_dir, DATA_DIRS[name], corr_dir
35
+
36
+
37
+ # def collate_default(data_list):
38
+ # data_dict = defaultdict(list)
39
+ # for pair_dict in data_list:
40
+ # for k, v in pair_dict.items():
41
+ # data_dict[k].append(v)
42
+ # for k in data_dict.keys():
43
+ # if k.startswith('fmap') or k.startswith('evals') or k.endswith('_sub'):
44
+ # data_dict[k] = np.stack(data_dict[k], axis=0)
45
+ # batch_size = len(data_list)
46
+ # for k, v in data_dict.items():
47
+ # assert len(v) == batch_size
48
+
49
+ # return data_dict
50
+
51
+
52
+ def prepare_batch(data_dict, device):
53
+ for k in data_dict.keys():
54
+ if isinstance(data_dict[k], np.ndarray):
55
+ data_dict[k] = torch.from_numpy(data_dict[k]).to(device)
56
+ else:
57
+ if k.startswith('gradX') or \
58
+ k.startswith('gradY') or \
59
+ k.startswith('L'):
60
+ from diffusion_net.utils import sparse_np_to_torch
61
+ tmp_list = [sparse_np_to_torch(st).to(device) for st in data_dict[k]]
62
+ if len(data_dict[k]) == 1:
63
+ data_dict[k] = torch.stack(tmp_list, dim=0)
64
+ else:
65
+ data_dict[k] = tmp_list
66
+ else:
67
+ if isinstance(data_dict[k][0], np.ndarray):
68
+ tmp_list = [torch.from_numpy(st).to(device) for st in data_dict[k]]
69
+ if len(data_dict[k]) == 1:
70
+ data_dict[k] = torch.stack(tmp_list, dim=0).to(device)
71
+ else:
72
+ data_dict[k] = tmp_list
73
+
74
+ return data_dict
shape_data/data_utils.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy
2
+ import scipy.sparse
3
+ import scipy.sparse.linalg
4
+ from scipy.io import loadmat
5
+ import sys
6
+ import os
7
+ import os.path as osp
8
+ import math
9
+ import numpy as np
10
+ import open3d as o3d
11
+ import potpourri3d as pp3d
12
+ import torch
13
+ from pathlib import Path
14
+
15
+ class CorrLoader(object):
16
+
17
+ def __init__(self, root_dir, data_type='mat'):
18
+ self.root_dir = root_dir
19
+ self.data_type = data_type
20
+
21
+ def get_by_names(self, sname0, sname1):
22
+ if self.data_type.endswith('mat'):
23
+ pmap10 = self._load_mat(osp.join(self.root_dir, f'{sname0}-{sname1}.mat'))
24
+ return np.stack((pmap10, np.arange(len(pmap10))), axis=1)
25
+ else:
26
+ raise RuntimeError(f'Data type {self.data_type} is not supported.')
27
+
28
+ def _load_mat(self, filepath):
29
+ data = loadmat(filepath)
30
+ pmap10 = np.squeeze(np.asarray(data['pmap10'], dtype=np.int32))
31
+ return pmap10
32
+
33
+
34
+ # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/HKS_functions.py
35
+ def HKS(evals, evects, time_list, scaled=False):
36
+ evals_s = np.asarray(evals).flatten()
37
+ t_list = np.asarray(time_list).flatten()
38
+
39
+ coefs = np.exp(-np.outer(t_list, evals_s))
40
+ weighted_evects = evects[None, :, :] * coefs[:, None, :]
41
+ natural_HKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
42
+
43
+ if scaled:
44
+ inv_scaling = coefs.sum(1)
45
+ return (1 / inv_scaling)[None, :] * natural_HKS
46
+
47
+ else:
48
+ return natural_HKS
49
+
50
+
51
+ def lm_HKS(evals, evects, landmarks, time_list, scaled=False):
52
+ evals_s = np.asarray(evals).flatten()
53
+ t_list = np.asarray(time_list).flatten()
54
+
55
+ coefs = np.exp(-np.outer(t_list, evals_s))
56
+ weighted_evects = evects[None, landmarks, :] * coefs[:, None, :]
57
+
58
+ landmarks_HKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)
59
+
60
+ if scaled:
61
+ inv_scaling = coefs.sum(1)
62
+ landmarks_HKS = (1 / inv_scaling)[None, :, None] * landmarks_HKS
63
+
64
+ return landmarks_HKS.reshape(-1, evects.shape[0]).T
65
+
66
+
67
+ def auto_HKS(evals, evects, num_T, landmarks=None, scaled=True):
68
+ abs_ev = sorted(np.abs(evals))
69
+ t_list = np.geomspace(4 * np.log(10) / abs_ev[-1], 4 * np.log(10) / abs_ev[1], num_T)
70
+
71
+ if landmarks is None:
72
+ return HKS(abs_ev, evects, t_list, scaled=scaled)
73
+ else:
74
+ return lm_HKS(abs_ev, evects, landmarks, t_list, scaled=scaled)
75
+
76
+
77
+ # https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/WKS_functions.py
78
+ def WKS(evals, evects, energy_list, sigma, scaled=False):
79
+ assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
80
+
81
+ evals = np.asarray(evals).flatten()
82
+ indices = np.where(evals > 1e-5)[0].flatten()
83
+ evals = evals[indices]
84
+ evects = evects[:, indices]
85
+
86
+ e_list = np.asarray(energy_list)
87
+ coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2 * sigma**2))
88
+
89
+ weighted_evects = evects[None, :, :] * coefs[:, None, :]
90
+
91
+ natural_WKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
92
+
93
+ if scaled:
94
+ inv_scaling = coefs.sum(1)
95
+ return (1 / inv_scaling)[None, :] * natural_WKS
96
+
97
+ else:
98
+ return natural_WKS
99
+
100
+
101
+ def lm_WKS(evals, evects, landmarks, energy_list, sigma, scaled=False):
102
+ assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
103
+
104
+ evals = np.asarray(evals).flatten()
105
+ indices = np.where(evals > 1e-2)[0].flatten()
106
+ evals = evals[indices]
107
+ evects = evects[:, indices]
108
+
109
+ e_list = np.asarray(energy_list)
110
+ coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2 * sigma**2))
111
+ weighted_evects = evects[None, landmarks, :] * coefs[:, None, :]
112
+
113
+ landmarks_WKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)
114
+
115
+ if scaled:
116
+ inv_scaling = coefs.sum(1)
117
+ landmarks_WKS = ((1 / inv_scaling)[None, :, None] * landmarks_WKS)
118
+
119
+ return landmarks_WKS.reshape(-1, evects.shape[0]).T
120
+
121
+
122
+ def auto_WKS(evals, evects, num_E, landmarks=None, scaled=True):
123
+ abs_ev = sorted(np.abs(evals))
124
+
125
+ e_min, e_max = np.log(abs_ev[1]), np.log(abs_ev[-1])
126
+ sigma = 7 * (e_max - e_min) / num_E
127
+
128
+ e_min += 2 * sigma
129
+ e_max -= 2 * sigma
130
+
131
+ energy_list = np.linspace(e_min, e_max, num_E)
132
+
133
+ if landmarks is None:
134
+ return WKS(abs_ev, evects, energy_list, sigma, scaled=scaled)
135
+ else:
136
+ return lm_WKS(abs_ev, evects, landmarks, energy_list, sigma, scaled=scaled)
137
+
138
+
139
+ def compute_hks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35):
140
+ feats = auto_HKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
141
+ feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
142
+ feats_norm2 = np.einsum('np,np->p', feats, np.expand_dims(mass, 1) * feats).flatten()
143
+ feats /= np.expand_dims(np.sqrt(feats_norm2), 0)
144
+ return feats.astype(np.float32)
145
+
146
+
147
+ def compute_wks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35):
148
+ feats = auto_WKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
149
+ feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
150
+ feats_norm2 = np.einsum('np,np->p', feats, np.expand_dims(mass, 1) * feats).flatten()
151
+ feats /= np.expand_dims(np.sqrt(feats_norm2), 0)
152
+ return feats.astype(np.float32)
153
+
154
+
155
+ def compute_geodesic_distance(V, F, vindices):
156
+ solver = pp3d.MeshHeatMethodDistanceSolver(np.asarray(V, dtype=np.float32), np.asarray(F, dtype=np.int32))
157
+ dists = [solver.compute_distance(vid)[vindices] for vid in vindices]
158
+ dists = np.stack(dists, axis=0)
159
+ assert dists.ndim == 2
160
+ return dists.astype(np.float32)
161
+
162
+
163
+ def compute_vertex_normals(vertices, faces):
164
+ mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
165
+ mesh.compute_vertex_normals()
166
+ return np.asarray(mesh.vertex_normals, dtype=np.float32)
167
+
168
+
169
+ def compute_surface_area(vertices, faces):
170
+ mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
171
+ return mesh.get_surface_area()
172
+
173
+ def numpy_to_open3d_mesh(V, F):
174
+ # Create an empty TriangleMesh object
175
+ mesh = o3d.geometry.TriangleMesh()
176
+ # Set vertices
177
+ mesh.vertices = o3d.utility.Vector3dVector(V)
178
+ # Set triangles
179
+ mesh.triangles = o3d.utility.Vector3iVector(F)
180
+ return mesh
181
+
182
+
183
+ def load_mesh(filepath, scale=True, return_vnormals=False):
184
+ if os.path.splitext(filepath)[1] == ".obj": #Avoid pre process from open3d
185
+ V, F = pp3d.read_mesh(filepath)
186
+ mesh = numpy_to_open3d_mesh(V, F)
187
+ else:
188
+ mesh = o3d.io.read_triangle_mesh(filepath)
189
+
190
+ tmat = np.identity(4, dtype=np.float32)
191
+ center = mesh.get_center()
192
+ tmat[:3, 3] = -center
193
+ if scale:
194
+ smat = np.identity(4, dtype=np.float32)
195
+ area = mesh.get_surface_area()
196
+ smat[:3, :3] = np.identity(3, dtype=np.float32) / math.sqrt(area)
197
+ tmat = smat @ tmat
198
+ mesh.transform(tmat)
199
+
200
+ vertices = np.asarray(mesh.vertices, dtype=np.float32)
201
+ faces = np.asarray(mesh.triangles, dtype=np.int32)
202
+ if return_vnormals:
203
+ mesh.compute_vertex_normals()
204
+ vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
205
+ return vertices, faces, vnormals
206
+ else:
207
+ return vertices, faces
208
+
209
+
210
+ def save_mesh(filepath, vertices, faces):
211
+ mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
212
+ o3d.io.write_triangle_mesh(filepath, mesh)
213
+
214
+
215
+ def load_geodist(filepath):
216
+ data = loadmat(filepath)
217
+ if 'geodist' in data and 'sqrt_area' in data:
218
+ geodist = np.asarray(data['geodist'], dtype=np.float32)
219
+ sqrt_area = data['sqrt_area'].toarray().flatten()[0]
220
+ elif 'G' in data and 'SQRarea' in data:
221
+ geodist = np.asarray(data['G'], dtype=np.float32)
222
+ sqrt_area = data['SQRarea'].flatten()[0]
223
+ else:
224
+ raise RuntimeError(f'File {filepath} does not have geodesics data.')
225
+ return geodist, sqrt_area
226
+
227
+
228
+ def farthest_point_sampling(points, max_points, random_start=True):
229
+ import torch_cluster
230
+
231
+ if torch.is_tensor(points):
232
+ device = points.device
233
+ is_batch = points.dim() == 3
234
+ if not is_batch:
235
+ points = torch.unsqueeze(points, dim=0)
236
+ assert points.dim() == 3
237
+
238
+ B, N, D = points.size()
239
+ assert N >= max_points
240
+ bindices = torch.flatten(torch.unsqueeze(torch.arange(B), 1).repeat(1, N)).long().to(device)
241
+ points = torch.reshape(points, (B * N, D)).float()
242
+ sindices = torch_cluster.fps(points, bindices, ratio=float(max_points) / N, random_start=random_start)
243
+ if is_batch:
244
+ sindices = torch.reshape(sindices, (B, max_points)) - torch.unsqueeze(torch.arange(B), 1).long().to(device) * N
245
+ elif isinstance(points, np.ndarray):
246
+ device = torch.device('cpu')
247
+ is_batch = points.ndim == 3
248
+ if not is_batch:
249
+ points = np.expand_dims(points, axis=0)
250
+ assert points.ndim == 3
251
+
252
+ B, N, D = points.shape
253
+ assert N >= max_points
254
+ bindices = np.tile(np.expand_dims(np.arange(B), 1), (1, N)).flatten()
255
+ bindices = torch.as_tensor(bindices, device=device).long()
256
+ points = torch.as_tensor(np.reshape(points, (B * N, D)), device=device).float()
257
+ sindices = torch_cluster.fps(points, bindices, ratio=float(max_points) / N, random_start=random_start)
258
+ sindices = sindices.cpu().numpy()
259
+ if is_batch:
260
+ sindices = np.reshape(sindices, (B, max_points)) - np.expand_dims(np.arange(B), 1) * N
261
+ else:
262
+ raise NotImplementedError
263
+ return sindices
264
+
265
+
266
+ def lstsq(A, B):
267
+ assert A.ndim == B.ndim == 2
268
+ sols = scipy.linalg.lstsq(A, B)[0]
269
+ return sols
270
+
shape_data/dt4dinter.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ import itertools
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+
8
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
9
+ if ROOT_DIR not in sys.path:
10
+ sys.path.append(ROOT_DIR)
11
+
12
+ from .dt4dintra import IGNORED_CATEGORIES
13
+ from .dt4dintra import ShapeDataset
14
+ from .faust import ShapePairDataset as FaustShapePairDataset
15
+ from utils.mesh import list_files
16
+
17
+ #IGNORED_CATEGORIES = ["drake", "mannequin", "ninja", "prisoner", "zlorp", "pumpkinhulk"]
18
+ IGNORED_CATEGORIES = ["pumpkinhulk"]
19
+ class ShapePairDataset(FaustShapePairDataset):
20
+
21
+ def _init(self):
22
+ self.name_id_map = self.shape_data.get_name_id_map()
23
+ categories = defaultdict(list)
24
+ for sname in self.name_id_map.keys():
25
+ categories[sname.split('/')[0]].append(sname)
26
+ self.pair_indices = list()
27
+ for filename in list_files(osp.join(self.corr_dir, 'cross_category_corres'), '*.vts', alphanum_sort=False):
28
+ cname0, cname1 = filename[:-4].split('_')
29
+ if cname0 in IGNORED_CATEGORIES or cname1 in IGNORED_CATEGORIES:
30
+ continue
31
+ for sname0 in categories[cname0]:
32
+ for sname1 in categories[cname1]:
33
+ self.pair_indices.append((self.name_id_map[sname0], self.name_id_map[sname1]))
34
+
35
+ def _load_corr_gt(self, sdict0, sdict1):
36
+ sname0 = sdict0['name']
37
+ sname1 = sdict1['name']
38
+ cname0 = sname0.split('/')[0]
39
+ cname1 = sname1.split('/')[0]
40
+ assert cname0 != cname1
41
+ lmk01 = self._load_corr_file(f'cross_category_corres/{cname0}_{cname1}')
42
+ corr0 = self._load_corr_file(sname0)
43
+ corr1 = self._load_corr_file(sname1)
44
+ corr_gt = np.stack((corr0, corr1[lmk01]), axis=1)
45
+ return corr_gt
46
+
47
+ def _load_corr_file(self, sname):
48
+ corr_path = osp.join(self.corr_dir, f'{sname}.vts')
49
+ corr = np.loadtxt(corr_path, dtype=np.int32)
50
+ return corr - 1
shape_data/dt4dintra.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ import itertools
5
+ from pathlib import Path
6
+ from collections import defaultdict
7
+
8
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
9
+ if ROOT_DIR not in sys.path:
10
+ sys.path.append(ROOT_DIR)
11
+
12
+ from .faust import ShapeDataset as FaustShapeDataset
13
+ from .faust import ShapePairDataset as FaustShapePairDataset
14
+ from utils.utils_legacy import read_lines
15
+ IGNORED_CATEGORIES = ['pumpkinhulk']
16
+
17
+
18
+ class ShapeDataset(FaustShapeDataset):
19
+ TRAIN_IDX = None
20
+ TEST_IDX = None
21
+ NAME = "DT4D"
22
+
23
+ def _get_file_list(self):
24
+ if self.mode.startswith('train'):
25
+ file_list = read_lines(osp.join(self.shape_dir, '..', 'train.txt'))
26
+ elif self.mode.startswith('test'):
27
+ file_list = read_lines(osp.join(self.shape_dir, '..', 'test.txt'))
28
+ else:
29
+ raise RuntimeError(f'Mode {self.mode} is not supported.')
30
+ shape_list = [fn + '.ply' for fn in file_list]
31
+ return shape_list
32
+
33
+
34
+ class ShapePairDataset(FaustShapePairDataset):
35
+
36
+ def _init(self):
37
+ self.name_id_map = self.shape_data.get_name_id_map()
38
+ categories = defaultdict(list)
39
+ for sname in self.name_id_map.keys():
40
+ categories[sname.split('/')[0]].append(sname)
41
+ self.pair_indices = list()
42
+ for cname, clist in categories.items():
43
+ if cname in IGNORED_CATEGORIES:
44
+ continue
45
+ for pname in itertools.combinations(clist, 2):
46
+ self.pair_indices.append((self.name_id_map[pname[0]], self.name_id_map[pname[1]]))
47
+
48
+ def _load_corr_gt(self, sdict0, sdict1):
49
+ corr0 = self._load_corr_file(sdict0['name'])
50
+ corr1 = self._load_corr_file(sdict1['name'])
51
+ corr_gt = np.stack((corr0, corr1), axis=1)
52
+ return corr_gt
53
+
54
+ def _load_corr_file(self, sname):
55
+ corr_path = osp.join(self.corr_dir, f'{sname}.vts')
56
+ corr = np.loadtxt(corr_path, dtype=np.int32)
57
+ return corr - 1
shape_data/faust.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import itertools
4
+ import math
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from pathlib import Path
9
+ import potpourri3d as pp3d
10
+ import open3d as o3d
11
+ from utils.geometry import get_operators, load_operators
12
+ from utils.surfaces import Surface
13
+ from utils.utils_func import may_create_folder
14
+ from utils.mesh import find_mesh_files
15
+
16
+ # def opt_rot_points(pts_1, pts_2, device="cuda:0"):
17
+ # center_1 = pts_1.mean(dim=0)
18
+ # pts_c1 = pts_1 - center_1
19
+ # center_2 = pts_2.mean(dim=0)
20
+ # pts_c2 = pts_2 - center_2
21
+ # to_sum = pts_c1[:, :, None] * pts_c2[:, None, :]
22
+ # A = pts_c1.T @ pts_c2
23
+ # #A = to_sum.sum(axis=0)
24
+ # u, _, v = torch.linalg.svd(A)
25
+ # a = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, torch.sign(torch.linalg.det(A))]]).float().to(device)
26
+ # O = u @ a @ v
27
+ # return O.T
28
+
29
+ def opt_rot_points(pts_1, pts_2):
30
+ center_1 = pts_1.mean(axis=0)
31
+ pts_c1 = pts_1 - center_1
32
+ center_2 = pts_2.mean(axis=0)
33
+ pts_c2 = pts_2 - center_2
34
+
35
+ A = np.dot(pts_c1.T, pts_c2)
36
+ u, _, v = np.linalg.svd(A)
37
+ a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, np.sign(np.linalg.det(A))]])
38
+ O = u @ a @ v
39
+ return O.T
40
+
41
+ def compute_vertex_normals(vertices, faces):
42
+ mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
43
+ mesh.compute_vertex_normals()
44
+ return np.asarray(mesh.vertex_normals, dtype=np.float32)
45
+
46
+
47
+ def numpy_to_open3d_mesh(V, F):
48
+ # Create an empty TriangleMesh object
49
+ mesh = o3d.geometry.TriangleMesh()
50
+ # Set vertices
51
+ mesh.vertices = o3d.utility.Vector3dVector(V)
52
+ # Set triangles
53
+ mesh.triangles = o3d.utility.Vector3iVector(F)
54
+ return mesh
55
+
56
+ def open_mesh(path):
57
+ """
58
+ Tries to open a mesh.
59
+ If it fails, try .ply, .obj, and .off alternatives.
60
+
61
+ Parameters
62
+ ----------
63
+ path : str
64
+ Path of the mesh
65
+ Returns
66
+ -------
67
+ mesh or None
68
+ Loaded mesh (V, F format) if successful, else None
69
+ """
70
+ p = Path(path)
71
+ base, ext = p.with_suffix(""), p.suffix
72
+ tried_exts = [ext, ".ply", ".obj", ".off"]
73
+ for e in tried_exts:
74
+ path = base.with_suffix(e)
75
+ if Path.exists(path):
76
+ try:
77
+ temp = pp3d.read_mesh(str(path))
78
+ return temp
79
+ except Exception as err:
80
+ print(f"Failed loading {path}: {err}")
81
+ return None
82
+
83
+
84
+ KEYS = ['vertices', 'faces', 'frames', 'mass', 'L', 'evals', 'evecs', 'gradX', 'gradY', 'hks', 'wks', 'idx', 'name']
85
+
86
+
87
+ class ShapeDataset(Dataset):
88
+ TRAIN_IDX = np.arange(0, 80)
89
+ TEST_IDX = np.arange(80, 100)
90
+ NAME = "FAUST"
91
+ def __init__(self,
92
+ shape_dir,
93
+ cache_dir,
94
+ mode,
95
+ oriented=False,
96
+ rot_auto=False,
97
+ num_eigenbasis=256,
98
+ laplacian_type='mesh',
99
+ feature_type=None,
100
+ **kwargs):
101
+ super().__init__()
102
+
103
+ self.shape_dir = shape_dir
104
+ self.cache_dir = cache_dir
105
+ self.mode = mode
106
+ self.oriented = oriented
107
+ if self.oriented:
108
+ self.NAME = self.NAME + "_ori"
109
+ self.num_eigenbasis = num_eigenbasis
110
+ self.laplacian_type = laplacian_type
111
+ self.feature_type = feature_type
112
+ for k, w in kwargs.items():
113
+ setattr(self, k, w)
114
+
115
+ print(f'Loading {mode} data from {shape_dir}')
116
+ self.shape_list = self._get_file_list()
117
+ self._prepare()
118
+
119
+ self.randg = np.random.RandomState(0)
120
+
121
+ def _get_file_list(self):
122
+ path_list = find_mesh_files(Path(self.shape_dir), alphanum_sort=True)
123
+ file_list = [f.name for f in path_list]
124
+ if self.mode.startswith('train'):
125
+ assert self.TRAIN_IDX is not None
126
+ shape_list = [file_list[idx] for idx in self.TRAIN_IDX]
127
+ elif self.mode.startswith('test'):
128
+ assert self.TEST_IDX is not None
129
+ shape_list = [file_list[idx] for idx in self.TEST_IDX]
130
+ else:
131
+ raise RuntimeError(f'Mode {self.mode} is not supported.')
132
+ return shape_list
133
+
134
+ def _load_mesh(self, filepath, scale=True, return_vnormals=False):
135
+ V, F = open_mesh(filepath)
136
+ mesh = numpy_to_open3d_mesh(V, F)
137
+
138
+ tmat = np.identity(4, dtype=np.float32)
139
+ center = mesh.get_center()
140
+ tmat[:3, 3] = -center
141
+ if scale:
142
+ smat = np.identity(4, dtype=np.float32)
143
+ area = mesh.get_surface_area()
144
+ smat[:3, :3] = np.identity(3, dtype=np.float32) / math.sqrt(area)
145
+ tmat = smat @ tmat
146
+ mesh.transform(tmat)
147
+
148
+ vertices = np.asarray(mesh.vertices, dtype=np.float32)
149
+ faces = np.asarray(mesh.triangles, dtype=np.int32)
150
+ if return_vnormals:
151
+ mesh.compute_vertex_normals()
152
+ vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
153
+ return vertices, faces, vnormals
154
+ else:
155
+ return vertices, faces
156
+
157
+ def _prepare(self):
158
+ may_create_folder(self.cache_dir)
159
+ for sid, sname in enumerate(self.shape_list):
160
+ cache_prefix = osp.join(self.cache_dir, self.NAME, f'{sname[:-4]}_{self.laplacian_type}_{self.num_eigenbasis}k')
161
+ cache_path = cache_prefix + '_0n.npz'
162
+ if not Path(cache_path).is_file():
163
+ vertices_np, faces_np, vertex_normals_np = self._load_mesh(osp.join(self.shape_dir, sname),
164
+ scale=True,
165
+ return_vnormals=True)
166
+
167
+ if self.laplacian_type == 'mesh':
168
+ _ = get_operators(torch.from_numpy(vertices_np).float(), torch.from_numpy(faces_np).long(), self.num_eigenbasis, cache_path=cache_path)
169
+ # elif self.laplacian_type == 'pcd':
170
+ # compute_operators(vertices_np, np.asarray([], dtype=np.int32), vertex_normals_np, self.num_eigenbasis,
171
+ # cache_path)
172
+ else:
173
+ raise RuntimeError(f'Basis type {self.laplacian_type} is not supported')
174
+
175
+ # if self.aug_noise_type is not None and self.aug_noise_type != 'naive':
176
+ # max_magnitude, max_levels = self.aug_noise_args[:2]
177
+ # randg = np.random.RandomState(sid)
178
+ # for nlevel in range(1, max_levels + 1):
179
+ # cache_path = cache_prefix + f'_{nlevel}n.npz'
180
+ # if Path(cache_path).is_file():
181
+ # continue
182
+ # noise_mag = max_magnitude * nlevel / max_levels
183
+ # noise_mat = np.clip(noise_mag * randg.randn(vertices_np.shape[0], vertices_np.shape[1]), -noise_mag,
184
+ # noise_mag)
185
+ # vertices_noise_np = vertices_np + noise_mat.astype(vertices_np.dtype)
186
+ # vertex_normals_noise_np = compute_vertex_normals(vertices_noise_np, faces_np)
187
+
188
+ # if self.laplacian_type == 'mesh':
189
+ # compute_operators(vertices_noise_np, faces_np, vertex_normals_noise_np, self.num_eigenbasis, cache_path)
190
+ # elif self.laplacian_type == 'pcd':
191
+ # compute_operators(vertices_noise_np, np.asarray([], dtype=np.int32), vertex_normals_noise_np,
192
+ # self.num_eigenbasis, cache_path)
193
+ # else:
194
+ # raise RuntimeError(f'Basis type {self.laplacian_type} is not supported')
195
+
196
+ def __getitem__(self, idx):
197
+ sname = self.shape_list[idx]
198
+
199
+ cache_prefix = osp.join(self.cache_dir, self.NAME, f'{sname[:-4]}_{self.laplacian_type}_{self.num_eigenbasis}k')
200
+ cache_path = cache_prefix + '_0n.npz'
201
+
202
+ assert Path(cache_path).is_file()
203
+
204
+ sdict = load_operators(cache_path)
205
+ sdict['idx'] = idx
206
+ sdict['name'] = sname[:-4]
207
+
208
+ if self.feature_type is not None:
209
+ sdict['feats'] = np.concatenate([sdict[ft] for ft in self.feature_type.split('_')], axis=-1)
210
+ vertices_np, _, _ = self._load_mesh(osp.join(self.shape_dir, sname), scale=True, return_vnormals=True)
211
+ sdict['vertices'] = vertices_np
212
+ sdict = self._centering(sdict)
213
+ return sdict
214
+
215
+ def __len__(self):
216
+ return len(self.shape_list)
217
+
218
+ def _centering(self, sdict):
219
+ vertices, areas = sdict['vertices'], sdict["mass"]
220
+ center = (vertices*areas[:, None]).sum()/areas.sum()
221
+ sdict['vertices'] = vertices - center
222
+ return sdict
223
+
224
+ def _random_noise_naive(self, sdict, randg, args):
225
+ vertices = sdict['vertices']
226
+ dtype = vertices.dtype
227
+ shape = vertices.shape
228
+ std, clip = args
229
+
230
+ noise = np.clip(std * randg.randn(*shape), -clip, clip)
231
+ sdict['vertices'] = vertices + noise.astype(dtype)
232
+ return sdict
233
+
234
+ def _random_rotation(self, sdict, randg, axes, args):
235
+ vertices = sdict['vertices']
236
+ dtype = vertices.dtype
237
+
238
+ max_x, max_y, max_z = args
239
+ if 'x' in axes:
240
+ anglex = randg.rand() * max_x * np.pi / 180.0
241
+ cosx = np.cos(anglex)
242
+ sinx = np.sin(anglex)
243
+ Rx = np.asarray([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]], dtype=dtype)
244
+ else:
245
+ Rx = np.eye(3, dtype=dtype)
246
+
247
+ if 'y' in axes:
248
+ angley = randg.rand() * max_y * np.pi / 180.0
249
+ cosy = np.cos(angley)
250
+ siny = np.sin(angley)
251
+ Ry = np.asarray([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]], dtype=dtype)
252
+ else:
253
+ Ry = np.eye(3, dtype=dtype)
254
+
255
+ if 'z' in axes:
256
+ anglez = randg.rand() * max_z * np.pi / 180.0
257
+ cosz = np.cos(anglez)
258
+ sinz = np.sin(anglez)
259
+ Rz = np.asarray([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]], dtype=dtype)
260
+ else:
261
+ Rz = np.eye(3, dtype=dtype)
262
+
263
+ Rxyz = randg.permutation(np.stack((Rx, Ry, Rz), axis=0))
264
+ R = Rxyz[2] @ Rxyz[1] @ Rxyz[0]
265
+ sdict['vertices'] = vertices @ R.T
266
+
267
+ return sdict
268
+
269
+ def _random_scaling(self, sdict, randg, args):
270
+ scale_min, scale_max = args
271
+ vertices = sdict['vertices']
272
+ scale = scale_min + randg.rand(1, 3) * (scale_max - scale_min)
273
+ sdict['vertices'] = vertices * scale
274
+ return sdict
275
+
276
+ def get_name_id_map(self):
277
+ return {sname[:-4]: sid for sid, sname in enumerate(self.shape_list)}
278
+
279
+
280
+ class ShapePairDataset(Dataset):
281
+
282
+ def __init__(self, corr_dir, mode, shape_data, rotate=False, **kwargs):
283
+ super().__init__()
284
+ self.corr_dir = corr_dir
285
+ self.mode = mode
286
+ self.shape_data = shape_data
287
+ self.rotate = rotate
288
+ if self.shape_data.oriented and self.rotate:
289
+ self.rotate = False
290
+ for k, w in kwargs.items():
291
+ setattr(self, k, w)
292
+
293
+ self._init()
294
+
295
+ self.randg = np.random.RandomState(0)
296
+
297
+ def _init(self):
298
+ self.name_id_map = self.shape_data.get_name_id_map()
299
+ self.pair_indices = list(itertools.combinations(range(len(self.shape_data)), 2))
300
+
301
+ def __getitem__(self, idx):
302
+ pidx = self.pair_indices[idx]
303
+ sdict0 = self.shape_data[pidx[0]]
304
+ sdict1 = self.shape_data[pidx[1]]
305
+ return self._prepare_pair(sdict0, sdict1)
306
+
307
+ def get_by_names(self, sname0, sname1):
308
+ sdict0 = self.shape_data[self.name_id_map[sname0]]
309
+ sdict1 = self.shape_data[self.name_id_map[sname1]]
310
+ return self._prepare_pair(sdict0, sdict1)
311
+
312
+ def _prepare_pair(self, sdict0, sdict1):
313
+ corr_gt = self._load_corr_gt(sdict0, sdict1)
314
+ # for fmap_size in self.fmap_sizes:
315
+ # fmap01_gt = pmap_to_fmap(sdict0['evecs'][:, :fmap_size], sdict1['evecs'][:, :fmap_size], corr_gt)
316
+ # pdict[f'fmap01_{fmap_size}_gt'] = fmap01_gt
317
+
318
+ # for idx in range(2):
319
+ # indices_sel = farthest_point_sampling(pdict[f'vertices{idx}'], self.num_corrs, random_start=is_train)
320
+ # for k in ['vertices', 'evecs', 'feats']:
321
+ # kid = f'{k}{idx}'
322
+ # if kid in pdict:
323
+ # pdict[kid + '_sub'] = pdict[kid][indices_sel, :]
324
+ # if self.use_geodists:
325
+ # geodists = compute_geodesic_distance(pdict[f'vertices{idx}'], pdict[f'faces{idx}'], indices_sel)
326
+ # pdict[f'geodists{idx}_sub'] = geodists
327
+ # pdict[f'vindices{idx}_sub'] = indices_sel
328
+
329
+ # fmap_size = self.fmap_sizes[-1]
330
+ # corr_gt_sub = fmap_to_pmap(pdict['evecs0_sub'][:, :fmap_size], pdict['evecs1_sub'][:, :fmap_size],
331
+ # pdict[f'fmap01_{fmap_size}_gt'])
332
+ # pdict['corr_gt_sub'] = corr_gt_sub
333
+
334
+ # if is_train:
335
+ # fmap_size = self.fmap_sizes[0]
336
+ # axis = self.randg.choice([0, 1]).item()
337
+ # max_bases = fmap_size // 2
338
+ # noise_ratio = 0.5
339
+ # if self.randg.rand() > 0.5:
340
+ # pdict[f'fmap01_{fmap_size}'] = self._random_scale(pdict[f'fmap01_{fmap_size}_gt'], self.randg, axis, max_bases)
341
+ # else:
342
+ # pdict[f'fmap01_{fmap_size}'] = self._random_noise(pdict[f'fmap01_{fmap_size}_gt'], self.randg, axis, max_bases,
343
+ # noise_ratio)
344
+ # else:
345
+ # if self.corr_loader is not None:
346
+ # corr_init = self.corr_loader.get_by_names(sdict0['name'], sdict1['name'])
347
+ # assert corr_init.ndim == 2 and len(corr_init) == len(sdict1['vertices'])
348
+ # fmap_size = self.fmap_sizes[0]
349
+ # fmap01_init = pmap_to_fmap(sdict0['evecs'][:, :fmap_size], sdict1['evecs'][:, :fmap_size], corr_init)
350
+ # pdict[f'fmap01_{fmap_size}'] = fmap01_init
351
+ # pdict['pmap10'] = corr_init[:, 0]
352
+
353
+ vts_1, vts_2 = corr_gt[:, 0], corr_gt[:, 1]
354
+ shape_dict, target_dict = sdict0, sdict1
355
+
356
+ if self.rotate:
357
+ pts_1, pts_2 = shape_dict['vertices'][vts_1], target_dict['vertices'][vts_2]
358
+ rot = opt_rot_points(pts_1, pts_2).astype(np.float32)#, device="cuda")
359
+ target_dict['vertices'] = target_dict['vertices'] @ rot
360
+ target_surf = Surface(FV=[target_dict['faces'], target_dict['vertices']])
361
+ target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().cuda()
362
+
363
+ shape_surf = Surface(FV=[shape_dict['faces'], shape_dict['vertices']])
364
+ map_info = (shape_dict['name'], vts_1, vts_2)
365
+ return shape_dict, shape_surf, target_dict, target_surf, target_normals, map_info
366
+
367
+ def _random_scale(self, fmap, randg, axis, max_bases):
368
+ assert max_bases > 1
369
+ assert axis in [0, 1]
370
+ num_bases = randg.randint(1, max_bases)
371
+ ids = randg.choice(fmap.shape[axis], num_bases, replace=False)
372
+ fmap_out = np.copy(fmap)
373
+ if axis == 0:
374
+ fmap_out[ids, :] *= (randg.rand(num_bases, 1) * 2 - 1)
375
+ else:
376
+ fmap_out[:, ids] *= (randg.rand(1, num_bases) * 2 - 1)
377
+ return fmap_out
378
+
379
+ def _random_noise(self, fmap, randg, axis, max_bases, max_ratio):
380
+ assert max_bases > 1
381
+ assert axis in [0, 1]
382
+ num_bases = randg.randint(1, max_bases)
383
+ ids = randg.choice(fmap.shape[axis], num_bases, replace=False)
384
+ fmap_out = np.copy(fmap)
385
+ ratio = randg.rand() * max_ratio
386
+ if axis == 0:
387
+ maxvals = np.amax(np.abs(fmap_out[ids, :]), axis=1 - axis, keepdims=True)
388
+ noise = ratio * maxvals * randg.randn(num_bases, fmap.shape[1 - axis])
389
+ fmap_out[ids, :] += noise
390
+ else:
391
+ maxvals = np.amax(np.abs(fmap_out[:, ids]), axis=1 - axis, keepdims=True)
392
+ noise = ratio * maxvals * randg.randn(fmap.shape[1 - axis], num_bases)
393
+ fmap_out[:, ids] += noise
394
+ return fmap_out
395
+
396
+ def _load_corr_gt(self, sdict0, sdict1):
397
+ corr0 = self._load_corr_file(sdict0['name'])
398
+ corr1 = self._load_corr_file(sdict1['name'])
399
+ corr_gt = np.stack((corr0, corr1), axis=1)
400
+ return corr_gt
401
+
402
+ def _load_corr_file(self, sname):
403
+ corr_path = osp.join(self.corr_dir, f'{sname}.vts')
404
+ corr = np.loadtxt(corr_path, dtype=np.int32)
405
+ return corr - 1
406
+
407
+ def __len__(self):
408
+ return len(self.pair_indices)
shape_data/scape.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
7
+ if ROOT_DIR not in sys.path:
8
+ sys.path.append(ROOT_DIR)
9
+
10
+ from shape_data.faust import ShapeDataset as FaustShapeDataset
11
+ from shape_data.faust import ShapePairDataset
12
+
13
+
14
+ class ShapeDataset(FaustShapeDataset):
15
+ TRAIN_IDX = np.arange(0, 51)
16
+ TEST_IDX = np.arange(51, 71)
17
+ NAME = "SCAPE"
shape_data/shrec19.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
7
+ if ROOT_DIR not in sys.path:
8
+ sys.path.append(ROOT_DIR)
9
+
10
+ from shape_data.faust import ShapeDataset as FaustShapeDataset
11
+ from shape_data.faust import ShapePairDataset as FaustShapePairDataset
12
+ from utils.io import list_files
13
+
14
+
15
+ class ShapeDataset(FaustShapeDataset):
16
+ TRAIN_IDX = None
17
+ TEST_IDX = np.arange(44)
18
+
19
+ NAME = "SHREC"
20
+
21
+ class ShapePairDataset(FaustShapePairDataset):
22
+
23
+ def _init(self):
24
+ assert self.mode.startswith('test')
25
+
26
+ self.name_id_map = self.shape_data.get_name_id_map()
27
+ self.pair_indices = list()
28
+ for corr_filename in list_files(self.corr_dir, '*.map', alphanum_sort=True):
29
+ sname0, sname1 = corr_filename[:-4].split('_')
30
+ if sname0 == '40' or sname1 == '40':
31
+ continue
32
+ self.pair_indices.append((self.name_id_map[sname1], self.name_id_map[sname0]))
33
+
34
+ def _load_corr_gt(self, sdict0, sdict1):
35
+ pmap10 = self._load_corr_file(sdict1['name'], sdict0['name'])
36
+ corr_gt = np.stack((pmap10, np.arange(len(pmap10))), axis=1)
37
+ return corr_gt
38
+
39
+ def _load_corr_file(self, sname0, sname1):
40
+ corr_path = osp.join(self.corr_dir, f'{sname0}_{sname1}.map')
41
+ corr = np.loadtxt(corr_path, dtype=np.int32)
42
+ return corr - 1
shape_data/smalr.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
7
+ if ROOT_DIR not in sys.path:
8
+ sys.path.append(ROOT_DIR)
9
+
10
+ from shape_data.faust import ShapeDataset as FaustShapeDataset
11
+ from shape_data.faust import ShapePairDataset
12
+ from utils.mesh import find_mesh_files
13
+
14
+
15
+ class ShapeDataset(FaustShapeDataset):
16
+ TRAIN_IDX = None
17
+ TEST_IDX = None
18
+ NAME = "SMAL"
19
+
20
+ def _get_file_list(self):
21
+ if self.mode.startswith('train'):
22
+ categories = ['cow', 'dog', 'fox', 'lion', 'wolf']
23
+ elif self.mode.startswith('test'):
24
+ categories = ['cougar', 'hippo', 'horse']
25
+ else:
26
+ raise RuntimeError(f'Mode {self.mode} is not supported.')
27
+
28
+ path_list = find_mesh_files(Path(self.shape_dir), alphanum_sort=True)
29
+ file_list = [f.name for f in path_list]
30
+ shape_list = [fn for fn in file_list if fn.split('_')[0] in categories]
31
+ return shape_list
shape_data/tosca.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import sys
3
+ import numpy as np
4
+ import re
5
+ from pathlib import Path
6
+ from itertools import permutations as pmt
7
+
8
+ ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
9
+ if ROOT_DIR not in sys.path:
10
+ sys.path.append(ROOT_DIR)
11
+
12
+ from shape_data.faust import ShapeDataset as FaustShapeDataset
13
+ from shape_data.faust import ShapePairDataset as FaustShapePairDataset
14
+ from utils.io import list_files
15
+
16
+ def contains_any_regex(substrings, ext, texts):
17
+ pattern = re.compile('|'.join(map(re.escape, substrings))) # Compile regex once
18
+ return [text for text in texts if bool(pattern.search(text)) and (ext in text)] # Apply to all texts efficiently
19
+
20
+
21
+ class ShapeDataset(FaustShapeDataset):
22
+ TRAIN_IDX = None
23
+ TEST_IDX = None
24
+
25
+ def _get_file_list(self):
26
+ if self.mode.startswith('train'):
27
+ categories = None
28
+ elif self.mode.startswith('test'):
29
+ categories = ['cat', 'dog', 'horse', 'wolf']
30
+ else:
31
+ raise RuntimeError(f'Mode {self.mode} is not supported.')
32
+ file_list = list_files(self.shape_dir, '*.off', alphanum_sort=True)
33
+ shape_list = contains_any_regex(categories, ".off", file_list)
34
+ return shape_list
35
+
36
+
37
+ class ShapePairDataset(FaustShapePairDataset):
38
+ categories = ['cat', 'dog', 'horse', 'wolf']
39
+
40
+ def _init(self):
41
+ assert self.mode.startswith('test')
42
+ self.name_id_map = self.shape_data.get_name_id_map()
43
+ self.pair_indices = list()
44
+ for cat in self.categories:
45
+ shape_list_temp = [self.name_id_map[fn] for fn in self.name_id_map if cat in fn]
46
+ self.pair_indices += list(pmt(shape_list_temp, 2))
snk/__init__.py ADDED
File without changes
snk/loss.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trimesh.graph import face_adjacency
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class PrismRegularizationLoss(nn.Module):
6
+ """
7
+ Calculate the loss based on the PriMo energy, as described in the paper:
8
+ PriMo: Coupled Prisms for Intuitive Surface Modeling
9
+ """
10
+ def __init__(self, primo_h):
11
+ super().__init__()
12
+ self.h = primo_h
13
+
14
+ # compute coefficient for the energy
15
+ indices = torch.tensor([(i, j) for i in range(2) for j in range(2)])
16
+ indices_A = indices.repeat_interleave(4, dim=0)
17
+ indices_B = indices.repeat(4, 1)
18
+ self.coeff = (torch.ones(1) * 2).pow(((indices_A - indices_B).abs() * -1).sum(dim=1))[None, :]
19
+
20
+ def forward(self, transformed_prism, rotations, verts, faces, normals):
21
+ # transformed_prism is (n_faces, 3, 3)
22
+ # verts and faces are from the template (shape 2)
23
+ # * for now assumes there is only one batch
24
+ # todo add batch support
25
+ bs = 1
26
+ verts = verts.reshape(-1, 3)
27
+ normals = normals.reshape(-1, 3)
28
+ faces = faces
29
+
30
+ # get the area of each face
31
+ face_areas = self.get_face_areas(verts, faces) # (n_faces,)
32
+
33
+ # get list of edges and the faces that share each edge
34
+ face_ids, edges = face_adjacency(faces.cpu().numpy(), return_edges=True) # (n_edges, 2), (n_edges, 2)
35
+ face_ids, edges = torch.from_numpy(face_ids).to(verts.device), torch.from_numpy(edges).to(verts.device)
36
+
37
+ # normals and rotations of the faces that share each edge
38
+ normals1, normals2 = normals[edges[:, 0]], normals[edges[:, 1]] # (n_edges, 3), normals are per vertex
39
+ rotations1, rotations2 = rotations[face_ids[:, 0]], rotations[face_ids[:, 1]] # (n_edges, 3, 3), rotations are per face
40
+
41
+ # computed normals from the transformed prism
42
+ # normals = self.compute_normals(transformed_prism)
43
+
44
+ # compute the loss
45
+ face_id1, face_id2 = face_ids[:, 0], face_ids[:, 1] # (n_edges,)
46
+ faces_to_verts = self.get_verts_id_face(faces, edges, face_ids) # (n_edges, 4)
47
+ verts1_p1, verts2_p1 = transformed_prism[face_id1, faces_to_verts[:, 0]], transformed_prism[face_id1, faces_to_verts[:, 1]] # (n_edges, 3)
48
+ verts1_p2, verts2_p2 = transformed_prism[face_id2, faces_to_verts[:, 2]], transformed_prism[face_id2, faces_to_verts[:, 3]] # (n_edges, 3)
49
+
50
+ # get the normals per vertex
51
+ # normals1, normals2 = normals[face_id1], normals[face_id2] # (n_edges, 3) # normals per face (NOT USED)
52
+ prism1_n1, prism1_n2 = (normals1[:, None] @ rotations1).squeeze(1), (normals2[:, None] @ rotations1).squeeze(1) # todo check if this is correct
53
+ prism2_n1, prism2_n2 = (normals1[:, None] @ rotations2).squeeze(1), (normals2[:, None] @ rotations2).squeeze(1)
54
+
55
+ # get the coordinates of the face of the prism
56
+ # prism1 (1 -> 2)
57
+ f_p1_00, f_p1_01 = verts1_p1 + prism1_n1 * self.h, verts2_p1 + prism1_n2 * self.h # (n_edges, 3)
58
+ f_p1_10, f_p1_11 = verts1_p1 - prism1_n1 * self.h, verts2_p1 - prism1_n2 * self.h # (n_edges, 3)
59
+ # prism2 (2 -> 1)
60
+ f_p2_00, f_p2_01 = verts1_p2 + prism2_n1 * self.h, verts2_p2 + prism2_n2 * self.h # (n_edges, 3)
61
+ f_p2_10, f_p2_11 = verts1_p2 - prism2_n1 * self.h, verts2_p2 - prism2_n2 * self.h # (n_edges, 3)
62
+
63
+ # compute the energy
64
+ A, B = torch.stack((f_p1_00, f_p1_01, f_p1_10, f_p1_11), dim=1), torch.stack((f_p2_00, f_p2_01, f_p2_10, f_p2_11), dim=1) # (n_edges, 4, 3)
65
+ energy = self.compute_energy(A - B, A - B) # (n_edges,)
66
+
67
+ # compute weight
68
+ area1, area2 = face_areas[face_id1], face_areas[face_id2] # (n_edges,)
69
+ weight = torch.norm(verts[edges[:, 0]] - verts[edges[:, 1]], dim=1).square() / (area1 + area2) # (n_edges,)
70
+ # weight = torch.ones_like(weight).to(weight.device) # todo remove
71
+ energy = energy * weight # (n_edges,)
72
+
73
+ loss = energy.sum() / bs # todo when batch enabled, need to divide by batch size
74
+ return loss
75
+
76
+ def compute_energy(self, A, B):
77
+ """
78
+ Computes the formula sum_{i,j,k,l=0}^{1} a_{ij}b_{kl} 2^{-|i - k| - |j - l|}.
79
+ Assumes that A and B are tensors of size bs x 4 x 3, where bs is the batch size.
80
+ """
81
+ self.coeff = self.coeff.to(A.device)
82
+
83
+ A_repeated = A.repeat_interleave(4, dim=1)
84
+ B_repeated = B.repeat(1, 4, 1)
85
+
86
+ energy = (A_repeated * B_repeated).sum(dim=-1)
87
+ energy = (energy * self.coeff).sum(dim=1)
88
+ energy = energy / 9
89
+
90
+ return energy
91
+
92
+ def get_face_areas(self, verts, faces):
93
+ # get the area of each face
94
+ v1, v2, v3 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
95
+ area = 0.5 * torch.cross(v2 - v1, v3 - v1, dim=-1).norm(dim=1)
96
+
97
+ return area
98
+
99
+ def get_verts_id_face(self, F, E, Q):
100
+ e = E.shape[0]
101
+ Z = torch.zeros((e, 4), dtype=torch.long)
102
+
103
+ v1 = F[:, 0][Q[:, 0]]
104
+ v2 = F[:, 1][Q[:, 0]]
105
+ v3 = F[:, 2][Q[:, 0]]
106
+ v4 = F[:, 0][Q[:, 1]]
107
+ v5 = F[:, 1][Q[:, 1]]
108
+ v6 = F[:, 2][Q[:, 1]]
109
+
110
+ idx1 = torch.where(v1 == E[:, 0], 0, torch.where(v2 == E[:, 0], 1, torch.where(v3 == E[:, 0], 2, -1)))
111
+ idx2 = torch.where(v1 == E[:, 1], 0, torch.where(v2 == E[:, 1], 1, torch.where(v3 == E[:, 1], 2, -1)))
112
+ idx3 = torch.where(v4 == E[:, 0], 0, torch.where(v5 == E[:, 0], 1, torch.where(v6 == E[:, 0], 2, -1)))
113
+ idx4 = torch.where(v4 == E[:, 1], 0, torch.where(v5 == E[:, 1], 1, torch.where(v6 == E[:, 1], 2, -1)))
114
+
115
+ Z[:, 0:2] = torch.stack((idx1, idx2), dim=1)
116
+ Z[:, 2:4] = torch.stack((idx3, idx4), dim=1)
117
+ Z = Z.to(F.device)
118
+
119
+ return Z
snk/prism_decoder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import roma
4
+ from shape_models.layers import DiffusionNet
5
+
6
+
7
+ class PrismDecoder(torch.nn.Module):
8
+ def __init__(self, dim_in=1024, dim_out=512, n_width=256, n_block=4, pairwise_dot=True, dropout=False, dot_linear_complex=True, neig=128):
9
+ super().__init__()
10
+
11
+
12
+ self.diffusion_net = DiffusionNet(
13
+ C_in=dim_in,
14
+ C_out=dim_out,
15
+ C_width=n_width,
16
+ N_block=n_block,
17
+ dropout=dropout,
18
+ with_gradient_features=pairwise_dot,
19
+ with_gradient_rotations=dot_linear_complex,
20
+ )
21
+
22
+ self.mlp_refine = nn.Sequential(
23
+ nn.Linear(dim_out, dim_out),
24
+ nn.ReLU(),
25
+ nn.Linear(dim_out, 512),
26
+ nn.ReLU(),
27
+ nn.Linear(512, 256),
28
+ nn.ReLU(),
29
+ nn.Linear(256, 12),
30
+ )
31
+
32
+ def forward(self, batch_dict, latent):
33
+ # original prism
34
+ try:
35
+ verts = batch_dict["vertices"]
36
+ except:
37
+ verts = batch_dict["verts"]
38
+ faces = batch_dict["faces"]
39
+ prism_base = verts[faces] # (n_faces, 3, 3)
40
+ bs = 1
41
+
42
+ # forward through diffusion net
43
+ features = self.diffusion_net(latent, batch_dict["mass"], batch_dict["L"], evals=batch_dict["evals"],
44
+ evecs=batch_dict["evecs"], gradX=batch_dict["gradX"], gradY=batch_dict["gradY"], faces=batch_dict["faces"]) # (bs, n_verts, dim)
45
+
46
+ # features per face
47
+ x_gather = features.unsqueeze(-1).expand(-1, -1, 3)
48
+ faces_gather = faces.unsqueeze(1).expand(-1, features.shape[-1], -1)
49
+ xf = torch.gather(x_gather, 0, faces_gather)
50
+ features = torch.mean(xf, dim=-1) # (bs, n_faces, dim)
51
+
52
+ # refine features with mlp
53
+ features = self.mlp_refine(features) # (bs, n_faces, 12)
54
+
55
+ # get the translation and rotation
56
+ rotations = features[:, :9].reshape(-1, 3, 3)
57
+ rotations = roma.special_procrustes(rotations) # (n_faces, 3, 3)
58
+ translations = features[:, 9:].reshape(-1, 3) # (n_faces, 3)
59
+
60
+ # transform the prism
61
+ transformed_prism = (prism_base @ rotations) + translations[:, None]
62
+
63
+ # prism to vertices
64
+ features = self.prism_to_vertices(transformed_prism, faces, verts)
65
+
66
+ out_features = features.reshape(bs, -1, 3)
67
+ transformed_prism = transformed_prism
68
+ rotations = rotations
69
+ return out_features, transformed_prism, rotations
70
+
71
+ def prism_to_vertices(self, prism, faces, verts):
72
+ # initialize the transformed features tensor
73
+ N = verts.shape[0]
74
+ d = prism.shape[-1]
75
+ device = prism.device
76
+ features = torch.zeros((N, d), device=device)
77
+
78
+ # scatter the features in K onto L using the indices in F
79
+ features.scatter_add_(0, faces[:, :, None].repeat(1, 1, d).reshape(-1, d), prism.reshape(-1, d))
80
+
81
+ # divide each row in the transformed features tensor by the number of faces that the corresponding vertex appears in
82
+ num_faces_per_vertex = torch.zeros(N, dtype=torch.float32, device=device)
83
+ num_faces_per_vertex.index_add_(0, faces.reshape(-1), torch.ones(faces.shape[0] * 3, device=device))
84
+ features /= num_faces_per_vertex.unsqueeze(1).clamp(min=1)
85
+
86
+ return features